Skip to content

Commit 0c4f829

Browse files
committed
modify handling shift network source conflicts
1 parent eaeff61 commit 0c4f829

5 files changed

Lines changed: 317 additions & 24 deletions

File tree

lib/Dialect/TensorExt/Transforms/ImplementShiftNetwork.cpp

Lines changed: 247 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,11 @@ namespace tensor_ext {
4343
#define GEN_PASS_DEF_IMPLEMENTSHIFTNETWORK
4444
#include "lib/Dialect/TensorExt/Transforms/Passes.h.inc"
4545

46-
ShiftScheme VosVosErkinShiftNetworks::findShiftScheme(
46+
ShiftSchemeResult VosVosErkinShiftNetworks::findShiftScheme(
4747
const Mapping& mapping, ArrayRef<int64_t> shiftOrder) {
4848
CacheKey cacheKey = makeCacheKey(mapping, shiftOrder);
4949
if (schemeCache.count(cacheKey)) {
50-
return schemeCache[cacheKey];
50+
return ShiftSchemeResult{schemeCache[cacheKey], mapping};
5151
}
5252

5353
ShiftStrategy strategy = evaluateShiftStrategy(mapping, shiftOrder);
@@ -91,6 +91,242 @@ ShiftScheme VosVosErkinShiftNetworks::findShiftScheme(
9191
}
9292
});
9393

94+
// Cleaning subroutine, identify if overlapping sources for targets
95+
Mapping cleanedMapping = mapping;
96+
auto targetToSources = mapping.getTargetToSources();
97+
bool hasMultiSourceTargets =
98+
std::any_of(targetToSources.begin(), targetToSources.end(),
99+
[](const auto& kv) { return kv.second.size() > 1; });
100+
LLVM_DEBUG({
101+
llvm::dbgs() << "Has overlapping sources for targets - "
102+
<< hasMultiSourceTargets << "\n";
103+
});
104+
if (hasMultiSourceTargets) {
105+
// Create a full conflict graph
106+
graph::UndirectedGraph<CtSlot> fullConflictGraph;
107+
for (const auto& [target, sources] : targetToSources) {
108+
for (const auto& ctmatch : sources) {
109+
fullConflictGraph.addVertex(ctmatch.source);
110+
}
111+
}
112+
SmallVector<int64_t> defaultTestShiftOrder = defaultShiftOrder(
113+
mapping.getCiphertextSize() * mapping.getNumCiphertexts());
114+
115+
ShiftStrategy fullConflictStrategy =
116+
evaluateShiftStrategy(mapping, defaultTestShiftOrder, true);
117+
118+
for (const auto& [roundNum, round] :
119+
llvm::enumerate(fullConflictStrategy.getRounds())) {
120+
if (roundNum == 0) continue;
121+
auto posns = round.positions;
122+
for (auto it1 = posns.begin(); it1 != posns.end(); ++it1) {
123+
for (auto it2 = std::next(it1); it2 != posns.end(); ++it2) {
124+
const SourceShift& ss1 = it1->first;
125+
const SourceShift& ss2 = it2->first;
126+
if (ss1.source != ss2.source && it1->second == it2->second) {
127+
LLVM_DEBUG(llvm::dbgs()
128+
<< "Round " << roundNum << ": collision between " << "{"
129+
<< ss1.source.ct << "," << ss1.source.slot << "}"
130+
<< " and " << "{" << ss2.source.ct << ","
131+
<< ss2.source.slot << "}" << " at " << "{"
132+
<< it1->second.ct << "," << it1->second.slot << "}\n");
133+
fullConflictGraph.addEdge(ss1.source, ss2.source);
134+
}
135+
}
136+
}
137+
}
138+
139+
struct SourceInfo {
140+
int numConflicts;
141+
int numInserts;
142+
};
143+
144+
std::unordered_map<CtSlot, SourceInfo> sourceMap;
145+
146+
for (CtSlot vertex : fullConflictGraph.getVertices()) {
147+
sourceMap[vertex] = SourceInfo{
148+
static_cast<int>(fullConflictGraph.edgesIncidentTo(vertex).size()),
149+
0};
150+
}
151+
152+
LLVM_DEBUG({
153+
llvm::dbgs() << "Full Conflict graph:\n";
154+
for (CtSlot vertex : fullConflictGraph.getVertices()) {
155+
llvm::dbgs() << vertex.ct << "," << vertex.slot << " <-> {";
156+
for (CtSlot neighbor : fullConflictGraph.edgesIncidentTo(vertex)) {
157+
llvm::dbgs() << neighbor.ct << "," << neighbor.slot << "; ";
158+
}
159+
llvm::dbgs() << "}\n";
160+
}
161+
});
162+
163+
// Create seeds for tests
164+
static std::mt19937 masterRng(std::random_device{}());
165+
static std::mt19937 rng(std::random_device{}());
166+
std::uniform_int_distribution<std::uint32_t> seedDist;
167+
168+
// Store best mapping
169+
int minNumRotations = std::numeric_limits<int>::max();
170+
Mapping bestMapping = mapping;
171+
graph::UndirectedGraph<CtSlot> bestConflictGraph;
172+
173+
// Copy the target to Sources for testing
174+
std::vector<std::pair<CtSlot, SmallVector<Mapping::CtMatch>*>> allTargets;
175+
allTargets.reserve(targetToSources.size());
176+
for (auto& entry : targetToSources) {
177+
allTargets.emplace_back(entry.first, &entry.second);
178+
}
179+
180+
// Run 50 tests
181+
for (int i = 0; i < 50; ++i) {
182+
// track source numberConflicts , inserts
183+
std::unordered_map<CtSlot, SourceInfo> testSourceMap = sourceMap;
184+
185+
// create test mapping and conflictgraph
186+
Mapping testMapping = mapping;
187+
graph::UndirectedGraph<CtSlot> testConflictGraph;
188+
189+
// shuffle targets and fill
190+
std::uint32_t trialSeed = seedDist(masterRng);
191+
rng.seed(trialSeed);
192+
std::shuffle(allTargets.begin(), allTargets.end(), rng);
193+
194+
for (auto& [slot, matchesPtr] : allTargets) {
195+
SmallVector<Mapping::CtMatch>& matches = *matchesPtr;
196+
std::vector<std::pair<CtSlot, SourceInfo*>> candidates;
197+
candidates.reserve(matches.size());
198+
199+
// Get and sort sources by 1) inserts , 2) conflicts ascending
200+
for (auto& match : matches) {
201+
CtSlot src = match.source;
202+
auto it = testSourceMap.find(src);
203+
candidates.emplace_back(src, &it->second);
204+
}
205+
std::sort(candidates.begin(), candidates.end(),
206+
[](const auto& a, const auto& b) {
207+
const SourceInfo& sa = *a.second;
208+
const SourceInfo& sb = *b.second;
209+
if (sa.numInserts != sb.numInserts)
210+
return sa.numInserts < sb.numInserts;
211+
return sa.numConflicts < sb.numConflicts;
212+
});
213+
214+
LLVM_DEBUG({
215+
llvm::dbgs() << "Candidates after sort for target " << slot.ct << ","
216+
<< slot.slot << ":\n";
217+
for (const auto& entry : candidates) {
218+
const CtSlot& src = entry.first;
219+
const SourceInfo& info = *entry.second;
220+
llvm::dbgs() << " src=(" << src.ct << "," << src.slot
221+
<< "), numInserts=" << info.numInserts
222+
<< ", numConflicts=" << info.numConflicts << "\n";
223+
}
224+
});
225+
226+
// Pick any source with lowest at random and save choice
227+
const SourceInfo& best = *candidates.front().second;
228+
int minInserts = best.numInserts;
229+
int minConflicts = best.numConflicts;
230+
231+
std::vector<size_t> bestIdx;
232+
bestIdx.reserve(candidates.size());
233+
234+
// any source with lowest
235+
for (size_t i = 0; i < candidates.size(); ++i) {
236+
const SourceInfo& info = *candidates[i].second;
237+
if (info.numInserts == minInserts &&
238+
info.numConflicts == minConflicts)
239+
bestIdx.push_back(i);
240+
else
241+
break;
242+
}
243+
244+
// pick and save choice
245+
std::uniform_int_distribution<size_t> dist(0, bestIdx.size() - 1);
246+
size_t chosenIndex = bestIdx[dist(rng)];
247+
248+
CtSlot chosen = candidates[chosenIndex].first;
249+
SourceInfo& info = *candidates[chosenIndex].second;
250+
info.numInserts += 1;
251+
testMapping.add(chosen, slot);
252+
testConflictGraph.addVertex(chosen);
253+
}
254+
255+
// Create test strategy, fill edges
256+
ShiftStrategy testStrategy =
257+
evaluateShiftStrategy(testMapping, defaultTestShiftOrder);
258+
for (const auto& [roundNum, round] :
259+
llvm::enumerate(testStrategy.getRounds())) {
260+
if (roundNum == 0) continue;
261+
auto posns = round.positions;
262+
for (auto it1 = posns.begin(); it1 != posns.end(); ++it1) {
263+
for (auto it2 = std::next(it1); it2 != posns.end(); ++it2) {
264+
const SourceShift& ss1 = it1->first;
265+
const SourceShift& ss2 = it2->first;
266+
if (ss1.source != ss2.source && it1->second == it2->second) {
267+
LLVM_DEBUG(llvm::dbgs()
268+
<< "Round " << roundNum << ": collision between "
269+
<< "{" << ss1.source.ct << "," << ss1.source.slot
270+
<< "}" << " and " << "{" << ss2.source.ct << ","
271+
<< ss2.source.slot << "}" << " at " << "{"
272+
<< it1->second.ct << "," << it1->second.slot << "}\n");
273+
testConflictGraph.addEdge(ss1.source, ss2.source);
274+
}
275+
}
276+
}
277+
}
278+
279+
// Find number of rotations needed
280+
graph::GreedyGraphColoring<CtSlot> testColorer;
281+
std::unordered_map<CtSlot, int> testColoring =
282+
testColorer.color(testConflictGraph);
283+
284+
SmallVector<RotationGroup> testResultRotationGroups;
285+
testResultRotationGroups.reserve(5);
286+
287+
for (const auto& entry : testColoring) {
288+
CtSlot source = entry.first;
289+
int64_t color = entry.second;
290+
if (color >= static_cast<int64_t>(testResultRotationGroups.size()))
291+
testResultRotationGroups.resize(color + 1);
292+
testResultRotationGroups[color].insert(source);
293+
}
294+
295+
// Update best number rotations & mapping
296+
int testNumberRotations =
297+
static_cast<int>(testResultRotationGroups.size());
298+
if (testNumberRotations < minNumRotations) {
299+
minNumRotations = testNumberRotations;
300+
bestMapping = testMapping;
301+
bestConflictGraph = testConflictGraph;
302+
}
303+
}
304+
305+
cleanedMapping = bestMapping;
306+
conflictGraph = bestConflictGraph;
307+
308+
LLVM_DEBUG({
309+
llvm::dbgs() << "Cleaning subroutine finished. "
310+
<< "Best num rotations = " << minNumRotations << "\n";
311+
312+
llvm::dbgs() << "Best mapping targetToSource:\n";
313+
for (const auto& [target, source] : bestMapping.getTargetToSource()) {
314+
llvm::dbgs() << " target=(" << target.ct << "," << target.slot << ")"
315+
<< " <- source=(" << source.ct << "," << source.slot
316+
<< ")\n";
317+
}
318+
319+
llvm::dbgs() << "Conflict graph after cleaning:\n";
320+
for (CtSlot vertex : conflictGraph.getVertices()) {
321+
llvm::dbgs() << " " << vertex.ct << "," << vertex.slot << " <-> {";
322+
for (CtSlot neighbor : conflictGraph.edgesIncidentTo(vertex)) {
323+
llvm::dbgs() << neighbor.ct << "," << neighbor.slot << "; ";
324+
}
325+
llvm::dbgs() << "}\n";
326+
}
327+
});
328+
}
329+
94330
graph::GreedyGraphColoring<CtSlot> colorer;
95331
std::unordered_map<CtSlot, int> coloring = colorer.color(conflictGraph);
96332

@@ -121,10 +357,10 @@ ShiftScheme VosVosErkinShiftNetworks::findShiftScheme(
121357

122358
ShiftScheme scheme{resultRotationGroups, strategy};
123359
schemeCache[cacheKey] = scheme;
124-
return schemeCache[cacheKey];
360+
return ShiftSchemeResult{schemeCache[cacheKey], cleanedMapping};
125361
}
126362

127-
ShiftScheme VosVosErkinShiftNetworks::findBestShiftScheme(
363+
ShiftSchemeResult VosVosErkinShiftNetworks::findBestShiftScheme(
128364
const Mapping& mapping, std::size_t randomSeed, unsigned randomTries) {
129365
SmallVector<int64_t> initShiftOrder = defaultShiftOrder(
130366
mapping.getCiphertextSize() * mapping.getNumCiphertexts());
@@ -156,15 +392,15 @@ ShiftScheme VosVosErkinShiftNetworks::findBestShiftScheme(
156392
}
157393

158394
ShiftStrategy VosVosErkinShiftNetworks::evaluateShiftStrategy(
159-
const Mapping& mapping, ArrayRef<int64_t> shiftOrder) {
395+
const Mapping& mapping, ArrayRef<int64_t> shiftOrder, bool useSources) {
160396
CacheKey cacheKey = makeCacheKey(mapping, shiftOrder);
161397
if (strategyCache.count(cacheKey)) {
162398
return strategyCache[cacheKey];
163399
}
164400

165401
ShiftStrategy strategy(mapping.getCiphertextSize(),
166402
mapping.getNumCiphertexts(), shiftOrder);
167-
strategy.evaluate(mapping);
403+
strategy.evaluate(mapping, useSources);
168404
strategyCache[cacheKey] = strategy;
169405
return strategy;
170406
}
@@ -230,7 +466,9 @@ LogicalResult convertRemapOp(RemapOp op,
230466
"DenseIntElementsAttr";
231467
}
232468

233-
ShiftScheme scheme = shiftNetworks.findShiftScheme(mapping);
469+
auto shiftSchemeResult = shiftNetworks.findShiftScheme(mapping);
470+
ShiftScheme& scheme = shiftSchemeResult.scheme;
471+
Mapping& cleanedMapping = shiftSchemeResult.cleanedMapping;
234472
auto rotationGroups = scheme.rotationGroups;
235473

236474
assert(!rotationGroups.empty() &&
@@ -260,8 +498,8 @@ LogicalResult convertRemapOp(RemapOp op,
260498
ciphertexts.push_back(kernel::SSAValue(slice.getResult()));
261499
}
262500

263-
auto resultNodes =
264-
implementShiftNetwork(ciphertexts, mapping, scheme, ciphertextSize);
501+
auto resultNodes = implementShiftNetwork(ciphertexts, cleanedMapping, scheme,
502+
ciphertextSize);
265503

266504
kernel::IRMaterializingVisitor visitor(b, singleCiphertextType);
267505
std::vector<Value> result = visitor.process(resultNodes);

lib/Dialect/TensorExt/Transforms/ImplementShiftNetwork.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ namespace tensor_ext {
2323
#define GEN_PASS_DECL_IMPLEMENTSHIFTNETWORK
2424
#include "lib/Dialect/TensorExt/Transforms/Passes.h.inc"
2525

26+
struct ShiftSchemeResult {
27+
ShiftScheme scheme;
28+
Mapping cleanedMapping;
29+
};
30+
2631
// Cf. https://link.springer.com/chapter/10.1007/978-3-031-17140-6_20
2732
// for an explanation of the algorithm.
2833
class VosVosErkinShiftNetworks {
@@ -41,18 +46,19 @@ class VosVosErkinShiftNetworks {
4146
// on further calls to avoid recomputing the shift network.
4247
//
4348
// The default shiftOrder is LSB to MSB, i.e. 1, 2, 4, 8, ...
44-
ShiftScheme findShiftScheme(const Mapping& mapping,
45-
ArrayRef<int64_t> shiftOrder = {});
49+
ShiftSchemeResult findShiftScheme(const Mapping& mapping,
50+
ArrayRef<int64_t> shiftOrder = {});
4651

4752
// Like findShiftScheme but randomly draw from a uniform distribution over all
4853
// possible shift orders and use the one that results in the best network.
49-
ShiftScheme findBestShiftScheme(const Mapping& mapping,
50-
std::size_t randomSeed,
51-
unsigned randomTries = 100);
54+
ShiftSchemeResult findBestShiftScheme(const Mapping& mapping,
55+
std::size_t randomSeed,
56+
unsigned randomTries = 100);
5257

5358
private:
5459
ShiftStrategy evaluateShiftStrategy(const Mapping& mapping,
55-
ArrayRef<int64_t> shiftOrder);
60+
ArrayRef<int64_t> shiftOrder,
61+
bool useSources = false);
5662

5763
CacheKey makeCacheKey(const Mapping& mapping, ArrayRef<int64_t> shiftOrder);
5864

lib/Dialect/TensorExt/Transforms/ShiftScheme.cpp

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,39 @@ int64_t ShiftStrategy::getVirtualShift(const CtSlot& source,
4242
return normalizeShift(sourceIndex, targetIndex, virtualCiphertextSize);
4343
}
4444

45-
void ShiftStrategy::evaluate(const Mapping& mapping) {
45+
void ShiftStrategy::evaluate(const Mapping& mapping, bool useSources) {
4646
// First compute the virtual shifts needed for each source slot
4747
SmallVector<SourceShift> sourceShifts;
48-
sourceShifts.reserve(mapping.size());
49-
for (const auto& [target, source] : mapping.getTargetToSource()) {
50-
int64_t shift = getVirtualShift(source, target);
51-
sourceShifts.push_back({source, shift});
48+
if (!useSources) {
49+
sourceShifts.reserve(mapping.size());
50+
for (const auto& [target, source] : mapping.getTargetToSource()) {
51+
int64_t shift = getVirtualShift(source, target);
52+
sourceShifts.push_back({source, shift});
53+
}
54+
} else {
55+
LLVM_DEBUG(llvm::dbgs() << "Evaluating with all targets \n");
56+
57+
auto allTargets = mapping.getTargetToSources();
58+
59+
int counter = 0;
60+
llvm::DenseSet<SourceShift> uniqueShifts;
61+
for (const auto& kv : allTargets) {
62+
const CtSlot& target = kv.first;
63+
const auto& sources = kv.second;
64+
for (const auto& ctmatch : sources) {
65+
int64_t shift = getVirtualShift(ctmatch.source, target);
66+
SourceShift ss{ctmatch.source, shift};
67+
68+
// Ensuring no duplicate pairs, can delete
69+
counter += 1;
70+
if (uniqueShifts.insert(ss).second) {
71+
sourceShifts.push_back(ss);
72+
}
73+
}
74+
}
75+
LLVM_DEBUG(llvm::dbgs() << "Total SourceShifts: " << counter << "\n";
76+
llvm::dbgs()
77+
<< "Unique SourceShifts: " << sourceShifts.size() << "\n";);
5278
}
5379

5480
// Compute the corresponding table of positions after each rotation,

0 commit comments

Comments
 (0)