Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,31 @@ namespace mlir {
namespace heir {
class ILPBootstrapPlacementAnalysis {
public:
struct OperandLevelReduction {
Operation* op;
unsigned operandNumber;
int levelToDrop;
};

struct OutputLevelReduction {
Value value;
int levelToDrop;
};

ILPBootstrapPlacementAnalysis(Operation* op, DataFlowSolver* solver,
int bootstrapWaterline)
: opToRunOn(op), solver(solver), bootstrapWaterline(bootstrapWaterline) {}
int bootstrapWaterline, int scaleWaterline,
int scaleFactorBits,
int bootstrapLevelLowerBound, int bootstrapCost,
int rescaleCost, bool useOrbitCompression)
: opToRunOn(op),
solver(solver),
bootstrapWaterline(bootstrapWaterline),
scaleWaterline(scaleWaterline),
scaleFactorBits(scaleFactorBits),
bootstrapLevelLowerBound(bootstrapLevelLowerBound),
bootstrapCost(bootstrapCost),
rescaleCost(rescaleCost),
useOrbitCompression(useOrbitCompression) {}
~ILPBootstrapPlacementAnalysis() = default;

LogicalResult solve();
Expand All @@ -29,6 +51,17 @@ class ILPBootstrapPlacementAnalysis {
// relinearize insertion.
llvm::SmallVector<Value, 32> getValuesToBootstrap() const;

// Return per-use level reductions chosen by the ILP.
llvm::SmallVector<OperandLevelReduction, 32> getOperandLevelReductions()
const {
return operandLevelReductions;
}

// Return per-result level reductions chosen by the ILP.
llvm::SmallVector<OutputLevelReduction, 32> getOutputLevelReductions() const {
return outputLevelReductions;
}

// Return the level at the given SSA value, as determined by the
// solution to the optimization problem. When the input value is the result
// of an op, and the model solution suggests a bootstrap should be
Expand All @@ -47,9 +80,17 @@ class ILPBootstrapPlacementAnalysis {
Operation* opToRunOn;
DataFlowSolver* solver;
int bootstrapWaterline;
int scaleWaterline;
int scaleFactorBits;
int bootstrapLevelLowerBound;
int bootstrapCost;
int rescaleCost;
bool useOrbitCompression;
llvm::DenseMap<Operation*, bool> solution;
llvm::DenseMap<Value, int> solutionLevelBeforeBootstrap;
llvm::DenseMap<Value, int> solutionLevelAfterBootstrap;
llvm::SmallVector<OperandLevelReduction, 32> operandLevelReductions;
llvm::SmallVector<OutputLevelReduction, 32> outputLevelReductions;
};
} // namespace heir
} // namespace mlir
Expand Down
165 changes: 149 additions & 16 deletions lib/Transforms/ILPBootstrapPlacement/ILPBootstrapPlacement.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
#include "lib/Transforms/ILPBootstrapPlacement/ILPBootstrapPlacement.h"

#include <cmath>
#include <optional>
#include <string>
#include <utility>

#include "lib/Analysis/ILPBootstrapPlacementAnalysis/ILPBootstrapPlacementAnalysis.h"
#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h"
#include "lib/Dialect/Mgmt/IR/MgmtOps.h"
#include "lib/Dialect/Mgmt/Transforms/AnnotateMgmt.h"
#include "lib/Dialect/Secret/IR/SecretOps.h"
#include "lib/Transforms/SecretInsertMgmt/Pipeline.h"
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "llvm/include/llvm/Support/JSON.h" // from @llvm-project
#include "llvm/include/llvm/Support/MemoryBuffer.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlow/Utils.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
Expand All @@ -26,20 +33,90 @@ namespace heir {
#define GEN_PASS_DEF_ILPBOOTSTRAPPLACEMENT
#include "lib/Transforms/ILPBootstrapPlacement/ILPBootstrapPlacement.h.inc"

struct OrbitCostModel {
int bootstrapCost;
int rescaleCost;
};

static std::optional<int> averagePositiveLatency(const llvm::json::Object& root,
llvm::StringRef opName) {
const llvm::json::Object* latencyTable = root.getObject("latencyTable");
if (!latencyTable) return std::nullopt;

const llvm::json::Array* latencies = latencyTable->getArray(opName);
if (!latencies) return std::nullopt;

double sum = 0;
int count = 0;
for (const llvm::json::Value& latencyValue : *latencies) {
std::optional<double> latency = latencyValue.getAsNumber();
if (!latency || *latency <= 0) continue;
sum += *latency;
++count;
}
if (count == 0) return std::nullopt;
return static_cast<int>(std::llround(sum / count));
}

static FailureOr<OrbitCostModel> loadOrbitCostModel(llvm::StringRef path) {
auto bufferOrError = llvm::MemoryBuffer::getFile(path);
if (!bufferOrError) return failure();

llvm::Expected<llvm::json::Value> parsed =
llvm::json::parse((*bufferOrError)->getBuffer());
if (!parsed) {
llvm::consumeError(parsed.takeError());
return failure();
}

const llvm::json::Object* root = parsed->getAsObject();
if (!root) return failure();

std::optional<int> parsedBootstrapCost =
averagePositiveLatency(*root, "earth.bootstrap_single");
std::optional<int> parsedRescaleCost =
averagePositiveLatency(*root, "earth.rescale_single");
if (!parsedBootstrapCost || !parsedRescaleCost) return failure();

return OrbitCostModel{*parsedBootstrapCost, *parsedRescaleCost};
}

struct ILPBootstrapPlacement
: impl::ILPBootstrapPlacementBase<ILPBootstrapPlacement> {
using ILPBootstrapPlacementBase::ILPBootstrapPlacementBase;

LogicalResult processSecretGenericOp(
secret::GenericOp genericOp, DataFlowSolver* solver,
SmallVector<Value, 32>* valuesToBootstrap) {
SmallVector<Value, 32>* valuesToBootstrap,
SmallVector<ILPBootstrapPlacementAnalysis::OutputLevelReduction, 32>*
outputLevelReductions,
SmallVector<ILPBootstrapPlacementAnalysis::OperandLevelReduction, 32>*
operandLevelReductions) {
genericOp->walk([&](mgmt::BootstrapOp op) {
op.getResult().replaceAllUsesWith(op.getOperand());
op.erase();
});

ILPBootstrapPlacementAnalysis analysis(genericOp, solver,
bootstrapWaterline);
int effectiveBootstrapCost = bootstrapCost;
int effectiveRescaleCost = rescaleCost;
if (!orbitCostModel.empty()) {
FailureOr<OrbitCostModel> loadedCostModel =
loadOrbitCostModel(orbitCostModel);
if (failed(loadedCostModel)) {
llvm::errs() << "failed to load Orbit cost model from `"
<< orbitCostModel << "`\n";
genericOp->emitError() << "failed to load Orbit cost model from `"
<< orbitCostModel << "`";
return failure();
}
effectiveBootstrapCost = loadedCostModel->bootstrapCost;
effectiveRescaleCost = loadedCostModel->rescaleCost;
}

ILPBootstrapPlacementAnalysis analysis(
genericOp, solver, bootstrapWaterline, scaleWaterline, scaleFactorBits,
bootstrapLevelLowerBound, effectiveBootstrapCost, effectiveRescaleCost,
useOrbitCompression);
if (failed(analysis.solve())) {
genericOp->emitError(
"Failed to solve the bootstrap placement optimization problem");
Expand All @@ -48,26 +125,69 @@ struct ILPBootstrapPlacement
LLVM_DEBUG(analysis.printSolution(llvm::dbgs()));
for (Value v : analysis.getValuesToBootstrap())
valuesToBootstrap->push_back(v);
for (auto reduction : analysis.getOutputLevelReductions())
outputLevelReductions->push_back(reduction);
for (auto reduction : analysis.getOperandLevelReductions())
operandLevelReductions->push_back(reduction);
return success();
}

std::pair<Value, Operation*> followRelinearizeModReduceChain(Value value) {
Value chainValue = value;
Operation* chainEnd = value.getDefiningOp();
while (chainValue.hasOneUse()) {
Operation* user = *chainValue.getUsers().begin();
if (isa<mgmt::RelinearizeOp>(user) || isa<mgmt::ModReduceOp>(user)) {
chainValue = user->getResult(0);
chainEnd = user;
continue;
}
break;
}
return {chainValue, chainEnd};
}

void insertOutputLevelReductions(
ArrayRef<ILPBootstrapPlacementAnalysis::OutputLevelReduction>
outputLevelReductions) {
OpBuilder b(&getContext());
for (auto reduction : outputLevelReductions) {
auto [toReduce, insertAfter] =
followRelinearizeModReduceChain(reduction.value);
if (!insertAfter) continue;

b.setInsertionPointAfter(insertAfter);
auto levelReduceOp = mgmt::LevelReduceOp::create(
b, insertAfter->getLoc(), toReduce,
b.getI64IntegerAttr(reduction.levelToDrop));
toReduce.replaceAllUsesExcept(levelReduceOp.getResult(), {levelReduceOp});
}
}

void insertOperandLevelReductions(
ArrayRef<ILPBootstrapPlacementAnalysis::OperandLevelReduction>
operandLevelReductions) {
OpBuilder b(&getContext());
for (auto reduction : operandLevelReductions) {
Operation* op = reduction.op;
if (!op || reduction.operandNumber >= op->getNumOperands()) continue;

Value operand = op->getOperand(reduction.operandNumber);
b.setInsertionPoint(op);
auto levelReduceOp = mgmt::LevelReduceOp::create(
b, op->getLoc(), operand, b.getI64IntegerAttr(reduction.levelToDrop));
op->setOperand(reduction.operandNumber, levelReduceOp.getResult());
}
}

void insertBootstrapsForValues(ArrayRef<Value> valuesToBootstrap) {
OpBuilder b(&getContext());
for (Value v : valuesToBootstrap) {
// After modreduce/relinearize we have mul -> relinearize -> modreduce.
// Follow the chain so we bootstrap the modreduce result (correct level
// refresh) and insert after it.
Value toBootstrap = v;
Operation* insertAfter = v.getDefiningOp();
while (toBootstrap.hasOneUse()) {
Operation* user = *toBootstrap.getUsers().begin();
if (isa<mgmt::RelinearizeOp>(user) || isa<mgmt::ModReduceOp>(user)) {
toBootstrap = user->getResult(0);
insertAfter = user;
} else {
break;
}
}
auto [toBootstrap, insertAfter] = followRelinearizeModReduceChain(v);
if (!insertAfter) continue;
b.setInsertionPointAfter(insertAfter);
auto bootstrapOp =
mgmt::BootstrapOp::create(b, insertAfter->getLoc(), toBootstrap);
Expand All @@ -89,9 +209,14 @@ struct ILPBootstrapPlacement
}

SmallVector<Value, 32> valuesToBootstrap;
SmallVector<ILPBootstrapPlacementAnalysis::OutputLevelReduction, 32>
outputLevelReductions;
SmallVector<ILPBootstrapPlacementAnalysis::OperandLevelReduction, 32>
operandLevelReductions;
auto result = module->walk([&](secret::GenericOp genericOp) {
if (failed(
processSecretGenericOp(genericOp, &solver, &valuesToBootstrap)))
if (failed(processSecretGenericOp(genericOp, &solver, &valuesToBootstrap,
&outputLevelReductions,
&operandLevelReductions)))
return WalkResult::interrupt();
return WalkResult::advance();
});
Expand All @@ -100,6 +225,10 @@ struct ILPBootstrapPlacement
return;
}

// Insert per-use level reductions before consumers, matching Orbit-style
// edge rescale placement.
insertOperandLevelReductions(operandLevelReductions);

// Modreduce after every mul.
insertModReduceBeforeOrAfterMult(getOperation(), /*afterMul=*/true,
/*beforeMulIncludeFirstMul=*/false,
Expand All @@ -108,6 +237,10 @@ struct ILPBootstrapPlacement
// Relinearize after every mul.
insertRelinearizeAfterMult(getOperation(), /*includeFloats=*/true);

// Insert shared producer-output level reductions after mul management and
// before bootstraps, matching Orbit's node rescale decisions.
insertOutputLevelReductions(outputLevelReductions);

// Insert bootstraps at the Values the ILP chose. Values remain valid.
insertBootstrapsForValues(valuesToBootstrap);

Expand Down
49 changes: 44 additions & 5 deletions lib/Transforms/ILPBootstrapPlacement/ILPBootstrapPlacement.td
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,23 @@ include "mlir/Pass/PassBase.td"
def ILPBootstrapPlacement : Pass <"ilp-bootstrap-placement"> {
let summary = "Optimize placement of bootstrap ops using ILP";
let description = [{
This pass uses an integer linear program to determine the optimal level
of each term in the MLIR, and thus the placement of bootstrap and
modreduce operations.
This pass uses an integer linear program to determine feasible levels
and CKKS-style scales for each term in the MLIR, and thus the placement
of bootstrap and level-reduction operations.

The pass runs on [ciphertext-semantic](https://heir.dev/docs/design/layout/#data-semantic-and-ciphertext-semantic-tensors)
IR (secret.generic with arith ops operating on pre-packed tensors). It
1) Inserts mgmt.modreduce after each level-consuming op (e.g. mul in
CKKS, where level drops only at multiplications).
2) Inserts mgmt.bootstrap at the positions chosen by the ILP.
3) Inserts mgmt.relinearize after each mul. Resulting order is mul ->
2) Inserts per-use mgmt.level_reduce ops for edge rescale decisions
chosen by the ILP.
3) Inserts mgmt.bootstrap at the positions chosen by the ILP.
4) Inserts mgmt.relinearize after each mul. Resulting order is mul ->
relinearize -> modreduce, with bootstrap after modreduce or after
the op where the ILP chose.

The ILP formulation is inspired by [Orbit](https://eprint.iacr.org/2026/213.pdf).

Note: The ILP formulation does not account for a freshly encrypted
ciphertext starting at a higher level than the bootstrap waterline.
This will be implemented as future work.
Expand All @@ -32,6 +36,41 @@ def ILPBootstrapPlacement : Pass <"ilp-bootstrap-placement"> {
"int",
/*default=*/"3",
"Bootstrap waterline (max level). Levels are 0..bootstrap-waterline (inclusive); inputs start at bootstrap-waterline.">,
Option<"scaleWaterline",
"scale-waterline",
"int",
/*default=*/"40",
"Minimum CKKS scale budget used by the Orbit-inspired scale constraints.">,
Option<"scaleFactorBits",
"scale-factor-bits",
"int",
/*default=*/"51",
"Scale bits dropped by one rescale/modreduce in the Orbit-inspired scale constraints.">,
Option<"bootstrapLevelLowerBound",
"bootstrap-level-lower-bound",
"int",
/*default=*/"0",
"Minimum input level at which bootstrap is allowed in the Orbit-inspired scale constraints.">,
Option<"orbitCostModel",
"orbit-cost-model",
"std::string",
/*default=*/"""",
"Path to an Orbit JSON cost model. When provided, bootstrap-cost and rescale-cost are loaded from latencyTable.">,
Option<"bootstrapCost",
"bootstrap-cost",
"int",
/*default=*/"69320650",
"Cost of one bootstrap in the ILP objective. Default is the positive-latency average from Orbit's profiled 64k Lattigo base cost model.">,
Option<"rescaleCost",
"rescale-cost",
"int",
/*default=*/"40988",
"Cost of one level-reduction/rescale in the ILP objective. Default is the positive-latency average from Orbit's profiled 64k Lattigo base cost model.">,
Option<"useOrbitCompression",
"use-orbit-compression",
"bool",
/*default=*/"true",
"When true, add Orbit-inspired structural compression constraints so equivalent ops share ILP state and bootstrap decisions.">,
];
}

Expand Down
13 changes: 12 additions & 1 deletion tests/Transforms/ilp_bootstrap_placement/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,20 @@ load("//bazel:lit.bzl", "glob_lit_tests")

package(default_applicable_licenses = ["@heir//:license"])

filegroup(
name = "orbit_cost_model",
srcs = [
"orbit_bad_cost_model.json",
"orbit_cost_model.json",
],
)

glob_lit_tests(
name = "all_tests",
data = ["@heir//tests:test_utilities"],
data = [
":orbit_cost_model",
"@heir//tests:test_utilities",
],
driver = "@heir//tests:run_lit.sh",
test_file_exts = ["mlir"],
)
Loading
Loading