11from Compiler .types import *
22from Compiler .sorting import *
33from Compiler .library import *
4+ from Compiler .decision_tree import PrefixSum , PrefixSumR , PrefixSum_inv , PrefixSumR_inv , SortPerm , GroupSum , GroupPrefixSum , GroupFirstOne
45from Compiler import util , oram
56
67from itertools import accumulate
@@ -37,29 +38,6 @@ def GetSortPerm(keys, *to_sort, n_bits=None, time=False):
3738 res = res .transpose ()
3839 return radix_sort_permutation_from_matrix (bs , res )
3940
40- def PrefixSum (x ):
41- return x .get_vector ().prefix_sum ()
42-
43- def PrefixSumR (x ):
44- tmp = get_type (x ).Array (len (x ))
45- tmp .assign_vector (x )
46- break_point ()
47- tmp [:] = tmp .get_reverse_vector ().prefix_sum ()
48- break_point ()
49- return tmp .get_reverse_vector ()
50-
51- def PrefixSum_inv (x ):
52- tmp = get_type (x ).Array (len (x ) + 1 )
53- tmp .assign_vector (x , base = 1 )
54- tmp [0 ] = 0
55- return tmp .get_vector (size = len (x ), base = 1 ) - tmp .get_vector (size = len (x ))
56-
57- def PrefixSumR_inv (x ):
58- tmp = get_type (x ).Array (len (x ) + 1 )
59- tmp .assign_vector (x )
60- tmp [- 1 ] = 0
61- return tmp .get_vector (size = len (x )) - tmp .get_vector (base = 1 , size = len (x ))
62-
6341def ApplyPermutation (perm , x ):
6442 res = Array .create_from (x )
6543 reveal_sort (perm , res , False )
@@ -70,71 +48,13 @@ def ApplyInversePermutation(perm, x):
7048 reveal_sort (perm , res , True )
7149 return res
7250
73- class SortPerm :
74- def __init__ (self , x ):
75- B = sint .Matrix (len (x ), 2 )
76- B .set_column (0 , 1 - x .get_vector ())
77- B .set_column (1 , x .get_vector ())
78- self .perm = Array .create_from (dest_comp (B ))
79- def apply (self , x ):
80- res = Array .create_from (x )
81- reveal_sort (self .perm , res , False )
82- return res
83- def unapply (self , x ):
84- res = Array .create_from (x )
85- reveal_sort (self .perm , res , True )
86- return res
87-
88- def Sort (keys , * to_sort , n_bits = None , time = False ):
89- if time :
90- start_timer (1 )
91- for k in keys :
92- assert len (k ) == len (keys [0 ])
93- n_bits = n_bits or [None ] * len (keys )
94- bs = Matrix .create_from (
95- sum ([k .get_vector ().bit_decompose (nb )
96- for k , nb in reversed (list (zip (keys , n_bits )))], []))
97- get_vec = lambda x : x [:] if isinstance (x , Array ) else x
98- res = Matrix .create_from (get_vec (x ).v if isinstance (get_vec (x ), sfix ) else x
99- for x in to_sort )
100- res = res .transpose ()
101- if time :
102- start_timer (11 )
103- radix_sort_from_matrix (bs , res )
104- if time :
105- stop_timer (11 )
106- stop_timer (1 )
107- res = res .transpose ()
108- return [sfix ._new (get_vec (x ), k = get_vec (y ).k , f = get_vec (y ).f )
109- if isinstance (get_vec (y ), sfix )
110- else x for (x , y ) in zip (res , to_sort )]
111-
11251def VectMax (key , * data , debug = False ):
11352 def reducer (x , y ):
11453 b = x [0 ]* y [1 ] > y [0 ]* x [1 ]
11554 return [b .if_else (xx , yy ) for xx , yy in zip (x , y )]
11655 res = util .tree_reduce (reducer , zip (key , * data ))
11756 return res
11857
119- def GroupSum (g , x ):
120- assert len (g ) == len (x )
121- p = PrefixSumR (x ) * g
122- pi = SortPerm (g .get_vector ().bit_not ())
123- p1 = pi .apply (p )
124- s1 = PrefixSumR_inv (p1 )
125- d1 = PrefixSum_inv (s1 )
126- d = pi .unapply (d1 ) * g
127- return PrefixSum (d )
128-
129- def GroupPrefixSum (g , x ):
130- assert len (g ) == len (x )
131- s = get_type (x ).Array (len (x ) + 1 )
132- s [0 ] = 0
133- s .assign_vector (PrefixSum (x ), base = 1 )
134- q = get_type (s ).Array (len (x ))
135- q .assign_vector (s .get_vector (size = len (x )) * g )
136- return s .get_vector (size = len (x ), base = 1 ) - GroupSum (g , q )
137-
13858def Custom_GT_Fractions (x_num , x_den , y_num , y_den , n_threads = 2 ):
13959 b = (x_num * y_den ) > (x_den * y_num )
14060 b = Array .create_from (b ).get_vector ()
@@ -223,11 +143,6 @@ def TrainLeafNodes(h, g, y, NID, Label, debug=False):
223143 assert len (g ) == len (NID )
224144 return FormatLayer (h , g , NID , Label , debug = debug )
225145
226- def GroupFirstOne (g , b ):
227- assert len (g ) == len (b )
228- s = GroupPrefixSum (g , b )
229- return s * b == 1
230-
231146class TreeTrainer :
232147 def GetInversePermutation (self , perm , n_threads = 2 ):
233148 res = Array .create_from (self .identity_permutation )
0 commit comments