Skip to content

Commit caf1818

Browse files
committed
Removed some redundant blocks
1 parent 1788401 commit caf1818

1 file changed

Lines changed: 1 addition & 86 deletions

File tree

Compiler/decision_tree_optimized.py

Lines changed: 1 addition & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from Compiler.types import *
22
from Compiler.sorting import *
33
from Compiler.library import *
4+
from Compiler.decision_tree import PrefixSum, PrefixSumR, PrefixSum_inv, PrefixSumR_inv, SortPerm, GroupSum, GroupPrefixSum, GroupFirstOne
45
from Compiler import util, oram
56

67
from itertools import accumulate
@@ -37,29 +38,6 @@ def GetSortPerm(keys, *to_sort, n_bits=None, time=False):
3738
res = res.transpose()
3839
return radix_sort_permutation_from_matrix(bs, res)
3940

40-
def PrefixSum(x):
41-
return x.get_vector().prefix_sum()
42-
43-
def PrefixSumR(x):
44-
tmp = get_type(x).Array(len(x))
45-
tmp.assign_vector(x)
46-
break_point()
47-
tmp[:] = tmp.get_reverse_vector().prefix_sum()
48-
break_point()
49-
return tmp.get_reverse_vector()
50-
51-
def PrefixSum_inv(x):
52-
tmp = get_type(x).Array(len(x) + 1)
53-
tmp.assign_vector(x, base=1)
54-
tmp[0] = 0
55-
return tmp.get_vector(size=len(x), base=1) - tmp.get_vector(size=len(x))
56-
57-
def PrefixSumR_inv(x):
58-
tmp = get_type(x).Array(len(x) + 1)
59-
tmp.assign_vector(x)
60-
tmp[-1] = 0
61-
return tmp.get_vector(size=len(x)) - tmp.get_vector(base=1, size=len(x))
62-
6341
def ApplyPermutation(perm, x):
6442
res = Array.create_from(x)
6543
reveal_sort(perm, res, False)
@@ -70,71 +48,13 @@ def ApplyInversePermutation(perm, x):
7048
reveal_sort(perm, res, True)
7149
return res
7250

73-
class SortPerm:
74-
def __init__(self, x):
75-
B = sint.Matrix(len(x), 2)
76-
B.set_column(0, 1 - x.get_vector())
77-
B.set_column(1, x.get_vector())
78-
self.perm = Array.create_from(dest_comp(B))
79-
def apply(self, x):
80-
res = Array.create_from(x)
81-
reveal_sort(self.perm, res, False)
82-
return res
83-
def unapply(self, x):
84-
res = Array.create_from(x)
85-
reveal_sort(self.perm, res, True)
86-
return res
87-
88-
def Sort(keys, *to_sort, n_bits=None, time=False):
89-
if time:
90-
start_timer(1)
91-
for k in keys:
92-
assert len(k) == len(keys[0])
93-
n_bits = n_bits or [None] * len(keys)
94-
bs = Matrix.create_from(
95-
sum([k.get_vector().bit_decompose(nb)
96-
for k, nb in reversed(list(zip(keys, n_bits)))], []))
97-
get_vec = lambda x: x[:] if isinstance(x, Array) else x
98-
res = Matrix.create_from(get_vec(x).v if isinstance(get_vec(x), sfix) else x
99-
for x in to_sort)
100-
res = res.transpose()
101-
if time:
102-
start_timer(11)
103-
radix_sort_from_matrix(bs, res)
104-
if time:
105-
stop_timer(11)
106-
stop_timer(1)
107-
res = res.transpose()
108-
return [sfix._new(get_vec(x), k=get_vec(y).k, f=get_vec(y).f)
109-
if isinstance(get_vec(y), sfix)
110-
else x for (x, y) in zip(res, to_sort)]
111-
11251
def VectMax(key, *data, debug=False):
11352
def reducer(x, y):
11453
b = x[0]*y[1] > y[0]*x[1]
11554
return [b.if_else(xx, yy) for xx, yy in zip(x, y)]
11655
res = util.tree_reduce(reducer, zip(key, *data))
11756
return res
11857

119-
def GroupSum(g, x):
120-
assert len(g) == len(x)
121-
p = PrefixSumR(x) * g
122-
pi = SortPerm(g.get_vector().bit_not())
123-
p1 = pi.apply(p)
124-
s1 = PrefixSumR_inv(p1)
125-
d1 = PrefixSum_inv(s1)
126-
d = pi.unapply(d1) * g
127-
return PrefixSum(d)
128-
129-
def GroupPrefixSum(g, x):
130-
assert len(g) == len(x)
131-
s = get_type(x).Array(len(x) + 1)
132-
s[0] = 0
133-
s.assign_vector(PrefixSum(x), base=1)
134-
q = get_type(s).Array(len(x))
135-
q.assign_vector(s.get_vector(size=len(x)) * g)
136-
return s.get_vector(size=len(x), base=1) - GroupSum(g, q)
137-
13858
def Custom_GT_Fractions(x_num, x_den, y_num, y_den, n_threads=2):
13959
b = (x_num*y_den) > (x_den*y_num)
14060
b = Array.create_from(b).get_vector()
@@ -223,11 +143,6 @@ def TrainLeafNodes(h, g, y, NID, Label, debug=False):
223143
assert len(g) == len(NID)
224144
return FormatLayer(h, g, NID, Label, debug=debug)
225145

226-
def GroupFirstOne(g, b):
227-
assert len(g) == len(b)
228-
s = GroupPrefixSum(g, b)
229-
return s * b == 1
230-
231146
class TreeTrainer:
232147
def GetInversePermutation(self, perm, n_threads=2):
233148
res = Array.create_from(self.identity_permutation)

0 commit comments

Comments
 (0)