@@ -211,7 +211,6 @@ def CropLayer(k, *v):
211211def TrainLeafNodes (h , g , y , NID ):
212212 assert len (g ) == len (y )
213213 assert len (g ) == len (NID )
214- Label = GroupSum (g , y .bit_not ()) < GroupSum (g , y )
215214 return FormatLayer (h , g , NID , Label )
216215
217216def GroupSame (g , y ):
@@ -352,6 +351,23 @@ def _():
352351 print_ln ('tt=%s' , util .reveal (tt ))
353352 return a [:], tt [:]
354353
354+ def SetupPerm (self , g , x , y ):
355+ m = len (x )
356+ n = len (y )
357+ pis = get_type (y ).Matrix (m , n )
358+ @for_range_multithread (self .n_threads , 1 , m )
359+ def _ (j ):
360+ @if_e (self .attr_lengths [j ])
361+ def _ ():
362+ pis [j ][:] = self .GetInversePermutation (GetSortPerm ([x [j ]], x [j ], y ,
363+ n_bits = [1 ], time = time ))
364+ @else_
365+ def _ ():
366+ pis [j ][:] = self .GetInversePermutation (GetSortPerm ([x [j ]], x [j ], y ,
367+ n_bits = [None ],
368+ time = time ))
369+ return pis
370+
355371 def TrainInternalNodes (self , k , x , y , g , NID ):
356372 assert len (g ) == len (y )
357373 for xx in x :
@@ -377,12 +393,21 @@ def train_layer(self, k):
377393 y = self .y
378394 g = self .g
379395 NID = self .NID
396+ pis = self .pis
380397 if self .debug > 1 :
381398 print_ln ('g=%s' , g .reveal ())
382399 print_ln ('y=%s' , y .reveal ())
383400 print_ln ('x=%s' , x .reveal_nested ())
384- self .nids [k ], self .aids [k ], self .thresholds [k ], b = \
385- self .TrainInternalNodes (k , x , y , g , NID )
401+
402+ s0 = GroupSum (g , y .get_vector ().bit_not ())
403+ s1 = GroupSum (g , y .get_vector ())
404+
405+ a , t = self .TestSelection (g , x , y , pis , s0 , s1 )
406+ b = self .ApplyTests (x , a , t )
407+ p = SortPerm (g .get_vector ().bit_not ())
408+
409+ self .nids [k ], self .aids [k ], self .thresholds [k ]= FormatLayer_without_crop (g [:], NID , a , t , debug = self .debug )
410+
386411 if self .debug > 1 :
387412 print_ln ('layer %s:' , k )
388413 for name , data in zip (('NID' , 'AID' , 'Thr' ),
@@ -422,6 +447,8 @@ def __init__(self, x, y, h, binary=False, attr_lengths=None,
422447 self .x = Matrix .create_from (x )
423448 self .nids , self .aids = [sint .Matrix (h , n ) for i in range (2 )]
424449 self .thresholds = self .x .value_type .Matrix (h , n )
450+ self .identity_permutation = sint .Array (n )
451+ self .label = sintbit .Array (n )
425452 self .n_threads = n_threads
426453 self .debug_selection = False
427454 self .debug_threading = False
@@ -431,11 +458,19 @@ def __init__(self, x, y, h, binary=False, attr_lengths=None,
431458
432459 def train (self ):
433460 """ Train and return decision tree. """
461+ n = len (self .y )
462+
463+ @for_range (n )
464+ def _ (i ):
465+ self .identity_permutation [i ] = sint (i )
466+
434467 h = len (self .nids )
468+ self .pis = self .SetupPerm (self .g , self .x , self .y )
469+
435470 @for_range (h )
436471 def _ (k ):
437472 self .train_layer (k )
438- return self .get_tree (h )
473+ return self .get_tree (h , self . label )
439474
440475 def train_with_testing (self , * test_set , output = False ):
441476 """ Train decision tree and test against test data.
@@ -459,12 +494,12 @@ def train_with_testing(self, *test_set, output=False):
459494 n_threads = self .n_threads )
460495 return tree
461496
462- def get_tree (self , h ):
497+ def get_tree (self , h , Label ):
463498 Layer = [None ] * (h + 1 )
464499 for k in range (h ):
465500 Layer [k ] = CropLayer (k , self .nids [k ], self .aids [k ],
466501 self .thresholds [k ])
467- Layer [h ] = TrainLeafNodes (h , self .g [:], self .y [:], self .NID )
502+ Layer [h ] = TrainLeafNodes (h , self .g [:], self .y [:], self .NID , Label )
468503 return Layer
469504
470505def DecisionTreeTraining (x , y , h , binary = False ):
0 commit comments