@@ -43,11 +43,11 @@ namespace tensor_ext {
4343#define GEN_PASS_DEF_IMPLEMENTSHIFTNETWORK
4444#include " lib/Dialect/TensorExt/Transforms/Passes.h.inc"
4545
46- ShiftScheme VosVosErkinShiftNetworks::findShiftScheme (
46+ ShiftSchemeResult VosVosErkinShiftNetworks::findShiftScheme (
4747 const Mapping& mapping, ArrayRef<int64_t > shiftOrder) {
4848 CacheKey cacheKey = makeCacheKey (mapping, shiftOrder);
4949 if (schemeCache.count (cacheKey)) {
50- return schemeCache[cacheKey];
50+ return ShiftSchemeResult{ schemeCache[cacheKey], mapping} ;
5151 }
5252
5353 ShiftStrategy strategy = evaluateShiftStrategy (mapping, shiftOrder);
@@ -91,6 +91,242 @@ ShiftScheme VosVosErkinShiftNetworks::findShiftScheme(
9191 }
9292 });
9393
94+ // Cleaning subroutine, identify if overlapping sources for targets
95+ Mapping cleanedMapping = mapping;
96+ auto targetToSources = mapping.getTargetToSources ();
97+ bool hasMultiSourceTargets =
98+ std::any_of (targetToSources.begin (), targetToSources.end (),
99+ [](const auto & kv) { return kv.second .size () > 1 ; });
100+ LLVM_DEBUG ({
101+ llvm::dbgs () << " Has overlapping sources for targets - "
102+ << hasMultiSourceTargets << " \n " ;
103+ });
104+ if (hasMultiSourceTargets) {
105+ // Create a full conflict graph
106+ graph::UndirectedGraph<CtSlot> fullConflictGraph;
107+ for (const auto & [target, sources] : targetToSources) {
108+ for (const auto & ctmatch : sources) {
109+ fullConflictGraph.addVertex (ctmatch.source );
110+ }
111+ }
112+ SmallVector<int64_t > defaultTestShiftOrder = defaultShiftOrder (
113+ mapping.getCiphertextSize () * mapping.getNumCiphertexts ());
114+
115+ ShiftStrategy fullConflictStrategy =
116+ evaluateShiftStrategy (mapping, defaultTestShiftOrder, true );
117+
118+ for (const auto & [roundNum, round] :
119+ llvm::enumerate (fullConflictStrategy.getRounds ())) {
120+ if (roundNum == 0 ) continue ;
121+ auto posns = round.positions ;
122+ for (auto it1 = posns.begin (); it1 != posns.end (); ++it1) {
123+ for (auto it2 = std::next (it1); it2 != posns.end (); ++it2) {
124+ const SourceShift& ss1 = it1->first ;
125+ const SourceShift& ss2 = it2->first ;
126+ if (ss1.source != ss2.source && it1->second == it2->second ) {
127+ LLVM_DEBUG (llvm::dbgs ()
128+ << " Round " << roundNum << " : collision between " << " {"
129+ << ss1.source .ct << " ," << ss1.source .slot << " }"
130+ << " and " << " {" << ss2.source .ct << " ,"
131+ << ss2.source .slot << " }" << " at " << " {"
132+ << it1->second .ct << " ," << it1->second .slot << " }\n " );
133+ fullConflictGraph.addEdge (ss1.source , ss2.source );
134+ }
135+ }
136+ }
137+ }
138+
139+ struct SourceInfo {
140+ int numConflicts;
141+ int numInserts;
142+ };
143+
144+ std::unordered_map<CtSlot, SourceInfo> sourceMap;
145+
146+ for (CtSlot vertex : fullConflictGraph.getVertices ()) {
147+ sourceMap[vertex] = SourceInfo{
148+ static_cast <int >(fullConflictGraph.edgesIncidentTo (vertex).size ()),
149+ 0 };
150+ }
151+
152+ LLVM_DEBUG ({
153+ llvm::dbgs () << " Full Conflict graph:\n " ;
154+ for (CtSlot vertex : fullConflictGraph.getVertices ()) {
155+ llvm::dbgs () << vertex.ct << " ," << vertex.slot << " <-> {" ;
156+ for (CtSlot neighbor : fullConflictGraph.edgesIncidentTo (vertex)) {
157+ llvm::dbgs () << neighbor.ct << " ," << neighbor.slot << " ; " ;
158+ }
159+ llvm::dbgs () << " }\n " ;
160+ }
161+ });
162+
163+ // Create seeds for tests
164+ static std::mt19937 masterRng (std::random_device{}());
165+ static std::mt19937 rng (std::random_device{}());
166+ std::uniform_int_distribution<std::uint32_t > seedDist;
167+
168+ // Store best mapping
169+ int minNumRotations = std::numeric_limits<int >::max ();
170+ Mapping bestMapping = mapping;
171+ graph::UndirectedGraph<CtSlot> bestConflictGraph;
172+
173+ // Copy the target to Sources for testing
174+ std::vector<std::pair<CtSlot, SmallVector<Mapping::CtMatch>*>> allTargets;
175+ allTargets.reserve (targetToSources.size ());
176+ for (auto & entry : targetToSources) {
177+ allTargets.emplace_back (entry.first , &entry.second );
178+ }
179+
180+ // Run 50 tests
181+ for (int i = 0 ; i < 50 ; ++i) {
182+ // track source numberConflicts , inserts
183+ std::unordered_map<CtSlot, SourceInfo> testSourceMap = sourceMap;
184+
185+ // create test mapping and conflictgraph
186+ Mapping testMapping = mapping;
187+ graph::UndirectedGraph<CtSlot> testConflictGraph;
188+
189+ // shuffle targets and fill
190+ std::uint32_t trialSeed = seedDist (masterRng);
191+ rng.seed (trialSeed);
192+ std::shuffle (allTargets.begin (), allTargets.end (), rng);
193+
194+ for (auto & [slot, matchesPtr] : allTargets) {
195+ SmallVector<Mapping::CtMatch>& matches = *matchesPtr;
196+ std::vector<std::pair<CtSlot, SourceInfo*>> candidates;
197+ candidates.reserve (matches.size ());
198+
199+ // Get and sort sources by 1) inserts , 2) conflicts ascending
200+ for (auto & match : matches) {
201+ CtSlot src = match.source ;
202+ auto it = testSourceMap.find (src);
203+ candidates.emplace_back (src, &it->second );
204+ }
205+ std::sort (candidates.begin (), candidates.end (),
206+ [](const auto & a, const auto & b) {
207+ const SourceInfo& sa = *a.second ;
208+ const SourceInfo& sb = *b.second ;
209+ if (sa.numInserts != sb.numInserts )
210+ return sa.numInserts < sb.numInserts ;
211+ return sa.numConflicts < sb.numConflicts ;
212+ });
213+
214+ LLVM_DEBUG ({
215+ llvm::dbgs () << " Candidates after sort for target " << slot.ct << " ,"
216+ << slot.slot << " :\n " ;
217+ for (const auto & entry : candidates) {
218+ const CtSlot& src = entry.first ;
219+ const SourceInfo& info = *entry.second ;
220+ llvm::dbgs () << " src=(" << src.ct << " ," << src.slot
221+ << " ), numInserts=" << info.numInserts
222+ << " , numConflicts=" << info.numConflicts << " \n " ;
223+ }
224+ });
225+
226+ // Pick any source with lowest at random and save choice
227+ const SourceInfo& best = *candidates.front ().second ;
228+ int minInserts = best.numInserts ;
229+ int minConflicts = best.numConflicts ;
230+
231+ std::vector<size_t > bestIdx;
232+ bestIdx.reserve (candidates.size ());
233+
234+ // any source with lowest
235+ for (size_t i = 0 ; i < candidates.size (); ++i) {
236+ const SourceInfo& info = *candidates[i].second ;
237+ if (info.numInserts == minInserts &&
238+ info.numConflicts == minConflicts)
239+ bestIdx.push_back (i);
240+ else
241+ break ;
242+ }
243+
244+ // pick and save choice
245+ std::uniform_int_distribution<size_t > dist (0 , bestIdx.size () - 1 );
246+ size_t chosenIndex = bestIdx[dist (rng)];
247+
248+ CtSlot chosen = candidates[chosenIndex].first ;
249+ SourceInfo& info = *candidates[chosenIndex].second ;
250+ info.numInserts += 1 ;
251+ testMapping.add (chosen, slot);
252+ testConflictGraph.addVertex (chosen);
253+ }
254+
255+ // Create test strategy, fill edges
256+ ShiftStrategy testStrategy =
257+ evaluateShiftStrategy (testMapping, defaultTestShiftOrder);
258+ for (const auto & [roundNum, round] :
259+ llvm::enumerate (testStrategy.getRounds ())) {
260+ if (roundNum == 0 ) continue ;
261+ auto posns = round.positions ;
262+ for (auto it1 = posns.begin (); it1 != posns.end (); ++it1) {
263+ for (auto it2 = std::next (it1); it2 != posns.end (); ++it2) {
264+ const SourceShift& ss1 = it1->first ;
265+ const SourceShift& ss2 = it2->first ;
266+ if (ss1.source != ss2.source && it1->second == it2->second ) {
267+ LLVM_DEBUG (llvm::dbgs ()
268+ << " Round " << roundNum << " : collision between "
269+ << " {" << ss1.source .ct << " ," << ss1.source .slot
270+ << " }" << " and " << " {" << ss2.source .ct << " ,"
271+ << ss2.source .slot << " }" << " at " << " {"
272+ << it1->second .ct << " ," << it1->second .slot << " }\n " );
273+ testConflictGraph.addEdge (ss1.source , ss2.source );
274+ }
275+ }
276+ }
277+ }
278+
279+ // Find number of rotations needed
280+ graph::GreedyGraphColoring<CtSlot> testColorer;
281+ std::unordered_map<CtSlot, int > testColoring =
282+ testColorer.color (testConflictGraph);
283+
284+ SmallVector<RotationGroup> testResultRotationGroups;
285+ testResultRotationGroups.reserve (5 );
286+
287+ for (const auto & entry : testColoring) {
288+ CtSlot source = entry.first ;
289+ int64_t color = entry.second ;
290+ if (color >= static_cast <int64_t >(testResultRotationGroups.size ()))
291+ testResultRotationGroups.resize (color + 1 );
292+ testResultRotationGroups[color].insert (source);
293+ }
294+
295+ // Update best number rotations & mapping
296+ int testNumberRotations =
297+ static_cast <int >(testResultRotationGroups.size ());
298+ if (testNumberRotations < minNumRotations) {
299+ minNumRotations = testNumberRotations;
300+ bestMapping = testMapping;
301+ bestConflictGraph = testConflictGraph;
302+ }
303+ }
304+
305+ cleanedMapping = bestMapping;
306+ conflictGraph = bestConflictGraph;
307+
308+ LLVM_DEBUG ({
309+ llvm::dbgs () << " Cleaning subroutine finished. "
310+ << " Best num rotations = " << minNumRotations << " \n " ;
311+
312+ llvm::dbgs () << " Best mapping targetToSource:\n " ;
313+ for (const auto & [target, source] : bestMapping.getTargetToSource ()) {
314+ llvm::dbgs () << " target=(" << target.ct << " ," << target.slot << " )"
315+ << " <- source=(" << source.ct << " ," << source.slot
316+ << " )\n " ;
317+ }
318+
319+ llvm::dbgs () << " Conflict graph after cleaning:\n " ;
320+ for (CtSlot vertex : conflictGraph.getVertices ()) {
321+ llvm::dbgs () << " " << vertex.ct << " ," << vertex.slot << " <-> {" ;
322+ for (CtSlot neighbor : conflictGraph.edgesIncidentTo (vertex)) {
323+ llvm::dbgs () << neighbor.ct << " ," << neighbor.slot << " ; " ;
324+ }
325+ llvm::dbgs () << " }\n " ;
326+ }
327+ });
328+ }
329+
94330 graph::GreedyGraphColoring<CtSlot> colorer;
95331 std::unordered_map<CtSlot, int > coloring = colorer.color (conflictGraph);
96332
@@ -121,10 +357,10 @@ ShiftScheme VosVosErkinShiftNetworks::findShiftScheme(
121357
122358 ShiftScheme scheme{resultRotationGroups, strategy};
123359 schemeCache[cacheKey] = scheme;
124- return schemeCache[cacheKey];
360+ return ShiftSchemeResult{ schemeCache[cacheKey], cleanedMapping} ;
125361}
126362
127- ShiftScheme VosVosErkinShiftNetworks::findBestShiftScheme (
363+ ShiftSchemeResult VosVosErkinShiftNetworks::findBestShiftScheme (
128364 const Mapping& mapping, std::size_t randomSeed, unsigned randomTries) {
129365 SmallVector<int64_t > initShiftOrder = defaultShiftOrder (
130366 mapping.getCiphertextSize () * mapping.getNumCiphertexts ());
@@ -156,15 +392,15 @@ ShiftScheme VosVosErkinShiftNetworks::findBestShiftScheme(
156392}
157393
158394ShiftStrategy VosVosErkinShiftNetworks::evaluateShiftStrategy (
159- const Mapping& mapping, ArrayRef<int64_t > shiftOrder) {
395+ const Mapping& mapping, ArrayRef<int64_t > shiftOrder, bool useSources ) {
160396 CacheKey cacheKey = makeCacheKey (mapping, shiftOrder);
161397 if (strategyCache.count (cacheKey)) {
162398 return strategyCache[cacheKey];
163399 }
164400
165401 ShiftStrategy strategy (mapping.getCiphertextSize (),
166402 mapping.getNumCiphertexts (), shiftOrder);
167- strategy.evaluate (mapping);
403+ strategy.evaluate (mapping, useSources );
168404 strategyCache[cacheKey] = strategy;
169405 return strategy;
170406}
@@ -230,7 +466,9 @@ LogicalResult convertRemapOp(RemapOp op,
230466 " DenseIntElementsAttr" ;
231467 }
232468
233- ShiftScheme scheme = shiftNetworks.findShiftScheme (mapping);
469+ auto shiftSchemeResult = shiftNetworks.findShiftScheme (mapping);
470+ ShiftScheme& scheme = shiftSchemeResult.scheme ;
471+ Mapping& cleanedMapping = shiftSchemeResult.cleanedMapping ;
234472 auto rotationGroups = scheme.rotationGroups ;
235473
236474 assert (!rotationGroups.empty () &&
@@ -260,8 +498,8 @@ LogicalResult convertRemapOp(RemapOp op,
260498 ciphertexts.push_back (kernel::SSAValue (slice.getResult ()));
261499 }
262500
263- auto resultNodes =
264- implementShiftNetwork (ciphertexts, mapping, scheme, ciphertextSize);
501+ auto resultNodes = implementShiftNetwork (ciphertexts, cleanedMapping, scheme,
502+ ciphertextSize);
265503
266504 kernel::IRMaterializingVisitor visitor (b, singleCiphertextType);
267505 std::vector<Value> result = visitor.process (resultNodes);
0 commit comments