@@ -157,32 +157,35 @@ def GroupMax(g, keys, *x):
157157 util .reveal (t ), util .reveal (keys ), util .reveal (x ))
158158 return [GroupSum (g , t [:] * xx ) for xx in [keys ] + x ]
159159
160- def ModifiedGini (g , y , debug = False ):
160+ def ComputeGini (g , x , y , notysum , ysum , debug = False ):
161161 assert len (g ) == len (y )
162162 y = [y .get_vector ().bit_not (), y ]
163163 u = [GroupPrefixSum (g , yy ) for yy in y ]
164- s = [GroupSum ( g , yy ) for yy in y ]
164+ s = [notysum , ysum ]
165165 w = [ss - uu for ss , uu in zip (s , u )]
166166 us = sum (u )
167167 ws = sum (w )
168168 uqs = u [0 ] ** 2 + u [1 ] ** 2
169169 wqs = w [0 ] ** 2 + w [1 ] ** 2
170- res = sfix (uqs ) / us + sfix (wqs ) / ws
171- if debug :
172- print_ln ('g=%s y=%s s=%s' ,
173- util .reveal (g ), util .reveal (y ),
174- util .reveal (s ))
175- print_ln ('u0=%s' , util .reveal (u [0 ]))
176- print_ln ('u0=%s' , util .reveal (u [1 ]))
177- print_ln ('us=%s' , util .reveal (us ))
178- print_ln ('w0=%s' , util .reveal (w [0 ]))
179- print_ln ('w1=%s' , util .reveal (w [1 ]))
180- print_ln ('ws=%s' , util .reveal (ws ))
181- print_ln ('uqs=%s' , util .reveal (uqs ))
182- print_ln ('wqs=%s' , util .reveal (wqs ))
183- if debug :
184- print_ln ('gini %s %s' , type (res ), util .reveal (res ))
185- return res
170+ res_num = ws * uqs + us * wqs
171+ res_den = us * ws
172+ xx = x
173+ t = get_type (x ).Array (len (x ))
174+ t [- 1 ] = MIN_VALUE
175+ t .assign_vector (xx .get_vector (size = len (x ) - 1 ) + \
176+ xx .get_vector (size = len (x ) - 1 , base = 1 ))
177+ gg = g
178+ p = sint .Array (len (x ))
179+ p [- 1 ] = 1
180+ p .assign_vector (gg .get_vector (base = 1 , size = len (x ) - 1 ).bit_or (
181+ xx .get_vector (size = len (x ) - 1 ) == \
182+ xx .get_vector (size = len (x ) - 1 , base = 1 )))
183+ break_point ()
184+ res_num = p [:].if_else (MIN_VALUE , res_num )
185+ res_den = p [:].if_else (1 , res_den )
186+ t = p [:].if_else (MIN_VALUE , t [:])
187+ return res_num , res_den , t
188+
186189
187190MIN_VALUE = - 10000
188191
@@ -243,6 +246,11 @@ class TreeTrainer:
243246 .. _`Hamada et al.`: https://arxiv.org/abs/2112.12906
244247
245248 """
249+ def GetInversePermutation (self , perm ):
250+ res = Array .create_from (self .identity_permutation )
251+ reveal_sort (perm , res )
252+ return res
253+
246254 def ApplyTests (self , x , AID , Threshold ):
247255 m = len (x )
248256 n = len (AID )
@@ -260,96 +268,49 @@ def _(j):
260268 print_ln ('threshold %s' , util .reveal (Threshold ))
261269 return 2 * xx < Threshold
262270
263- def AttributeWiseTestSelection (self , g , x , y , time = False , debug = False ):
264- assert len (g ) == len (x )
265- assert len (g ) == len (y )
266- if time :
267- start_timer (2 )
268- s = ModifiedGini (g , y , debug = debug or self .debug > 2 )
269- if time :
270- stop_timer (2 )
271- if debug or self .debug > 1 :
272- print_ln ('gini %s' , s .reveal ())
273- xx = x
274- t = get_type (x ).Array (len (x ))
275- t [- 1 ] = MIN_VALUE
276- t .assign_vector (xx .get_vector (size = len (x ) - 1 ) + \
277- xx .get_vector (size = len (x ) - 1 , base = 1 ))
278- gg = g
279- p = sint .Array (len (x ))
280- p [- 1 ] = 1
281- p .assign_vector (gg .get_vector (base = 1 , size = len (x ) - 1 ).bit_or (
282- xx .get_vector (size = len (x ) - 1 ) == \
283- xx .get_vector (size = len (x ) - 1 , base = 1 )))
284- break_point ()
285- if debug :
286- print_ln ('attribute t=%s p=%s' , util .reveal (t ), util .reveal (p ))
287- s = p [:].if_else (MIN_VALUE , s )
288- t = p [:].if_else (MIN_VALUE , t [:])
289- if debug :
290- print_ln ('attribute s=%s t=%s' , util .reveal (s ), util .reveal (t ))
291- if time :
292- start_timer (3 )
293- s , t = GroupMax (gg , s , t )
294- if time :
295- stop_timer (3 )
296- if debug :
297- print_ln ('attribute s=%s t=%s' , util .reveal (s ), util .reveal (t ))
298- return t , s
299-
300- def GlobalTestSelection (self , x , y , g ):
301- assert len (y ) == len (g )
271+ def TestSelection (self , g , x , y , pis , notysum , ysum , time = False ):
302272 for xx in x :
303273 assert (len (xx ) == len (g ))
274+ assert len (g ) == len (y )
304275 m = len (x )
305276 n = len (y )
277+ gg = g
306278 u , t = [get_type (x ).Matrix (m , n ) for i in range (2 )]
307279 v = get_type (y ).Matrix (m , n )
308- s = sfix .Matrix (m , n )
280+ s_num = get_type (y ).Matrix (m , n )
281+ s_den = get_type (y ).Matrix (m , n )
282+ a = sint .Array (n )
283+
284+ notysum_arr = Array .create_from (notysum )
285+ ysum_arr = Array .create_from (ysum )
286+
309287 @for_range_multithread (self .n_threads , 1 , m )
310288 def _ (j ):
311289 single = not self .n_threads or self .n_threads == 1
312290 time = self .time and single
313- if debug :
291+ if self . debug_selection :
314292 print_ln ('run %s' , j )
315- @if_e (self .attr_lengths [j ])
316- def _ ():
317- u [j ][:], v [j ][:] = Sort ((PrefixSum (g ), x [j ]), x [j ], y ,
318- n_bits = [util .log2 (n ), 1 ], time = time )
319- @else_
320- def _ ():
321- u [j ][:], v [j ][:] = Sort ((PrefixSum (g ), x [j ]), x [j ], y ,
322- n_bits = [util .log2 (n ), None ],
323- time = time )
324- if self .debug_threading :
325- print_ln ('global sort %s %s %s' , j , util .reveal (u [j ]),
326- util .reveal (v [j ]))
327- t [j ][:], s [j ][:] = self .AttributeWiseTestSelection (
328- g , u [j ], v [j ], time = time , debug = self .debug_selection )
329- if self .debug_threading :
330- print_ln ('global attribute %s %s %s' , j , util .reveal (t [j ]),
331- util .reveal (s [j ]))
332- n = len (g )
333- a = sint .Array (n )
334- if self .debug_threading :
335- print_ln ('global s=%s' , util .reveal (s ))
336- if self .debug_gini :
337- print_ln ('Gini indices ' + ' ' .join (str (i ) + ':%s' for i in range (m )),
338- * (ss [0 ].reveal () for ss in s ))
339- if self .time :
340- start_timer (4 )
341- if self .debug > 1 :
342- print_ln ('s=%s' , s .reveal_nested ())
343- print_ln ('t=%s' , t .reveal_nested ())
344- a [:], tt = VectMax ((s [j ][:] for j in range (m )), range (m ),
345- (t [j ][:] for j in range (m )), debug = self .debug > 1 )
346- tt = Array .create_from (tt )
347- if self .time :
348- stop_timer (4 )
349- if self .debug > 1 :
350- print_ln ('a=%s' , util .reveal (a ))
351- print_ln ('tt=%s' , util .reveal (tt ))
352- return a [:], tt [:]
293+ u [j ].assign_vector (x [j ])
294+ v [j ].assign_vector (y )
295+ reveal_sort (pis [j ], u [j ])
296+ reveal_sort (pis [j ], v [j ])
297+ s_num [j ][:], s_den [j ][:], t [j ][:] = ComputeGini (g , u [j ], v [j ], notysum_arr , ysum_arr , debug = False )
298+
299+ ss_num , ss_den , tt , aa = VectMax ((s_num [j ][:] for j in range (m )), (s_den [j ][:] for j in range (m )), (t [j ][:] for j in range (m )), range (m ), debug = self .debug )
300+
301+ aaa = get_type (y ).Array (n )
302+ ttt = get_type (x ).Array (n )
303+
304+ GroupMax_num , GroupMax_den , GroupMax_ttt , GroupMax_aaa = GroupMax (g , ss_num , ss_den , tt , aa )
305+
306+ f = sint .Array (n )
307+ f = (self .zeros .get_vector () == notysum ).bit_or (self .zeros .get_vector () == ysum )
308+ aaa_vector , ttt_vector = f .if_else (0 , GroupMax_aaa ), f .if_else (MIN_VALUE , GroupMax_ttt )
309+
310+ ttt .assign_vector (ttt_vector )
311+ aaa .assign_vector (aaa_vector )
312+
313+ return aaa , ttt
353314
354315 def SetupPerm (self , g , x , y ):
355316 m = len (x )
@@ -368,6 +329,30 @@ def _():
368329 time = time ))
369330 return pis
370331
332+ def UpdateState (self , g , x , y , pis , NID , b , k ):
333+ m = len (x )
334+ n = len (y )
335+ q = SortPerm (b )
336+
337+ y [:] = q .apply (y )
338+ NID [:] = 2 ** k * b + NID
339+ NID [:] = q .apply (NID )
340+ g [:] = GroupFirstOne (g , b .bit_not ()) + GroupFirstOne (g , b )
341+ g [:] = q .apply (g )
342+
343+ b_arith = sint .Array (n )
344+ b_arith = Array .create_from (b )
345+
346+ @for_range_multithread (self .n_threads , 1 , m )
347+ def _ (j ):
348+ x [j ][:] = q .apply (x [j ])
349+ b_permuted = ApplyPermutation (pis [j ], b_arith )
350+
351+ pis [j ] = q .apply (pis [j ])
352+ pis [j ] = ApplyInversePermutation (pis [j ], SortPerm (b_permuted ).perm )
353+
354+ return [g , x , y , NID , pis ]
355+
371356 def TrainInternalNodes (self , k , x , y , g , NID ):
372357 assert len (g ) == len (y )
373358 for xx in x :
@@ -407,21 +392,11 @@ def train_layer(self, k):
407392 p = SortPerm (g .get_vector ().bit_not ())
408393
409394 self .nids [k ], self .aids [k ], self .thresholds [k ]= FormatLayer_without_crop (g [:], NID , a , t , debug = self .debug )
410-
411- if self .debug > 1 :
412- print_ln ('layer %s:' , k )
413- for name , data in zip (('NID' , 'AID' , 'Thr' ),
414- (self .nids [k ], self .aids [k ],
415- self .thresholds [k ])):
416- print_ln (' %s: %s' , name , data .reveal ())
417- NID [:] = 2 ** k * b + NID
418- b_not = b .bit_not ()
419- if self .debug > 1 :
420- print_ln ('b_not=%s' , b_not .reveal ())
421- g [:] = GroupFirstOne (g , b_not ) + GroupFirstOne (g , b )
422- y [:], g [:], NID [:], * xx = Sort ([b ], y , g , NID , * x , n_bits = [1 ])
423- for i , xxx in enumerate (xx ):
424- x [i ] = xxx
395+ self .g , self .x , self .y , self .NID , self .pis = self .UpdateState (g , x , y , pis , NID , b , k )
396+
397+ @if_ (k >= (len (self .nids )- 1 ))
398+ def _ ():
399+ self .label = Array .create_from (s0 < s1 )
425400
426401 def __init__ (self , x , y , h , binary = False , attr_lengths = None ,
427402 n_threads = None ):
0 commit comments