Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
256 changes: 247 additions & 9 deletions lib/Dialect/TensorExt/Transforms/ImplementShiftNetwork.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ namespace tensor_ext {
#define GEN_PASS_DEF_IMPLEMENTSHIFTNETWORK
#include "lib/Dialect/TensorExt/Transforms/Passes.h.inc"

ShiftScheme VosVosErkinShiftNetworks::findShiftScheme(
ShiftSchemeResult VosVosErkinShiftNetworks::findShiftScheme(
const Mapping& mapping, ArrayRef<int64_t> shiftOrder) {
CacheKey cacheKey = makeCacheKey(mapping, shiftOrder);
if (schemeCache.count(cacheKey)) {
return schemeCache[cacheKey];
return ShiftSchemeResult{schemeCache[cacheKey], mapping};
}

ShiftStrategy strategy = evaluateShiftStrategy(mapping, shiftOrder);
Expand Down Expand Up @@ -91,6 +91,242 @@ ShiftScheme VosVosErkinShiftNetworks::findShiftScheme(
}
});

// Cleaning subroutine, identify if overlapping sources for targets
Mapping cleanedMapping = mapping;
auto targetToSources = mapping.getTargetToSources();
bool hasMultiSourceTargets =
std::any_of(targetToSources.begin(), targetToSources.end(),
[](const auto& kv) { return kv.second.size() > 1; });
LLVM_DEBUG({
llvm::dbgs() << "Has overlapping sources for targets - "
<< hasMultiSourceTargets << "\n";
});
if (hasMultiSourceTargets) {
// Create a full conflict graph
graph::UndirectedGraph<CtSlot> fullConflictGraph;
for (const auto& [target, sources] : targetToSources) {
for (const auto& ctmatch : sources) {
fullConflictGraph.addVertex(ctmatch.source);
}
}
SmallVector<int64_t> defaultTestShiftOrder = defaultShiftOrder(
mapping.getCiphertextSize() * mapping.getNumCiphertexts());

ShiftStrategy fullConflictStrategy =
evaluateShiftStrategy(mapping, defaultTestShiftOrder, true);

for (const auto& [roundNum, round] :
llvm::enumerate(fullConflictStrategy.getRounds())) {
if (roundNum == 0) continue;
auto posns = round.positions;
for (auto it1 = posns.begin(); it1 != posns.end(); ++it1) {
for (auto it2 = std::next(it1); it2 != posns.end(); ++it2) {
const SourceShift& ss1 = it1->first;
const SourceShift& ss2 = it2->first;
if (ss1.source != ss2.source && it1->second == it2->second) {
LLVM_DEBUG(llvm::dbgs()
<< "Round " << roundNum << ": collision between " << "{"
<< ss1.source.ct << "," << ss1.source.slot << "}"
<< " and " << "{" << ss2.source.ct << ","
<< ss2.source.slot << "}" << " at " << "{"
<< it1->second.ct << "," << it1->second.slot << "}\n");
fullConflictGraph.addEdge(ss1.source, ss2.source);
}
}
}
}

struct SourceInfo {
int numConflicts;
int numInserts;
};

std::unordered_map<CtSlot, SourceInfo> sourceMap;

for (CtSlot vertex : fullConflictGraph.getVertices()) {
sourceMap[vertex] = SourceInfo{
static_cast<int>(fullConflictGraph.edgesIncidentTo(vertex).size()),
0};
}

LLVM_DEBUG({
llvm::dbgs() << "Full Conflict graph:\n";
for (CtSlot vertex : fullConflictGraph.getVertices()) {
llvm::dbgs() << vertex.ct << "," << vertex.slot << " <-> {";
for (CtSlot neighbor : fullConflictGraph.edgesIncidentTo(vertex)) {
llvm::dbgs() << neighbor.ct << "," << neighbor.slot << "; ";
}
llvm::dbgs() << "}\n";
}
});

// Create seeds for tests
static std::mt19937 masterRng(std::random_device{}());
static std::mt19937 rng(std::random_device{}());
std::uniform_int_distribution<std::uint32_t> seedDist;

// Store best mapping
int minNumRotations = std::numeric_limits<int>::max();
Mapping bestMapping = mapping;
graph::UndirectedGraph<CtSlot> bestConflictGraph;

// Copy the target to Sources for testing
std::vector<std::pair<CtSlot, SmallVector<Mapping::CtMatch>*>> allTargets;
allTargets.reserve(targetToSources.size());
for (auto& entry : targetToSources) {
allTargets.emplace_back(entry.first, &entry.second);
}

// Run 50 tests
for (int i = 0; i < 50; ++i) {
// track source numberConflicts , inserts
std::unordered_map<CtSlot, SourceInfo> testSourceMap = sourceMap;

// create test mapping and conflictgraph
Mapping testMapping = mapping;
graph::UndirectedGraph<CtSlot> testConflictGraph;

// shuffle targets and fill
std::uint32_t trialSeed = seedDist(masterRng);
rng.seed(trialSeed);
std::shuffle(allTargets.begin(), allTargets.end(), rng);

for (auto& [slot, matchesPtr] : allTargets) {
SmallVector<Mapping::CtMatch>& matches = *matchesPtr;
std::vector<std::pair<CtSlot, SourceInfo*>> candidates;
candidates.reserve(matches.size());

// Get and sort sources by 1) inserts , 2) conflicts ascending
for (auto& match : matches) {
CtSlot src = match.source;
auto it = testSourceMap.find(src);
candidates.emplace_back(src, &it->second);
}
std::sort(candidates.begin(), candidates.end(),
[](const auto& a, const auto& b) {
const SourceInfo& sa = *a.second;
const SourceInfo& sb = *b.second;
if (sa.numInserts != sb.numInserts)
return sa.numInserts < sb.numInserts;
return sa.numConflicts < sb.numConflicts;
});

LLVM_DEBUG({
llvm::dbgs() << "Candidates after sort for target " << slot.ct << ","
<< slot.slot << ":\n";
for (const auto& entry : candidates) {
const CtSlot& src = entry.first;
const SourceInfo& info = *entry.second;
llvm::dbgs() << " src=(" << src.ct << "," << src.slot
<< "), numInserts=" << info.numInserts
<< ", numConflicts=" << info.numConflicts << "\n";
}
});

// Pick any source with lowest at random and save choice
const SourceInfo& best = *candidates.front().second;
int minInserts = best.numInserts;
int minConflicts = best.numConflicts;

std::vector<size_t> bestIdx;
bestIdx.reserve(candidates.size());

// any source with lowest
for (size_t i = 0; i < candidates.size(); ++i) {
const SourceInfo& info = *candidates[i].second;
if (info.numInserts == minInserts &&
info.numConflicts == minConflicts)
bestIdx.push_back(i);
else
break;
}

// pick and save choice
std::uniform_int_distribution<size_t> dist(0, bestIdx.size() - 1);
size_t chosenIndex = bestIdx[dist(rng)];

CtSlot chosen = candidates[chosenIndex].first;
SourceInfo& info = *candidates[chosenIndex].second;
info.numInserts += 1;
testMapping.add(chosen, slot);
testConflictGraph.addVertex(chosen);
}

// Create test strategy, fill edges
ShiftStrategy testStrategy =
evaluateShiftStrategy(testMapping, defaultTestShiftOrder);
for (const auto& [roundNum, round] :
llvm::enumerate(testStrategy.getRounds())) {
if (roundNum == 0) continue;
auto posns = round.positions;
for (auto it1 = posns.begin(); it1 != posns.end(); ++it1) {
for (auto it2 = std::next(it1); it2 != posns.end(); ++it2) {
const SourceShift& ss1 = it1->first;
const SourceShift& ss2 = it2->first;
if (ss1.source != ss2.source && it1->second == it2->second) {
LLVM_DEBUG(llvm::dbgs()
<< "Round " << roundNum << ": collision between "
<< "{" << ss1.source.ct << "," << ss1.source.slot
<< "}" << " and " << "{" << ss2.source.ct << ","
<< ss2.source.slot << "}" << " at " << "{"
<< it1->second.ct << "," << it1->second.slot << "}\n");
testConflictGraph.addEdge(ss1.source, ss2.source);
}
}
}
}

// Find number of rotations needed
graph::GreedyGraphColoring<CtSlot> testColorer;
std::unordered_map<CtSlot, int> testColoring =
testColorer.color(testConflictGraph);

SmallVector<RotationGroup> testResultRotationGroups;
testResultRotationGroups.reserve(5);

for (const auto& entry : testColoring) {
CtSlot source = entry.first;
int64_t color = entry.second;
if (color >= static_cast<int64_t>(testResultRotationGroups.size()))
testResultRotationGroups.resize(color + 1);
testResultRotationGroups[color].insert(source);
}

// Update best number rotations & mapping
int testNumberRotations =
static_cast<int>(testResultRotationGroups.size());
if (testNumberRotations < minNumRotations) {
minNumRotations = testNumberRotations;
bestMapping = testMapping;
bestConflictGraph = testConflictGraph;
}
}

cleanedMapping = bestMapping;
conflictGraph = bestConflictGraph;

LLVM_DEBUG({
llvm::dbgs() << "Cleaning subroutine finished. "
<< "Best num rotations = " << minNumRotations << "\n";

llvm::dbgs() << "Best mapping targetToSource:\n";
for (const auto& [target, source] : bestMapping.getTargetToSource()) {
llvm::dbgs() << " target=(" << target.ct << "," << target.slot << ")"
<< " <- source=(" << source.ct << "," << source.slot
<< ")\n";
}

llvm::dbgs() << "Conflict graph after cleaning:\n";
for (CtSlot vertex : conflictGraph.getVertices()) {
llvm::dbgs() << " " << vertex.ct << "," << vertex.slot << " <-> {";
for (CtSlot neighbor : conflictGraph.edgesIncidentTo(vertex)) {
llvm::dbgs() << neighbor.ct << "," << neighbor.slot << "; ";
}
llvm::dbgs() << "}\n";
}
});
}

graph::GreedyGraphColoring<CtSlot> colorer;
std::unordered_map<CtSlot, int> coloring = colorer.color(conflictGraph);

Expand Down Expand Up @@ -121,10 +357,10 @@ ShiftScheme VosVosErkinShiftNetworks::findShiftScheme(

ShiftScheme scheme{resultRotationGroups, strategy};
schemeCache[cacheKey] = scheme;
return schemeCache[cacheKey];
return ShiftSchemeResult{schemeCache[cacheKey], cleanedMapping};
}

ShiftScheme VosVosErkinShiftNetworks::findBestShiftScheme(
ShiftSchemeResult VosVosErkinShiftNetworks::findBestShiftScheme(
const Mapping& mapping, std::size_t randomSeed, unsigned randomTries) {
SmallVector<int64_t> initShiftOrder = defaultShiftOrder(
mapping.getCiphertextSize() * mapping.getNumCiphertexts());
Expand Down Expand Up @@ -156,15 +392,15 @@ ShiftScheme VosVosErkinShiftNetworks::findBestShiftScheme(
}

ShiftStrategy VosVosErkinShiftNetworks::evaluateShiftStrategy(
const Mapping& mapping, ArrayRef<int64_t> shiftOrder) {
const Mapping& mapping, ArrayRef<int64_t> shiftOrder, bool useSources) {
CacheKey cacheKey = makeCacheKey(mapping, shiftOrder);
if (strategyCache.count(cacheKey)) {
return strategyCache[cacheKey];
}

ShiftStrategy strategy(mapping.getCiphertextSize(),
mapping.getNumCiphertexts(), shiftOrder);
strategy.evaluate(mapping);
strategy.evaluate(mapping, useSources);
strategyCache[cacheKey] = strategy;
return strategy;
}
Expand Down Expand Up @@ -230,7 +466,9 @@ LogicalResult convertRemapOp(RemapOp op,
"DenseIntElementsAttr";
}

ShiftScheme scheme = shiftNetworks.findShiftScheme(mapping);
auto shiftSchemeResult = shiftNetworks.findShiftScheme(mapping);
ShiftScheme& scheme = shiftSchemeResult.scheme;
Mapping& cleanedMapping = shiftSchemeResult.cleanedMapping;
auto rotationGroups = scheme.rotationGroups;

assert(!rotationGroups.empty() &&
Expand Down Expand Up @@ -260,8 +498,8 @@ LogicalResult convertRemapOp(RemapOp op,
ciphertexts.push_back(kernel::SSAValue(slice.getResult()));
}

auto resultNodes =
implementShiftNetwork(ciphertexts, mapping, scheme, ciphertextSize);
auto resultNodes = implementShiftNetwork(ciphertexts, cleanedMapping, scheme,
ciphertextSize);

kernel::IRMaterializingVisitor visitor(b, singleCiphertextType);
std::vector<Value> result = visitor.process(resultNodes);
Expand Down
18 changes: 12 additions & 6 deletions lib/Dialect/TensorExt/Transforms/ImplementShiftNetwork.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ namespace tensor_ext {
#define GEN_PASS_DECL_IMPLEMENTSHIFTNETWORK
#include "lib/Dialect/TensorExt/Transforms/Passes.h.inc"

struct ShiftSchemeResult {
ShiftScheme scheme;
Mapping cleanedMapping;
};

// Cf. https://link.springer.com/chapter/10.1007/978-3-031-17140-6_20
// for an explanation of the algorithm.
class VosVosErkinShiftNetworks {
Expand All @@ -41,18 +46,19 @@ class VosVosErkinShiftNetworks {
// on further calls to avoid recomputing the shift network.
//
// The default shiftOrder is LSB to MSB, i.e. 1, 2, 4, 8, ...
ShiftScheme findShiftScheme(const Mapping& mapping,
ArrayRef<int64_t> shiftOrder = {});
ShiftSchemeResult findShiftScheme(const Mapping& mapping,
ArrayRef<int64_t> shiftOrder = {});

// Like findShiftScheme but randomly draw from a uniform distribution over all
// possible shift orders and use the one that results in the best network.
ShiftScheme findBestShiftScheme(const Mapping& mapping,
std::size_t randomSeed,
unsigned randomTries = 100);
ShiftSchemeResult findBestShiftScheme(const Mapping& mapping,
std::size_t randomSeed,
unsigned randomTries = 100);

private:
ShiftStrategy evaluateShiftStrategy(const Mapping& mapping,
ArrayRef<int64_t> shiftOrder);
ArrayRef<int64_t> shiftOrder,
bool useSources = false);

CacheKey makeCacheKey(const Mapping& mapping, ArrayRef<int64_t> shiftOrder);

Expand Down
10 changes: 6 additions & 4 deletions lib/Dialect/TensorExt/Transforms/ImplementShiftNetworkTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ ::testing::AssertionResult checkMapping(const Mapping& mapping,
int64_t ciphertextSize,
unsigned naiveNumRGExpected = 0) {
VosVosErkinShiftNetworks shiftNetworks;

auto naiveScheme = shiftNetworks.findShiftScheme(mapping);
auto naiveSchemeResult = shiftNetworks.findShiftScheme(mapping);
const ShiftScheme& naiveScheme = naiveSchemeResult.scheme;
unsigned naiveNumRG = naiveScheme.rotationGroups.size();
unsigned naiveNumRounds = naiveScheme.strategy.getRounds().size();
if (naiveNumRGExpected > 0 && naiveNumRG != naiveNumRGExpected)
Expand All @@ -106,8 +106,10 @@ ::testing::AssertionResult checkMapping(const Mapping& mapping,
// We try a large number of shift orders here such that we can be effectively
// certain that we will find a network that is at least as good as the "naive"
// one.
auto bestScheme = shiftNetworks.findBestShiftScheme(
mapping, /*randomSeed=*/42, /*randomTries=*/1000);
auto bestSchemeResult =
shiftNetworks.findBestShiftScheme(mapping, /*randomSeed=*/42,
/*randomTries=*/1000);
const ShiftScheme& bestScheme = bestSchemeResult.scheme;
unsigned bestNumRounds = bestScheme.strategy.getRounds().size();
if (bestNumRounds > naiveNumRounds)
return ::testing::AssertionFailure()
Expand Down
Loading
Loading