Skip to content

Commit 61bc180

Browse files
committed
extend handling shift network source conflicts
1 parent eaeff61 commit 61bc180

4 files changed

Lines changed: 300 additions & 13 deletions

File tree

lib/Dialect/TensorExt/Transforms/ImplementShiftNetwork.cpp

Lines changed: 243 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ namespace tensor_ext {
4444
#include "lib/Dialect/TensorExt/Transforms/Passes.h.inc"
4545

4646
ShiftScheme VosVosErkinShiftNetworks::findShiftScheme(
47-
const Mapping& mapping, ArrayRef<int64_t> shiftOrder) {
47+
Mapping& mapping, ArrayRef<int64_t> shiftOrder) {
4848
CacheKey cacheKey = makeCacheKey(mapping, shiftOrder);
4949
if (schemeCache.count(cacheKey)) {
5050
return schemeCache[cacheKey];
@@ -91,6 +91,241 @@ ShiftScheme VosVosErkinShiftNetworks::findShiftScheme(
9191
}
9292
});
9393

94+
// Cleaning subroutine, identify if overlapping sources for targets
95+
auto targetToSources = mapping.getTargetToSources();
96+
bool hasMultiSourceTargets =
97+
std::any_of(targetToSources.begin(), targetToSources.end(),
98+
[](const auto& kv) { return kv.second.size() > 1; });
99+
LLVM_DEBUG({
100+
llvm::dbgs() << "Has overlapping sources for targets - "
101+
<< hasMultiSourceTargets << "\n";
102+
});
103+
if (hasMultiSourceTargets) {
104+
105+
// Create a full conflict graph
106+
graph::UndirectedGraph<CtSlot> fullConflictGraph;
107+
for (const auto& [target, sources] : targetToSources) {
108+
for (const auto& ctmatch : sources) {
109+
fullConflictGraph.addVertex(ctmatch.source);
110+
}
111+
}
112+
SmallVector<int64_t> defaultTestShiftOrder = defaultShiftOrder(
113+
mapping.getCiphertextSize() * mapping.getNumCiphertexts());
114+
115+
ShiftStrategy fullConflictStrategy =
116+
evaluateShiftStrategy(mapping, defaultTestShiftOrder, true);
117+
118+
for (const auto& [roundNum, round] :
119+
llvm::enumerate(fullConflictStrategy.getRounds())) {
120+
if (roundNum == 0) continue;
121+
auto posns = round.positions;
122+
for (auto it1 = posns.begin(); it1 != posns.end(); ++it1) {
123+
for (auto it2 = std::next(it1); it2 != posns.end(); ++it2) {
124+
const SourceShift& ss1 = it1->first;
125+
const SourceShift& ss2 = it2->first;
126+
if (ss1.source != ss2.source && it1->second == it2->second) {
127+
LLVM_DEBUG(llvm::dbgs()
128+
<< "Round " << roundNum << ": collision between "
129+
<< "{" << ss1.source.ct << "," << ss1.source.slot << "}"
130+
<< " and " << "{" << ss2.source.ct << ","
131+
<< ss2.source.slot << "}" << " at " << "{"
132+
<< it1->second.ct << "," << it1->second.slot << "}\n");
133+
fullConflictGraph.addEdge(ss1.source, ss2.source);
134+
}
135+
}
136+
}
137+
}
138+
139+
140+
struct SourceInfo {
141+
int numConflicts;
142+
int numInserts;
143+
};
144+
145+
std::unordered_map<CtSlot, SourceInfo> sourceMap;
146+
147+
for (CtSlot vertex : fullConflictGraph.getVertices()) {
148+
sourceMap[vertex] = SourceInfo{
149+
static_cast<int>(fullConflictGraph.edgesIncidentTo(vertex).size()),
150+
0};
151+
}
152+
153+
LLVM_DEBUG({
154+
llvm::dbgs() << "Full Conflict graph:\n";
155+
for (CtSlot vertex : fullConflictGraph.getVertices()) {
156+
llvm::dbgs() << vertex.ct << "," << vertex.slot << " <-> {";
157+
for (CtSlot neighbor : fullConflictGraph.edgesIncidentTo(vertex)) {
158+
llvm::dbgs() << neighbor.ct << "," << neighbor.slot << "; ";
159+
}
160+
llvm::dbgs() << "}\n";
161+
}
162+
});
163+
164+
// Create seeds for tests
165+
static std::mt19937 masterRng(std::random_device{}());
166+
static std::mt19937 rng(std::random_device{}());
167+
std::uniform_int_distribution<std::uint32_t> seedDist;
168+
169+
// Store best mapping
170+
int minNumRotations = std::numeric_limits<int>::max();
171+
Mapping bestMapping = mapping;
172+
graph::UndirectedGraph<CtSlot> bestConflictGraph;
173+
174+
// Copy the target to Sources for testing
175+
std::vector<std::pair<CtSlot, SmallVector<Mapping::CtMatch>*>> allTargets;
176+
allTargets.reserve(targetToSources.size());
177+
for (auto &entry : targetToSources) {
178+
allTargets.emplace_back(entry.first, &entry.second);
179+
}
180+
181+
// Run 50 tests
182+
for (int i = 0; i < 50; ++i) {
183+
// track source numberConflicts , inserts
184+
std::unordered_map<CtSlot, SourceInfo> testSourceMap = sourceMap;
185+
186+
// create test mapping and conflictgraph
187+
Mapping testMapping = mapping;
188+
graph::UndirectedGraph<CtSlot> testConflictGraph;
189+
190+
// shuffle targets and fill
191+
std::uint32_t trialSeed = seedDist(masterRng);
192+
rng.seed(trialSeed);
193+
std::shuffle(allTargets.begin(), allTargets.end(), rng);
194+
195+
for (auto &[slot, matchesPtr] : allTargets) {
196+
SmallVector<Mapping::CtMatch> &matches = *matchesPtr;
197+
std::vector<std::pair<CtSlot, SourceInfo*>> candidates;
198+
candidates.reserve(matches.size());
199+
200+
// Get and sort sources by 1) inserts , 2) conflicts ascending
201+
for (auto &match : matches) {
202+
CtSlot src = match.source;
203+
auto it = testSourceMap.find(src);
204+
candidates.emplace_back(src, &it->second);
205+
}
206+
std::sort(candidates.begin(), candidates.end(),
207+
[](const auto &a, const auto &b) {
208+
const SourceInfo &sa = *a.second;
209+
const SourceInfo &sb = *b.second;
210+
if (sa.numInserts != sb.numInserts)
211+
return sa.numInserts < sb.numInserts;
212+
return sa.numConflicts < sb.numConflicts;
213+
});
214+
215+
LLVM_DEBUG({
216+
llvm::dbgs() << "Candidates after sort for target "
217+
<< slot.ct << "," << slot.slot << ":\n";
218+
for (const auto &entry : candidates) {
219+
const CtSlot &src = entry.first;
220+
const SourceInfo &info = *entry.second;
221+
llvm::dbgs() << " src=(" << src.ct << "," << src.slot
222+
<< "), numInserts=" << info.numInserts
223+
<< ", numConflicts=" << info.numConflicts << "\n";
224+
}
225+
});
226+
227+
// Pick any source with lowest at random and save choice
228+
const SourceInfo &best = *candidates.front().second;
229+
int minInserts = best.numInserts;
230+
int minConflicts = best.numConflicts;
231+
232+
std::vector<size_t> bestIdx;
233+
bestIdx.reserve(candidates.size());
234+
235+
// any source with lowest
236+
for (size_t i = 0; i < candidates.size(); ++i) {
237+
const SourceInfo &info = *candidates[i].second;
238+
if (info.numInserts == minInserts && info.numConflicts == minConflicts)
239+
bestIdx.push_back(i);
240+
else
241+
break;
242+
}
243+
244+
// pick and save choice
245+
std::uniform_int_distribution<size_t> dist(0, bestIdx.size() - 1);
246+
size_t chosenIndex = bestIdx[dist(rng)];
247+
248+
CtSlot chosen = candidates[chosenIndex].first;
249+
SourceInfo &info = *candidates[chosenIndex].second;
250+
info.numInserts += 1;
251+
testMapping.add(chosen, slot);
252+
testConflictGraph.addVertex(chosen);
253+
}
254+
255+
256+
// Create test strategy, fill edges
257+
ShiftStrategy testStrategy =
258+
evaluateShiftStrategy(testMapping, defaultTestShiftOrder);
259+
for (const auto& [roundNum, round] :
260+
llvm::enumerate(testStrategy.getRounds())) {
261+
if (roundNum == 0) continue;
262+
auto posns = round.positions;
263+
for (auto it1 = posns.begin(); it1 != posns.end(); ++it1) {
264+
for (auto it2 = std::next(it1); it2 != posns.end(); ++it2) {
265+
const SourceShift& ss1 = it1->first;
266+
const SourceShift& ss2 = it2->first;
267+
if (ss1.source != ss2.source && it1->second == it2->second) {
268+
LLVM_DEBUG(llvm::dbgs()
269+
<< "Round " << roundNum << ": collision between "
270+
<< "{" << ss1.source.ct << "," << ss1.source.slot << "}"
271+
<< " and " << "{" << ss2.source.ct << ","
272+
<< ss2.source.slot << "}" << " at " << "{"
273+
<< it1->second.ct << "," << it1->second.slot << "}\n");
274+
testConflictGraph.addEdge(ss1.source, ss2.source);
275+
}
276+
}
277+
}
278+
}
279+
280+
// Find number of rotations needed
281+
graph::GreedyGraphColoring<CtSlot> testColorer;
282+
std::unordered_map<CtSlot, int> testColoring =
283+
testColorer.color(testConflictGraph);
284+
285+
SmallVector<RotationGroup> testResultRotationGroups;
286+
testResultRotationGroups.reserve(5);
287+
288+
for (const auto &entry : testColoring) {
289+
CtSlot source = entry.first;
290+
int64_t color = entry.second;
291+
if (color >= static_cast<int64_t>(testResultRotationGroups.size()))
292+
testResultRotationGroups.resize(color + 1);
293+
testResultRotationGroups[color].insert(source);
294+
}
295+
296+
// Update best number rotations & mapping
297+
int testNumberRotations = static_cast<int>(testResultRotationGroups.size());
298+
if (testNumberRotations < minNumRotations) {
299+
minNumRotations = testNumberRotations;
300+
bestMapping = testMapping;
301+
bestConflictGraph = testConflictGraph;
302+
}
303+
}
304+
305+
mapping.setTargetToSource(bestMapping.getTargetToSource());
306+
conflictGraph = bestConflictGraph;
307+
308+
LLVM_DEBUG({
309+
llvm::dbgs() << "Cleaning subroutine finished. "
310+
<< "Best num rotations = " << minNumRotations << "\n";
311+
312+
llvm::dbgs() << "Best mapping targetToSource:\n";
313+
for (const auto &[target, source] : bestMapping.getTargetToSource()) {
314+
llvm::dbgs() << " target=(" << target.ct << "," << target.slot << ")"
315+
<< " <- source=(" << source.ct << "," << source.slot << ")\n";
316+
}
317+
318+
llvm::dbgs() << "Conflict graph after cleaning:\n";
319+
for (CtSlot vertex : conflictGraph.getVertices()) {
320+
llvm::dbgs() << " " << vertex.ct << "," << vertex.slot << " <-> {";
321+
for (CtSlot neighbor : conflictGraph.edgesIncidentTo(vertex)) {
322+
llvm::dbgs() << neighbor.ct << "," << neighbor.slot << "; ";
323+
}
324+
llvm::dbgs() << "}\n";
325+
}
326+
});
327+
}
328+
94329
graph::GreedyGraphColoring<CtSlot> colorer;
95330
std::unordered_map<CtSlot, int> coloring = colorer.color(conflictGraph);
96331

@@ -125,7 +360,7 @@ ShiftScheme VosVosErkinShiftNetworks::findShiftScheme(
125360
}
126361

127362
ShiftScheme VosVosErkinShiftNetworks::findBestShiftScheme(
128-
const Mapping& mapping, std::size_t randomSeed, unsigned randomTries) {
363+
Mapping& mapping, std::size_t randomSeed, unsigned randomTries) {
129364
SmallVector<int64_t> initShiftOrder = defaultShiftOrder(
130365
mapping.getCiphertextSize() * mapping.getNumCiphertexts());
131366

@@ -156,15 +391,19 @@ ShiftScheme VosVosErkinShiftNetworks::findBestShiftScheme(
156391
}
157392

158393
ShiftStrategy VosVosErkinShiftNetworks::evaluateShiftStrategy(
159-
const Mapping& mapping, ArrayRef<int64_t> shiftOrder) {
394+
const Mapping& mapping, ArrayRef<int64_t> shiftOrder, bool useSources) {
395+
160396
CacheKey cacheKey = makeCacheKey(mapping, shiftOrder);
161397
if (strategyCache.count(cacheKey)) {
162398
return strategyCache[cacheKey];
163399
}
400+
LLVM_DEBUG(llvm::dbgs() << "Making shift strategy..." << "\n");
164401

165402
ShiftStrategy strategy(mapping.getCiphertextSize(),
166403
mapping.getNumCiphertexts(), shiftOrder);
167-
strategy.evaluate(mapping);
404+
LLVM_DEBUG(llvm::dbgs() << "Evaluating shift strategy..." << "\n");
405+
406+
strategy.evaluate(mapping, useSources);
168407
strategyCache[cacheKey] = strategy;
169408
return strategy;
170409
}

lib/Dialect/TensorExt/Transforms/ImplementShiftNetwork.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,19 @@ class VosVosErkinShiftNetworks {
4141
// on further calls to avoid recomputing the shift network.
4242
//
4343
// The default shiftOrder is LSB to MSB, i.e. 1, 2, 4, 8, ...
44-
ShiftScheme findShiftScheme(const Mapping& mapping,
44+
ShiftScheme findShiftScheme(Mapping& mapping,
4545
ArrayRef<int64_t> shiftOrder = {});
4646

4747
// Like findShiftScheme but randomly draw from a uniform distribution over all
4848
// possible shift orders and use the one that results in the best network.
49-
ShiftScheme findBestShiftScheme(const Mapping& mapping,
49+
ShiftScheme findBestShiftScheme(Mapping& mapping,
5050
std::size_t randomSeed,
5151
unsigned randomTries = 100);
5252

5353
private:
5454
ShiftStrategy evaluateShiftStrategy(const Mapping& mapping,
55-
ArrayRef<int64_t> shiftOrder);
55+
ArrayRef<int64_t> shiftOrder,
56+
bool useSources = false);
5657

5758
CacheKey makeCacheKey(const Mapping& mapping, ArrayRef<int64_t> shiftOrder);
5859

lib/Dialect/TensorExt/Transforms/ShiftScheme.cpp

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,40 @@ int64_t ShiftStrategy::getVirtualShift(const CtSlot& source,
4242
return normalizeShift(sourceIndex, targetIndex, virtualCiphertextSize);
4343
}
4444

45-
void ShiftStrategy::evaluate(const Mapping& mapping) {
45+
void ShiftStrategy::evaluate(const Mapping& mapping, bool useSources) {
4646
// First compute the virtual shifts needed for each source slot
4747
SmallVector<SourceShift> sourceShifts;
48-
sourceShifts.reserve(mapping.size());
49-
for (const auto& [target, source] : mapping.getTargetToSource()) {
50-
int64_t shift = getVirtualShift(source, target);
51-
sourceShifts.push_back({source, shift});
48+
if (!useSources) {
49+
sourceShifts.reserve(mapping.size());
50+
for (const auto& [target, source] : mapping.getTargetToSource()) {
51+
int64_t shift = getVirtualShift(source, target);
52+
sourceShifts.push_back({source, shift});
53+
}
54+
} else {
55+
LLVM_DEBUG(llvm::dbgs() << "Evaluating with all targets \n");
56+
57+
auto allTargets = mapping.getTargetToSources();
58+
59+
int counter = 0;
60+
llvm::DenseSet<SourceShift> uniqueShifts;
61+
for (const auto& kv : allTargets) {
62+
const CtSlot& target = kv.first;
63+
const auto& sources = kv.second;
64+
for (const auto& ctmatch : sources) {
65+
int64_t shift = getVirtualShift(ctmatch.source, target);
66+
SourceShift ss{ctmatch.source, shift};
67+
68+
// Ensuring no duplicate pairs, can delete
69+
counter += 1;
70+
if (uniqueShifts.insert(ss).second) {
71+
sourceShifts.push_back(ss);
72+
}
73+
}
74+
}
75+
LLVM_DEBUG(llvm::dbgs()
76+
<< "Total SourceShifts: " << counter << "\n";
77+
llvm::dbgs()
78+
<< "Unique SourceShifts: " << sourceShifts.size() << "\n";);
5279
}
5380

5481
// Compute the corresponding table of positions after each rotation,

0 commit comments

Comments
 (0)