@@ -44,7 +44,7 @@ namespace tensor_ext {
4444#include " lib/Dialect/TensorExt/Transforms/Passes.h.inc"
4545
4646ShiftScheme VosVosErkinShiftNetworks::findShiftScheme (
47- const Mapping& mapping, ArrayRef<int64_t > shiftOrder) {
47+ Mapping& mapping, ArrayRef<int64_t > shiftOrder) {
4848 CacheKey cacheKey = makeCacheKey (mapping, shiftOrder);
4949 if (schemeCache.count (cacheKey)) {
5050 return schemeCache[cacheKey];
@@ -91,6 +91,241 @@ ShiftScheme VosVosErkinShiftNetworks::findShiftScheme(
9191 }
9292 });
9393
94+ // Cleaning subroutine, identify if overlapping sources for targets
95+ auto targetToSources = mapping.getTargetToSources ();
96+ bool hasMultiSourceTargets =
97+ std::any_of (targetToSources.begin (), targetToSources.end (),
98+ [](const auto & kv) { return kv.second .size () > 1 ; });
99+ LLVM_DEBUG ({
100+ llvm::dbgs () << " Has overlapping sources for targets - "
101+ << hasMultiSourceTargets << " \n " ;
102+ });
103+ if (hasMultiSourceTargets) {
104+
105+ // Create a full conflict graph
106+ graph::UndirectedGraph<CtSlot> fullConflictGraph;
107+ for (const auto & [target, sources] : targetToSources) {
108+ for (const auto & ctmatch : sources) {
109+ fullConflictGraph.addVertex (ctmatch.source );
110+ }
111+ }
112+ SmallVector<int64_t > defaultTestShiftOrder = defaultShiftOrder (
113+ mapping.getCiphertextSize () * mapping.getNumCiphertexts ());
114+
115+ ShiftStrategy fullConflictStrategy =
116+ evaluateShiftStrategy (mapping, defaultTestShiftOrder, true );
117+
118+ for (const auto & [roundNum, round] :
119+ llvm::enumerate (fullConflictStrategy.getRounds ())) {
120+ if (roundNum == 0 ) continue ;
121+ auto posns = round.positions ;
122+ for (auto it1 = posns.begin (); it1 != posns.end (); ++it1) {
123+ for (auto it2 = std::next (it1); it2 != posns.end (); ++it2) {
124+ const SourceShift& ss1 = it1->first ;
125+ const SourceShift& ss2 = it2->first ;
126+ if (ss1.source != ss2.source && it1->second == it2->second ) {
127+ LLVM_DEBUG (llvm::dbgs ()
128+ << " Round " << roundNum << " : collision between "
129+ << " {" << ss1.source .ct << " ," << ss1.source .slot << " }"
130+ << " and " << " {" << ss2.source .ct << " ,"
131+ << ss2.source .slot << " }" << " at " << " {"
132+ << it1->second .ct << " ," << it1->second .slot << " }\n " );
133+ fullConflictGraph.addEdge (ss1.source , ss2.source );
134+ }
135+ }
136+ }
137+ }
138+
139+
140+ struct SourceInfo {
141+ int numConflicts;
142+ int numInserts;
143+ };
144+
145+ std::unordered_map<CtSlot, SourceInfo> sourceMap;
146+
147+ for (CtSlot vertex : fullConflictGraph.getVertices ()) {
148+ sourceMap[vertex] = SourceInfo{
149+ static_cast <int >(fullConflictGraph.edgesIncidentTo (vertex).size ()),
150+ 0 };
151+ }
152+
153+ LLVM_DEBUG ({
154+ llvm::dbgs () << " Full Conflict graph:\n " ;
155+ for (CtSlot vertex : fullConflictGraph.getVertices ()) {
156+ llvm::dbgs () << vertex.ct << " ," << vertex.slot << " <-> {" ;
157+ for (CtSlot neighbor : fullConflictGraph.edgesIncidentTo (vertex)) {
158+ llvm::dbgs () << neighbor.ct << " ," << neighbor.slot << " ; " ;
159+ }
160+ llvm::dbgs () << " }\n " ;
161+ }
162+ });
163+
164+ // Create seeds for tests
165+ static std::mt19937 masterRng (std::random_device{}());
166+ static std::mt19937 rng (std::random_device{}());
167+ std::uniform_int_distribution<std::uint32_t > seedDist;
168+
169+ // Store best mapping
170+ int minNumRotations = std::numeric_limits<int >::max ();
171+ Mapping bestMapping = mapping;
172+ graph::UndirectedGraph<CtSlot> bestConflictGraph;
173+
174+ // Copy the target to Sources for testing
175+ std::vector<std::pair<CtSlot, SmallVector<Mapping::CtMatch>*>> allTargets;
176+ allTargets.reserve (targetToSources.size ());
177+ for (auto &entry : targetToSources) {
178+ allTargets.emplace_back (entry.first , &entry.second );
179+ }
180+
181+ // Run 50 tests
182+ for (int i = 0 ; i < 50 ; ++i) {
183+ // track source numberConflicts , inserts
184+ std::unordered_map<CtSlot, SourceInfo> testSourceMap = sourceMap;
185+
186+ // create test mapping and conflictgraph
187+ Mapping testMapping = mapping;
188+ graph::UndirectedGraph<CtSlot> testConflictGraph;
189+
190+ // shuffle targets and fill
191+ std::uint32_t trialSeed = seedDist (masterRng);
192+ rng.seed (trialSeed);
193+ std::shuffle (allTargets.begin (), allTargets.end (), rng);
194+
195+ for (auto &[slot, matchesPtr] : allTargets) {
196+ SmallVector<Mapping::CtMatch> &matches = *matchesPtr;
197+ std::vector<std::pair<CtSlot, SourceInfo*>> candidates;
198+ candidates.reserve (matches.size ());
199+
200+ // Get and sort sources by 1) inserts , 2) conflicts ascending
201+ for (auto &match : matches) {
202+ CtSlot src = match.source ;
203+ auto it = testSourceMap.find (src);
204+ candidates.emplace_back (src, &it->second );
205+ }
206+ std::sort (candidates.begin (), candidates.end (),
207+ [](const auto &a, const auto &b) {
208+ const SourceInfo &sa = *a.second ;
209+ const SourceInfo &sb = *b.second ;
210+ if (sa.numInserts != sb.numInserts )
211+ return sa.numInserts < sb.numInserts ;
212+ return sa.numConflicts < sb.numConflicts ;
213+ });
214+
215+ LLVM_DEBUG ({
216+ llvm::dbgs () << " Candidates after sort for target "
217+ << slot.ct << " ," << slot.slot << " :\n " ;
218+ for (const auto &entry : candidates) {
219+ const CtSlot &src = entry.first ;
220+ const SourceInfo &info = *entry.second ;
221+ llvm::dbgs () << " src=(" << src.ct << " ," << src.slot
222+ << " ), numInserts=" << info.numInserts
223+ << " , numConflicts=" << info.numConflicts << " \n " ;
224+ }
225+ });
226+
227+ // Pick any source with lowest at random and save choice
228+ const SourceInfo &best = *candidates.front ().second ;
229+ int minInserts = best.numInserts ;
230+ int minConflicts = best.numConflicts ;
231+
232+ std::vector<size_t > bestIdx;
233+ bestIdx.reserve (candidates.size ());
234+
235+ // any source with lowest
236+ for (size_t i = 0 ; i < candidates.size (); ++i) {
237+ const SourceInfo &info = *candidates[i].second ;
238+ if (info.numInserts == minInserts && info.numConflicts == minConflicts)
239+ bestIdx.push_back (i);
240+ else
241+ break ;
242+ }
243+
244+ // pick and save choice
245+ std::uniform_int_distribution<size_t > dist (0 , bestIdx.size () - 1 );
246+ size_t chosenIndex = bestIdx[dist (rng)];
247+
248+ CtSlot chosen = candidates[chosenIndex].first ;
249+ SourceInfo &info = *candidates[chosenIndex].second ;
250+ info.numInserts += 1 ;
251+ testMapping.add (chosen, slot);
252+ testConflictGraph.addVertex (chosen);
253+ }
254+
255+
256+ // Create test strategy, fill edges
257+ ShiftStrategy testStrategy =
258+ evaluateShiftStrategy (testMapping, defaultTestShiftOrder);
259+ for (const auto & [roundNum, round] :
260+ llvm::enumerate (testStrategy.getRounds ())) {
261+ if (roundNum == 0 ) continue ;
262+ auto posns = round.positions ;
263+ for (auto it1 = posns.begin (); it1 != posns.end (); ++it1) {
264+ for (auto it2 = std::next (it1); it2 != posns.end (); ++it2) {
265+ const SourceShift& ss1 = it1->first ;
266+ const SourceShift& ss2 = it2->first ;
267+ if (ss1.source != ss2.source && it1->second == it2->second ) {
268+ LLVM_DEBUG (llvm::dbgs ()
269+ << " Round " << roundNum << " : collision between "
270+ << " {" << ss1.source .ct << " ," << ss1.source .slot << " }"
271+ << " and " << " {" << ss2.source .ct << " ,"
272+ << ss2.source .slot << " }" << " at " << " {"
273+ << it1->second .ct << " ," << it1->second .slot << " }\n " );
274+ testConflictGraph.addEdge (ss1.source , ss2.source );
275+ }
276+ }
277+ }
278+ }
279+
280+ // Find number of rotations needed
281+ graph::GreedyGraphColoring<CtSlot> testColorer;
282+ std::unordered_map<CtSlot, int > testColoring =
283+ testColorer.color (testConflictGraph);
284+
285+ SmallVector<RotationGroup> testResultRotationGroups;
286+ testResultRotationGroups.reserve (5 );
287+
288+ for (const auto &entry : testColoring) {
289+ CtSlot source = entry.first ;
290+ int64_t color = entry.second ;
291+ if (color >= static_cast <int64_t >(testResultRotationGroups.size ()))
292+ testResultRotationGroups.resize (color + 1 );
293+ testResultRotationGroups[color].insert (source);
294+ }
295+
296+ // Update best number rotations & mapping
297+ int testNumberRotations = static_cast <int >(testResultRotationGroups.size ());
298+ if (testNumberRotations < minNumRotations) {
299+ minNumRotations = testNumberRotations;
300+ bestMapping = testMapping;
301+ bestConflictGraph = testConflictGraph;
302+ }
303+ }
304+
305+ mapping.setTargetToSource (bestMapping.getTargetToSource ());
306+ conflictGraph = bestConflictGraph;
307+
308+ LLVM_DEBUG ({
309+ llvm::dbgs () << " Cleaning subroutine finished. "
310+ << " Best num rotations = " << minNumRotations << " \n " ;
311+
312+ llvm::dbgs () << " Best mapping targetToSource:\n " ;
313+ for (const auto &[target, source] : bestMapping.getTargetToSource ()) {
314+ llvm::dbgs () << " target=(" << target.ct << " ," << target.slot << " )"
315+ << " <- source=(" << source.ct << " ," << source.slot << " )\n " ;
316+ }
317+
318+ llvm::dbgs () << " Conflict graph after cleaning:\n " ;
319+ for (CtSlot vertex : conflictGraph.getVertices ()) {
320+ llvm::dbgs () << " " << vertex.ct << " ," << vertex.slot << " <-> {" ;
321+ for (CtSlot neighbor : conflictGraph.edgesIncidentTo (vertex)) {
322+ llvm::dbgs () << neighbor.ct << " ," << neighbor.slot << " ; " ;
323+ }
324+ llvm::dbgs () << " }\n " ;
325+ }
326+ });
327+ }
328+
94329 graph::GreedyGraphColoring<CtSlot> colorer;
95330 std::unordered_map<CtSlot, int > coloring = colorer.color (conflictGraph);
96331
@@ -125,7 +360,7 @@ ShiftScheme VosVosErkinShiftNetworks::findShiftScheme(
125360}
126361
127362ShiftScheme VosVosErkinShiftNetworks::findBestShiftScheme (
128- const Mapping& mapping, std::size_t randomSeed, unsigned randomTries) {
363+ Mapping& mapping, std::size_t randomSeed, unsigned randomTries) {
129364 SmallVector<int64_t > initShiftOrder = defaultShiftOrder (
130365 mapping.getCiphertextSize () * mapping.getNumCiphertexts ());
131366
@@ -156,15 +391,19 @@ ShiftScheme VosVosErkinShiftNetworks::findBestShiftScheme(
156391}
157392
158393ShiftStrategy VosVosErkinShiftNetworks::evaluateShiftStrategy (
159- const Mapping& mapping, ArrayRef<int64_t > shiftOrder) {
394+ const Mapping& mapping, ArrayRef<int64_t > shiftOrder, bool useSources) {
395+
160396 CacheKey cacheKey = makeCacheKey (mapping, shiftOrder);
161397 if (strategyCache.count (cacheKey)) {
162398 return strategyCache[cacheKey];
163399 }
400+ LLVM_DEBUG (llvm::dbgs () << " Making shift strategy..." << " \n " );
164401
165402 ShiftStrategy strategy (mapping.getCiphertextSize (),
166403 mapping.getNumCiphertexts (), shiftOrder);
167- strategy.evaluate (mapping);
404+ LLVM_DEBUG (llvm::dbgs () << " Evaluating shift strategy..." << " \n " );
405+
406+ strategy.evaluate (mapping, useSources);
168407 strategyCache[cacheKey] = strategy;
169408 return strategy;
170409}
0 commit comments