From 4d1911146f8bc3fd3332cb167255ba14298208bd Mon Sep 17 00:00:00 2001 From: Isabel Faulds Date: Tue, 6 Jan 2026 20:29:05 -0800 Subject: [PATCH] modify handling shift network source conflicts --- .../Transforms/ImplementShiftNetwork.cpp | 256 +++++++++++++++++- .../Transforms/ImplementShiftNetwork.h | 18 +- .../Transforms/ImplementShiftNetworkTest.cpp | 10 +- .../TensorExt/Transforms/ShiftScheme.cpp | 36 ++- .../TensorExt/Transforms/ShiftScheme.h | 23 +- .../LayoutConversionCost.cpp | 8 +- 6 files changed, 323 insertions(+), 28 deletions(-) diff --git a/lib/Dialect/TensorExt/Transforms/ImplementShiftNetwork.cpp b/lib/Dialect/TensorExt/Transforms/ImplementShiftNetwork.cpp index 2ef85230d6..cab978dceb 100644 --- a/lib/Dialect/TensorExt/Transforms/ImplementShiftNetwork.cpp +++ b/lib/Dialect/TensorExt/Transforms/ImplementShiftNetwork.cpp @@ -43,11 +43,11 @@ namespace tensor_ext { #define GEN_PASS_DEF_IMPLEMENTSHIFTNETWORK #include "lib/Dialect/TensorExt/Transforms/Passes.h.inc" -ShiftScheme VosVosErkinShiftNetworks::findShiftScheme( +ShiftSchemeResult VosVosErkinShiftNetworks::findShiftScheme( const Mapping& mapping, ArrayRef shiftOrder) { CacheKey cacheKey = makeCacheKey(mapping, shiftOrder); if (schemeCache.count(cacheKey)) { - return schemeCache[cacheKey]; + return ShiftSchemeResult{schemeCache[cacheKey], mapping}; } ShiftStrategy strategy = evaluateShiftStrategy(mapping, shiftOrder); @@ -91,6 +91,242 @@ ShiftScheme VosVosErkinShiftNetworks::findShiftScheme( } }); + // Cleaning subroutine, identify if overlapping sources for targets + Mapping cleanedMapping = mapping; + auto targetToSources = mapping.getTargetToSources(); + bool hasMultiSourceTargets = + std::any_of(targetToSources.begin(), targetToSources.end(), + [](const auto& kv) { return kv.second.size() > 1; }); + LLVM_DEBUG({ + llvm::dbgs() << "Has overlapping sources for targets - " + << hasMultiSourceTargets << "\n"; + }); + if (hasMultiSourceTargets) { + // Create a full conflict graph + graph::UndirectedGraph fullConflictGraph; + for (const auto& [target, sources] : targetToSources) { + for (const auto& ctmatch : sources) { + fullConflictGraph.addVertex(ctmatch.source); + } + } + SmallVector defaultTestShiftOrder = defaultShiftOrder( + mapping.getCiphertextSize() * mapping.getNumCiphertexts()); + + ShiftStrategy fullConflictStrategy = + evaluateShiftStrategy(mapping, defaultTestShiftOrder, true); + + for (const auto& [roundNum, round] : + llvm::enumerate(fullConflictStrategy.getRounds())) { + if (roundNum == 0) continue; + auto posns = round.positions; + for (auto it1 = posns.begin(); it1 != posns.end(); ++it1) { + for (auto it2 = std::next(it1); it2 != posns.end(); ++it2) { + const SourceShift& ss1 = it1->first; + const SourceShift& ss2 = it2->first; + if (ss1.source != ss2.source && it1->second == it2->second) { + LLVM_DEBUG(llvm::dbgs() + << "Round " << roundNum << ": collision between " << "{" + << ss1.source.ct << "," << ss1.source.slot << "}" + << " and " << "{" << ss2.source.ct << "," + << ss2.source.slot << "}" << " at " << "{" + << it1->second.ct << "," << it1->second.slot << "}\n"); + fullConflictGraph.addEdge(ss1.source, ss2.source); + } + } + } + } + + struct SourceInfo { + int numConflicts; + int numInserts; + }; + + std::unordered_map sourceMap; + + for (CtSlot vertex : fullConflictGraph.getVertices()) { + sourceMap[vertex] = SourceInfo{ + static_cast(fullConflictGraph.edgesIncidentTo(vertex).size()), + 0}; + } + + LLVM_DEBUG({ + llvm::dbgs() << "Full Conflict graph:\n"; + for (CtSlot vertex : fullConflictGraph.getVertices()) { + llvm::dbgs() << vertex.ct << "," << vertex.slot << " <-> {"; + for (CtSlot neighbor : fullConflictGraph.edgesIncidentTo(vertex)) { + llvm::dbgs() << neighbor.ct << "," << neighbor.slot << "; "; + } + llvm::dbgs() << "}\n"; + } + }); + + // Create seeds for tests + static std::mt19937 masterRng(std::random_device{}()); + static std::mt19937 rng(std::random_device{}()); + std::uniform_int_distribution seedDist; + + // Store best mapping + int minNumRotations = std::numeric_limits::max(); + Mapping bestMapping = mapping; + graph::UndirectedGraph bestConflictGraph; + + // Copy the target to Sources for testing + std::vector*>> allTargets; + allTargets.reserve(targetToSources.size()); + for (auto& entry : targetToSources) { + allTargets.emplace_back(entry.first, &entry.second); + } + + // Run 50 tests + for (int i = 0; i < 50; ++i) { + // track source numberConflicts , inserts + std::unordered_map testSourceMap = sourceMap; + + // create test mapping and conflictgraph + Mapping testMapping = mapping; + graph::UndirectedGraph testConflictGraph; + + // shuffle targets and fill + std::uint32_t trialSeed = seedDist(masterRng); + rng.seed(trialSeed); + std::shuffle(allTargets.begin(), allTargets.end(), rng); + + for (auto& [slot, matchesPtr] : allTargets) { + SmallVector& matches = *matchesPtr; + std::vector> candidates; + candidates.reserve(matches.size()); + + // Get and sort sources by 1) inserts , 2) conflicts ascending + for (auto& match : matches) { + CtSlot src = match.source; + auto it = testSourceMap.find(src); + candidates.emplace_back(src, &it->second); + } + std::sort(candidates.begin(), candidates.end(), + [](const auto& a, const auto& b) { + const SourceInfo& sa = *a.second; + const SourceInfo& sb = *b.second; + if (sa.numInserts != sb.numInserts) + return sa.numInserts < sb.numInserts; + return sa.numConflicts < sb.numConflicts; + }); + + LLVM_DEBUG({ + llvm::dbgs() << "Candidates after sort for target " << slot.ct << "," + << slot.slot << ":\n"; + for (const auto& entry : candidates) { + const CtSlot& src = entry.first; + const SourceInfo& info = *entry.second; + llvm::dbgs() << " src=(" << src.ct << "," << src.slot + << "), numInserts=" << info.numInserts + << ", numConflicts=" << info.numConflicts << "\n"; + } + }); + + // Pick any source with lowest at random and save choice + const SourceInfo& best = *candidates.front().second; + int minInserts = best.numInserts; + int minConflicts = best.numConflicts; + + std::vector bestIdx; + bestIdx.reserve(candidates.size()); + + // any source with lowest + for (size_t i = 0; i < candidates.size(); ++i) { + const SourceInfo& info = *candidates[i].second; + if (info.numInserts == minInserts && + info.numConflicts == minConflicts) + bestIdx.push_back(i); + else + break; + } + + // pick and save choice + std::uniform_int_distribution dist(0, bestIdx.size() - 1); + size_t chosenIndex = bestIdx[dist(rng)]; + + CtSlot chosen = candidates[chosenIndex].first; + SourceInfo& info = *candidates[chosenIndex].second; + info.numInserts += 1; + testMapping.add(chosen, slot); + testConflictGraph.addVertex(chosen); + } + + // Create test strategy, fill edges + ShiftStrategy testStrategy = + evaluateShiftStrategy(testMapping, defaultTestShiftOrder); + for (const auto& [roundNum, round] : + llvm::enumerate(testStrategy.getRounds())) { + if (roundNum == 0) continue; + auto posns = round.positions; + for (auto it1 = posns.begin(); it1 != posns.end(); ++it1) { + for (auto it2 = std::next(it1); it2 != posns.end(); ++it2) { + const SourceShift& ss1 = it1->first; + const SourceShift& ss2 = it2->first; + if (ss1.source != ss2.source && it1->second == it2->second) { + LLVM_DEBUG(llvm::dbgs() + << "Round " << roundNum << ": collision between " + << "{" << ss1.source.ct << "," << ss1.source.slot + << "}" << " and " << "{" << ss2.source.ct << "," + << ss2.source.slot << "}" << " at " << "{" + << it1->second.ct << "," << it1->second.slot << "}\n"); + testConflictGraph.addEdge(ss1.source, ss2.source); + } + } + } + } + + // Find number of rotations needed + graph::GreedyGraphColoring testColorer; + std::unordered_map testColoring = + testColorer.color(testConflictGraph); + + SmallVector testResultRotationGroups; + testResultRotationGroups.reserve(5); + + for (const auto& entry : testColoring) { + CtSlot source = entry.first; + int64_t color = entry.second; + if (color >= static_cast(testResultRotationGroups.size())) + testResultRotationGroups.resize(color + 1); + testResultRotationGroups[color].insert(source); + } + + // Update best number rotations & mapping + int testNumberRotations = + static_cast(testResultRotationGroups.size()); + if (testNumberRotations < minNumRotations) { + minNumRotations = testNumberRotations; + bestMapping = testMapping; + bestConflictGraph = testConflictGraph; + } + } + + cleanedMapping = bestMapping; + conflictGraph = bestConflictGraph; + + LLVM_DEBUG({ + llvm::dbgs() << "Cleaning subroutine finished. " + << "Best num rotations = " << minNumRotations << "\n"; + + llvm::dbgs() << "Best mapping targetToSource:\n"; + for (const auto& [target, source] : bestMapping.getTargetToSource()) { + llvm::dbgs() << " target=(" << target.ct << "," << target.slot << ")" + << " <- source=(" << source.ct << "," << source.slot + << ")\n"; + } + + llvm::dbgs() << "Conflict graph after cleaning:\n"; + for (CtSlot vertex : conflictGraph.getVertices()) { + llvm::dbgs() << " " << vertex.ct << "," << vertex.slot << " <-> {"; + for (CtSlot neighbor : conflictGraph.edgesIncidentTo(vertex)) { + llvm::dbgs() << neighbor.ct << "," << neighbor.slot << "; "; + } + llvm::dbgs() << "}\n"; + } + }); + } + graph::GreedyGraphColoring colorer; std::unordered_map coloring = colorer.color(conflictGraph); @@ -121,10 +357,10 @@ ShiftScheme VosVosErkinShiftNetworks::findShiftScheme( ShiftScheme scheme{resultRotationGroups, strategy}; schemeCache[cacheKey] = scheme; - return schemeCache[cacheKey]; + return ShiftSchemeResult{schemeCache[cacheKey], cleanedMapping}; } -ShiftScheme VosVosErkinShiftNetworks::findBestShiftScheme( +ShiftSchemeResult VosVosErkinShiftNetworks::findBestShiftScheme( const Mapping& mapping, std::size_t randomSeed, unsigned randomTries) { SmallVector initShiftOrder = defaultShiftOrder( mapping.getCiphertextSize() * mapping.getNumCiphertexts()); @@ -156,7 +392,7 @@ ShiftScheme VosVosErkinShiftNetworks::findBestShiftScheme( } ShiftStrategy VosVosErkinShiftNetworks::evaluateShiftStrategy( - const Mapping& mapping, ArrayRef shiftOrder) { + const Mapping& mapping, ArrayRef shiftOrder, bool useSources) { CacheKey cacheKey = makeCacheKey(mapping, shiftOrder); if (strategyCache.count(cacheKey)) { return strategyCache[cacheKey]; @@ -164,7 +400,7 @@ ShiftStrategy VosVosErkinShiftNetworks::evaluateShiftStrategy( ShiftStrategy strategy(mapping.getCiphertextSize(), mapping.getNumCiphertexts(), shiftOrder); - strategy.evaluate(mapping); + strategy.evaluate(mapping, useSources); strategyCache[cacheKey] = strategy; return strategy; } @@ -230,7 +466,9 @@ LogicalResult convertRemapOp(RemapOp op, "DenseIntElementsAttr"; } - ShiftScheme scheme = shiftNetworks.findShiftScheme(mapping); + auto shiftSchemeResult = shiftNetworks.findShiftScheme(mapping); + ShiftScheme& scheme = shiftSchemeResult.scheme; + Mapping& cleanedMapping = shiftSchemeResult.cleanedMapping; auto rotationGroups = scheme.rotationGroups; assert(!rotationGroups.empty() && @@ -260,8 +498,8 @@ LogicalResult convertRemapOp(RemapOp op, ciphertexts.push_back(kernel::SSAValue(slice.getResult())); } - auto resultNodes = - implementShiftNetwork(ciphertexts, mapping, scheme, ciphertextSize); + auto resultNodes = implementShiftNetwork(ciphertexts, cleanedMapping, scheme, + ciphertextSize); kernel::IRMaterializingVisitor visitor(b, singleCiphertextType); std::vector result = visitor.process(resultNodes); diff --git a/lib/Dialect/TensorExt/Transforms/ImplementShiftNetwork.h b/lib/Dialect/TensorExt/Transforms/ImplementShiftNetwork.h index 152aabbf9d..bb92bb232a 100644 --- a/lib/Dialect/TensorExt/Transforms/ImplementShiftNetwork.h +++ b/lib/Dialect/TensorExt/Transforms/ImplementShiftNetwork.h @@ -23,6 +23,11 @@ namespace tensor_ext { #define GEN_PASS_DECL_IMPLEMENTSHIFTNETWORK #include "lib/Dialect/TensorExt/Transforms/Passes.h.inc" +struct ShiftSchemeResult { + ShiftScheme scheme; + Mapping cleanedMapping; +}; + // Cf. https://link.springer.com/chapter/10.1007/978-3-031-17140-6_20 // for an explanation of the algorithm. class VosVosErkinShiftNetworks { @@ -41,18 +46,19 @@ class VosVosErkinShiftNetworks { // on further calls to avoid recomputing the shift network. // // The default shiftOrder is LSB to MSB, i.e. 1, 2, 4, 8, ... - ShiftScheme findShiftScheme(const Mapping& mapping, - ArrayRef shiftOrder = {}); + ShiftSchemeResult findShiftScheme(const Mapping& mapping, + ArrayRef shiftOrder = {}); // Like findShiftScheme but randomly draw from a uniform distribution over all // possible shift orders and use the one that results in the best network. - ShiftScheme findBestShiftScheme(const Mapping& mapping, - std::size_t randomSeed, - unsigned randomTries = 100); + ShiftSchemeResult findBestShiftScheme(const Mapping& mapping, + std::size_t randomSeed, + unsigned randomTries = 100); private: ShiftStrategy evaluateShiftStrategy(const Mapping& mapping, - ArrayRef shiftOrder); + ArrayRef shiftOrder, + bool useSources = false); CacheKey makeCacheKey(const Mapping& mapping, ArrayRef shiftOrder); diff --git a/lib/Dialect/TensorExt/Transforms/ImplementShiftNetworkTest.cpp b/lib/Dialect/TensorExt/Transforms/ImplementShiftNetworkTest.cpp index 4fc84392bb..71ed26703d 100644 --- a/lib/Dialect/TensorExt/Transforms/ImplementShiftNetworkTest.cpp +++ b/lib/Dialect/TensorExt/Transforms/ImplementShiftNetworkTest.cpp @@ -91,8 +91,8 @@ ::testing::AssertionResult checkMapping(const Mapping& mapping, int64_t ciphertextSize, unsigned naiveNumRGExpected = 0) { VosVosErkinShiftNetworks shiftNetworks; - - auto naiveScheme = shiftNetworks.findShiftScheme(mapping); + auto naiveSchemeResult = shiftNetworks.findShiftScheme(mapping); + const ShiftScheme& naiveScheme = naiveSchemeResult.scheme; unsigned naiveNumRG = naiveScheme.rotationGroups.size(); unsigned naiveNumRounds = naiveScheme.strategy.getRounds().size(); if (naiveNumRGExpected > 0 && naiveNumRG != naiveNumRGExpected) @@ -106,8 +106,10 @@ ::testing::AssertionResult checkMapping(const Mapping& mapping, // We try a large number of shift orders here such that we can be effectively // certain that we will find a network that is at least as good as the "naive" // one. - auto bestScheme = shiftNetworks.findBestShiftScheme( - mapping, /*randomSeed=*/42, /*randomTries=*/1000); + auto bestSchemeResult = + shiftNetworks.findBestShiftScheme(mapping, /*randomSeed=*/42, + /*randomTries=*/1000); + const ShiftScheme& bestScheme = bestSchemeResult.scheme; unsigned bestNumRounds = bestScheme.strategy.getRounds().size(); if (bestNumRounds > naiveNumRounds) return ::testing::AssertionFailure() diff --git a/lib/Dialect/TensorExt/Transforms/ShiftScheme.cpp b/lib/Dialect/TensorExt/Transforms/ShiftScheme.cpp index 310f9e78a2..224f09df91 100644 --- a/lib/Dialect/TensorExt/Transforms/ShiftScheme.cpp +++ b/lib/Dialect/TensorExt/Transforms/ShiftScheme.cpp @@ -42,13 +42,39 @@ int64_t ShiftStrategy::getVirtualShift(const CtSlot& source, return normalizeShift(sourceIndex, targetIndex, virtualCiphertextSize); } -void ShiftStrategy::evaluate(const Mapping& mapping) { +void ShiftStrategy::evaluate(const Mapping& mapping, bool useSources) { // First compute the virtual shifts needed for each source slot SmallVector sourceShifts; - sourceShifts.reserve(mapping.size()); - for (const auto& [target, source] : mapping.getTargetToSource()) { - int64_t shift = getVirtualShift(source, target); - sourceShifts.push_back({source, shift}); + if (!useSources) { + sourceShifts.reserve(mapping.size()); + for (const auto& [target, source] : mapping.getTargetToSource()) { + int64_t shift = getVirtualShift(source, target); + sourceShifts.push_back({source, shift}); + } + } else { + LLVM_DEBUG(llvm::dbgs() << "Evaluating with all targets \n"); + + auto allTargets = mapping.getTargetToSources(); + + int counter = 0; + llvm::DenseSet uniqueShifts; + for (const auto& kv : allTargets) { + const CtSlot& target = kv.first; + const auto& sources = kv.second; + for (const auto& ctmatch : sources) { + int64_t shift = getVirtualShift(ctmatch.source, target); + SourceShift ss{ctmatch.source, shift}; + + // Ensuring no duplicate pairs, can delete + counter += 1; + if (uniqueShifts.insert(ss).second) { + sourceShifts.push_back(ss); + } + } + } + LLVM_DEBUG(llvm::dbgs() << "Total SourceShifts: " << counter << "\n"; + llvm::dbgs() + << "Unique SourceShifts: " << sourceShifts.size() << "\n";); } // Compute the corresponding table of positions after each rotation, diff --git a/lib/Dialect/TensorExt/Transforms/ShiftScheme.h b/lib/Dialect/TensorExt/Transforms/ShiftScheme.h index fec67e18b9..fed7dad61e 100644 --- a/lib/Dialect/TensorExt/Transforms/ShiftScheme.h +++ b/lib/Dialect/TensorExt/Transforms/ShiftScheme.h @@ -85,6 +85,11 @@ namespace tensor_ext { // An arbitrary mapping on the slots of a set of ciphertexts. class Mapping { public: + struct CtMatch { + CtSlot source; + int64_t distance; + }; + Mapping(int64_t ciphertextSize = 1, int64_t numCiphertexts = 1) : ciphertextSize(ciphertextSize), numCiphertexts(numCiphertexts) {} @@ -92,6 +97,11 @@ class Mapping { void add(CtSlot source, CtSlot target) { auto [it, inserted] = targetToSource.insert({target, source}); + // Save map of targets & all possible sources with their distances between + // target & source + targetToSources[target].push_back( + CtMatch{source, getVirtualDistance(target, source)}); + if (!inserted) { // Update the mapping if the new source is closer to the target than the // existing source. This will select for the closest source when there are @@ -106,7 +116,17 @@ class Mapping { } } + void clearTargetToSource() { targetToSource.clear(); } + DenseMap getTargetToSource() const { return targetToSource; } + DenseMap> getTargetToSources() const { + return targetToSources; + } + + void setTargetToSource( + const llvm::DenseMap& newTargetToSource) { + targetToSource = newTargetToSource; + } int64_t getCiphertextSize() const { return ciphertextSize; } int64_t getNumCiphertexts() const { return numCiphertexts; } @@ -117,6 +137,7 @@ class Mapping { // Map from target to source to ensure only a single source is mapped to any // target. DenseMap targetToSource; + DenseMap> targetToSources; int64_t getVirtualDistance(const CtSlot& lhs, const CtSlot& rhs) { return std::abs(lhs.ct - rhs.ct) * ciphertextSize + @@ -182,7 +203,7 @@ class ShiftStrategy { SmallVector getRounds() const { return rounds; } // Run the shifting strategy and populate the list of rounds in the strategy - void evaluate(const Mapping& mapping); + void evaluate(const Mapping& mapping, bool useSources = false); private: int64_t ciphertextSize; diff --git a/lib/Transforms/LayoutOptimization/LayoutConversionCost.cpp b/lib/Transforms/LayoutOptimization/LayoutConversionCost.cpp index eadf5b55f8..ce80e96206 100644 --- a/lib/Transforms/LayoutOptimization/LayoutConversionCost.cpp +++ b/lib/Transforms/LayoutOptimization/LayoutConversionCost.cpp @@ -100,15 +100,17 @@ Cost computeCostOfLayoutConversion(int64_t numCiphertexts, } tensor_ext::VosVosErkinShiftNetworks shiftNetwork; - ShiftScheme scheme = + auto schemeResult = shiftNetwork.findBestShiftScheme(mapping, vveRandomSeed, vveRandomTries); + ShiftScheme scheme = schemeResult.scheme; + Mapping cleanedMapping = schemeResult.cleanedMapping; using NodeTy = ArithmeticDagNode; using ValueTy = std::shared_ptr; SmallVector inputLeaves(numCiphertexts, SymbolicValue({ciphertextSize})); - SmallVector> groupResults = - implementRotationGroups(inputLeaves, mapping, scheme, ciphertextSize); + SmallVector> groupResults = implementRotationGroups( + inputLeaves, cleanedMapping, scheme, ciphertextSize); // The cost is the maximum number of rotations in any group Cost maxRotations = 0;