Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,5 @@ __pycache__/

# lockfile is updated by automation
MODULE.bazel.lock

.jetskicli
13 changes: 13 additions & 0 deletions lib/Dialect/JaxiteWord/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ cc_library(
name = "Dialect",
srcs = ["JaxiteWordDialect.cpp"],
hdrs = [
"JaxiteWordAttributes.h",
"JaxiteWordDialect.h",
"JaxiteWordOps.h",
"JaxiteWordTypes.h",
],
deps = [
":attributes_inc_gen",
":dialect_inc_gen",
":ops_inc_gen",
":types_inc_gen",
Expand All @@ -32,6 +34,7 @@ cc_library(
td_library(
name = "td_files",
srcs = [
"JaxiteWordAttributes.td",
"JaxiteWordDialect.td",
"JaxiteWordOps.td",
"JaxiteWordTypes.td",
Expand Down Expand Up @@ -63,6 +66,16 @@ add_heir_dialect_library(
],
)

add_heir_dialect_library(
name = "attributes_inc_gen",
dialect = "JaxiteWord",
kind = "attribute",
td_file = "JaxiteWordAttributes.td",
deps = [
":td_files",
],
)

add_heir_dialect_library(
name = "ops_inc_gen",
dialect = "JaxiteWord",
Expand Down
9 changes: 9 additions & 0 deletions lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#ifndef LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDATTRIBUTES_H_
#define LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDATTRIBUTES_H_

#include "lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.h"

#define GET_ATTRDEF_CLASSES
#include "lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.h.inc"

#endif // LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDATTRIBUTES_H_
32 changes: 32 additions & 0 deletions lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#ifndef LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDATTRIBUTES_TD_
#define LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDATTRIBUTES_TD_

include "JaxiteWordDialect.td"

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"

class JaxiteWord_Attribute<string attrName, string attrMnemonic>
: AttrDef<JaxiteWord_Dialect, attrName> {
let mnemonic = attrMnemonic;
let assemblyFormat = "`<` struct(params) `>`";
}

def JaxiteWord_CkksParameters : JaxiteWord_Attribute<"CkksParameters", "ckks_parameters"> {
let summary = "Jaxite CKKS parameters";
let description = [{
Parameters for Jaxite CKKS backend.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JianmingTONG Could you provide a quick description of the parameters that would be suitable to include in the description field here? Particularly, r, c, dnum, composite_degree, batch.

}];

let parameters = (ins
"DenseI64ArrayAttr":$q_towers,
"DenseI64ArrayAttr":$p_towers,
"int":$r,
"int":$c,
"int":$dnum,
"int":$composite_degree,
"int":$batch
);
}

#endif // LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDATTRIBUTES_TD_
7 changes: 7 additions & 0 deletions lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.h"

#include "lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.h"
#include "lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.cpp.inc"
#include "lib/Dialect/JaxiteWord/IR/JaxiteWordOps.h"
#include "lib/Dialect/JaxiteWord/IR/JaxiteWordTypes.h"
Expand All @@ -8,6 +9,8 @@
#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project

#define GET_ATTRDEF_CLASSES
#include "lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.cpp.inc"
#define GET_TYPEDEF_CLASSES
#include "lib/Dialect/JaxiteWord/IR/JaxiteWordTypes.cpp.inc"
#define GET_OP_CLASSES
Expand All @@ -18,6 +21,10 @@ namespace heir {
namespace jaxiteword {

void JaxiteWordDialect::initialize() {
addAttributes<
#define GET_ATTRDEF_LIST
#include "lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.cpp.inc"
>();
addTypes<
#define GET_TYPEDEF_LIST
#include "lib/Dialect/JaxiteWord/IR/JaxiteWordTypes.cpp.inc"
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def JaxiteWord_Dialect : Dialect {
let cppNamespace = "::mlir::heir::jaxiteword";

let useDefaultTypePrinterParser = 1;
let useDefaultAttributePrinterParser = 1;
}

#endif // LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDDIALECT_H_
39 changes: 39 additions & 0 deletions lib/Dialect/JaxiteWord/Transforms/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
load("@heir//lib/Transforms:transforms.bzl", "add_heir_transforms")
load("@rules_cc//cc:cc_library.bzl", "cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "Transforms",
hdrs = ["Passes.h"],
deps = [
":JaxiteCkksParameterSelection",
":pass_inc_gen",
"@heir//lib/Dialect/JaxiteWord/IR:Dialect",
"@llvm-project//mlir:IR",
],
)

cc_library(
name = "JaxiteCkksParameterSelection",
srcs = ["JaxiteCkksParameterSelection.cpp"],
hdrs = ["JaxiteCkksParameterSelection.h"],
deps = [
":pass_inc_gen",
"@heir//lib/Dialect/CKKS/IR:Dialect",
"@heir//lib/Dialect/JaxiteWord/IR:Dialect",
"@heir//lib/Parameters:RLWEParams",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Transforms",
],
)

add_heir_transforms(
header_filename = "Passes.h.inc",
pass_name = "JaxiteWord",
td_file = "Passes.td",
)
85 changes: 85 additions & 0 deletions lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#include "lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.h"

#include "lib/Dialect/CKKS/IR/CKKSAttributes.h"
#include "lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.h"
#include "lib/Dialect/JaxiteWord/IR/JaxiteWordOps.h"
#include "lib/Parameters/RLWEParams.h"
#include "llvm/include/llvm/ADT/APInt.h" // from @llvm-project
#include "llvm/include/llvm/Support/MathExtras.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace jaxiteword {

#define GEN_PASS_DEF_JAXITECKKSPARAMETERSELECTION
#include "lib/Dialect/JaxiteWord/Transforms/Passes.h.inc"

struct JaxiteCkksParameterSelection
: impl::JaxiteCkksParameterSelectionBase<JaxiteCkksParameterSelection> {
using JaxiteCkksParameterSelectionBase::JaxiteCkksParameterSelectionBase;

void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp module = getOperation();

auto schemeParamAttr = module->getAttrOfType<ckks::SchemeParamAttr>(
ckks::CKKSDialect::kSchemeParamAttrName);
if (!schemeParamAttr) {
module->emitOpError() << "Missing ckks.schemeParam attribute";
signalPassFailure();
return;
}

int logN = schemeParamAttr.getLogN();
int ringDim = 1 << logN;

auto Q = schemeParamAttr.getQ().asArrayRef();
auto P = schemeParamAttr.getP().asArrayRef();

int totalBitsQ = 0;
for (auto q : Q) {
totalBitsQ += llvm::APInt(64, q).getActiveBits();
}

int totalBitsP = 0;
for (auto p : P) {
totalBitsP += llvm::APInt(64, p).getActiveBits();
}

std::vector<int64_t> existingPrimes;
std::vector<int64_t> qTowers;
std::vector<int64_t> pTowers;

int bitsGeneratedQ = 0;
while (bitsGeneratedQ < totalBitsQ) {
int64_t prime = findPrime(30, ringDim, existingPrimes);
qTowers.push_back(prime);
existingPrimes.push_back(prime);
bitsGeneratedQ += 30;
}

int bitsGeneratedP = 0;
while (bitsGeneratedP < totalBitsP) {
int64_t prime = findPrime(30, ringDim, existingPrimes);
pTowers.push_back(prime);
existingPrimes.push_back(prime);
bitsGeneratedP += 30;
}

auto qTowersAttr = DenseI64ArrayAttr::get(context, qTowers);
auto pTowersAttr = DenseI64ArrayAttr::get(context, pTowers);

int dnum = computeDnum(Q.size() - 1);

// FIXME: Replace dummy value for composite_degree.
auto ckksParamsAttr = CkksParametersAttr::get(
context, qTowersAttr, pTowersAttr, 4, 4, dnum, 7, 1);

module->setAttr("jaxiteword.ckks_params", ckksParamsAttr);
}
};

} // namespace jaxiteword
} // namespace heir
} // namespace mlir
17 changes: 17 additions & 0 deletions lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#ifndef LIB_DIALECT_JAXITEWORD_TRANSFORMS_JAXITECKKSPARAMETERSELECTION_H_
#define LIB_DIALECT_JAXITEWORD_TRANSFORMS_JAXITECKKSPARAMETERSELECTION_H_

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace jaxiteword {

#define GEN_PASS_DECL_JAXITECKKSPARAMETERSELECTION
#include "lib/Dialect/JaxiteWord/Transforms/Passes.h.inc"

} // namespace jaxiteword
} // namespace heir
} // namespace mlir

#endif // LIB_DIALECT_JAXITEWORD_TRANSFORMS_JAXITECKKSPARAMETERSELECTION_H_
18 changes: 18 additions & 0 deletions lib/Dialect/JaxiteWord/Transforms/Passes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef LIB_DIALECT_JAXITEWORD_TRANSFORMS_PASSES_H_
#define LIB_DIALECT_JAXITEWORD_TRANSFORMS_PASSES_H_

#include "lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.h"
#include "lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.h"

namespace mlir {
namespace heir {
namespace jaxiteword {

#define GEN_PASS_REGISTRATION
#include "lib/Dialect/JaxiteWord/Transforms/Passes.h.inc"

} // namespace jaxiteword
} // namespace heir
} // namespace mlir

#endif // LIB_DIALECT_JAXITEWORD_TRANSFORMS_PASSES_H_
16 changes: 16 additions & 0 deletions lib/Dialect/JaxiteWord/Transforms/Passes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#ifndef LIB_DIALECT_JAXITEWORD_TRANSFORMS_PASSES_TD_
#define LIB_DIALECT_JAXITEWORD_TRANSFORMS_PASSES_TD_

include "mlir/Pass/PassBase.td"

def JaxiteCkksParameterSelection : Pass<"jaxite-ckks-parameter-selection", "mlir::ModuleOp"> {
let summary = "Selects parameters for Jaxite CKKS backend";
let description = [{
This pass selects parameters for the Jaxite CKKS backend and annotates them on the module.

(* example filepath=tests/Dialect/JaxiteWord/Transforms/doctest.mlir *)
}];
let dependentDialects = ["mlir::heir::jaxiteword::JaxiteWordDialect", "mlir::heir::ckks::CKKSDialect"];
}

#endif // LIB_DIALECT_JAXITEWORD_TRANSFORMS_PASSES_TD_
15 changes: 15 additions & 0 deletions tests/Dialect/JaxiteWord/IR/attr_test.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: heir-opt %s | FileCheck %s

// CHECK: module attributes {jaxiteword.ckks_params = #jaxiteword.ckks_parameters<q_towers = [1, 2], p_towers = [3], r = 4, c = 5, dnum = 6, composite_degree = 7, batch = 8>}
module attributes {
jaxiteword.ckks_params = #jaxiteword.ckks_parameters<
q_towers = [1, 2],
p_towers = [3],
r = 4,
c = 5,
dnum = 6,
composite_degree = 7,
batch = 8
>
} {
}
10 changes: 10 additions & 0 deletions tests/Dialect/JaxiteWord/Transforms/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
load("//bazel:lit.bzl", "glob_lit_tests")

package(default_applicable_licenses = ["@heir//:license"])

glob_lit_tests(
name = "all_tests",
data = ["@heir//tests:test_utilities"],
driver = "@heir//tests:run_lit.sh",
test_file_exts = ["mlir"],
)
19 changes: 19 additions & 0 deletions tests/Dialect/JaxiteWord/Transforms/doctest.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: heir-opt --jaxite-ckks-parameter-selection %s | FileCheck %s

// CHECK: jaxiteword.ckks_params = #jaxiteword.ckks_parameters<{{.*}}>
!ct = !jaxiteword.ciphertext<2, 3, 4>
!ml = !jaxiteword.modulus_list<65536, 1152921504606844513, 1152921504606844417>

module attributes {
ckks.schemeParam = #ckks.scheme_param<
logN = 13,
Q = [36028797018652673],
P = [1152921504606994433],
logDefaultScale = 45
>
} {
func.func @test_add(%ct1 : !ct, %ct2 : !ct, %modulus_list: !ml) -> !ct {
%out = jaxiteword.add %ct1, %ct2, %modulus_list: (!ct, !ct, !ml) -> !ct
return %out : !ct
}
}
32 changes: 32 additions & 0 deletions tests/Dialect/JaxiteWord/Transforms/large_test.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// RUN: heir-opt --jaxite-ckks-parameter-selection %s | FileCheck %s

!ct = !jaxiteword.ciphertext<2, 3, 4>
!ml = !jaxiteword.modulus_list<65536, 1152921504606844513, 1152921504606844417>

// CHECK: jaxiteword.ckks_params = #jaxiteword.ckks_parameters<{{.*}}>
module attributes {
ckks.schemeParam = #ckks.scheme_param<
logN = 13,
Q = [
7896856388305998031, 8335717806483771817, 7621929371556188363, 8941345776919444657,
7943813361973406531, 7742501181933711653, 7673257225347932497, 7210067971330841557,
8234891178228564671, 7847526270039855001, 8245181310374330081, 8960862465870304837,
8718902402328186751, 9031509869954283143, 7789630786405883791, 8945030373143909771,
7258099451375055763, 8999881575504424663, 9020740517063589967, 7906610589161779643,
7256670403940451583, 7215881909751066997, 7261482118667644289, 6918930965025587023,
7552875336759771971, 7264322706790679029, 7035727842643806041, 8663275797836175071,
7348375621176293489, 8101412547026401381
],
P = [
8046990677865391223, 8262056840302532089, 7520591891579404973, 8469636204033924593,
7515061052621148421, 8671733300942445233, 9061065578563297193, 8446495666365292607,
8329800933433096669, 7565030516258039723
],
logDefaultScale = 45
>
} {
func.func @test_add(%ct1 : !ct, %ct2 : !ct, %modulus_list: !ml) -> !ct {
%out = jaxiteword.add %ct1, %ct2, %modulus_list: (!ct, !ct, !ml) -> !ct
return %out : !ct
}
}
1 change: 1 addition & 0 deletions tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ cc_binary(
"@heir//lib/Dialect/Debug/Transforms",
"@heir//lib/Dialect/Jaxite/IR:Dialect",
"@heir//lib/Dialect/JaxiteWord/IR:Dialect",
"@heir//lib/Dialect/JaxiteWord/Transforms",
"@heir//lib/Dialect/KeyMgmt/IR:Dialect",
"@heir//lib/Dialect/LWE/Conversions/LWEToLattigo",
"@heir//lib/Dialect/LWE/Conversions/LWEToOpenfhe",
Expand Down
Loading
Loading