-
Notifications
You must be signed in to change notification settings - Fork 135
Expand file tree
/
Copy pathImplementShiftNetwork.h
More file actions
73 lines (58 loc) · 2.71 KB
/
ImplementShiftNetwork.h
File metadata and controls
73 lines (58 loc) · 2.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#ifndef LIB_DIALECT_TENSOREXT_TRANSFORMS_IMPLEMENTSHIFTNETWORK_H_
#define LIB_DIALECT_TENSOREXT_TRANSFORMS_IMPLEMENTSHIFTNETWORK_H_
/// An implementation of the graph coloring approach of Vos-Vos-Erkin 2022 from
/// http://dx.doi.org/10.1007/978-3-031-17140-6_20
///
/// This implements a version of the algorithm that supports arbitrary mappings
/// across multi-ciphertexts, including replication.
#include <cstdint>
#include <utility>
#include "lib/Dialect/TensorExt/Transforms/ShiftScheme.h"
#include "lib/Utils/ADT/FrozenVector.h"
#include "lib/Utils/MathUtils.h"
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
namespace mlir {
namespace heir {
namespace tensor_ext {
#define GEN_PASS_DECL_IMPLEMENTSHIFTNETWORK
#include "lib/Dialect/TensorExt/Transforms/Passes.h.inc"
struct ShiftSchemeResult {
ShiftScheme scheme;
Mapping cleanedMapping;
};
// Cf. https://link.springer.com/chapter/10.1007/978-3-031-17140-6_20
// for an explanation of the algorithm.
class VosVosErkinShiftNetworks {
using CacheKey = std::pair<Mapping, FrozenVector<int64_t>>;
public:
VosVosErkinShiftNetworks() = default;
// Computes a partition of the slot indices of a ciphertext into
// RotationGroups that are compatible with respect to the target permutation.
// Each RotationGroup corresponds to a set of indices that should be rotated
// together via power-of-two rotations.
//
// The returned ArrayRef is owned by this VosVosErkinShiftNetworks instance.
// The resulting set of rotation groups are is cached, and the cache is used
// on further calls to avoid recomputing the shift network.
//
// The default shiftOrder is LSB to MSB, i.e. 1, 2, 4, 8, ...
ShiftSchemeResult findShiftScheme(const Mapping& mapping,
ArrayRef<int64_t> shiftOrder = {});
// Like findShiftScheme but randomly draw from a uniform distribution over all
// possible shift orders and use the one that results in the best network.
ShiftSchemeResult findBestShiftScheme(const Mapping& mapping,
std::size_t randomSeed,
unsigned randomTries = 100);
private:
ShiftStrategy evaluateShiftStrategy(const Mapping& mapping,
ArrayRef<int64_t> shiftOrder,
bool useSources = false);
CacheKey makeCacheKey(const Mapping& mapping, ArrayRef<int64_t> shiftOrder);
DenseMap<CacheKey, ShiftStrategy> strategyCache;
DenseMap<CacheKey, ShiftScheme> schemeCache;
};
} // namespace tensor_ext
} // namespace heir
} // namespace mlir
#endif // LIB_DIALECT_TENSOREXT_TRANSFORMS_IMPLEMENTSHIFTNETWORK_H_