@@ -425,6 +425,28 @@ def _(i):
425425 def _ (k ):
426426 self .train_layer (k )
427427 return self .get_tree (h , self .label )
428+
429+ def train_with_testing (self , * test_set , output = False ):
430+ """ Train decision tree and test against test data.
431+
432+ :param y: binary labels (list or sint vector)
433+ :param x: sample data (by attribute, list or
434+ :py:obj:`~Compiler.types.Matrix`)
435+ :param output: output tree after every level
436+ :returns: tree
437+
438+ """
439+ for k in range (len (self .nids )):
440+ self .train_layer (k )
441+ tree = self .get_tree (k + 1 , self .label )
442+ if output :
443+ output_decision_tree (tree )
444+ test_decision_tree ('train' , tree , self .y , self .x ,
445+ n_threads = self .n_threads )
446+ if test_set :
447+ test_decision_tree ('test' , tree , * test_set ,
448+ n_threads = self .n_threads )
449+ return tree
428450
429451 def get_tree (self , h , Label ):
430452 Layer = [None ] * (h + 1 )
@@ -449,6 +471,69 @@ def output_decision_tree(layers):
449471 for j , x in enumerate (('NID' , 'result' )):
450472 print_ln (' %s: %s' , x , util .reveal (layers [- 1 ][j ]))
451473
474+ def pick (bits , x ):
475+ if len (bits ) == 1 :
476+ return bits [0 ] * x [0 ]
477+ else :
478+ try :
479+ return x [0 ].dot_product (bits , x )
480+ except :
481+ return sum (aa * bb for aa , bb in zip (bits , x ))
482+
483+ def run_decision_tree (layers , data ):
484+ """ Run decision tree against sample data.
485+
486+ :param layers: tree output by :py:class:`TreeTrainer`
487+ :param data: sample data (:py:class:`~Compiler.types.Array`)
488+ :returns: binary label
489+
490+ """
491+ h = len (layers ) - 1
492+ index = 1
493+ for k , layer in enumerate (layers [:- 1 ]):
494+ assert len (layer ) == 3
495+ for x in layer :
496+ assert len (x ) <= 2 ** k
497+ bits = layer [0 ].equal (index , k )
498+ threshold = pick (bits , layer [2 ])
499+ key_index = pick (bits , layer [1 ])
500+ if key_index .is_clear :
501+ key = data [key_index ]
502+ else :
503+ key = pick (
504+ oram .demux (key_index .bit_decompose (util .log2 (len (data )))), data )
505+ child = 2 * key < threshold
506+ index += child * 2 ** k
507+ bits = layers [h ][0 ].equal (index , h )
508+ return pick (bits , layers [h ][1 ])
509+
510+ def test_decision_tree (name , layers , y , x , n_threads = None , time = False ):
511+ if time :
512+ start_timer (100 )
513+ n = len (y )
514+ x = x .transpose ().reveal ()
515+ y = y .reveal ()
516+ guess = regint .Array (n )
517+ truth = regint .Array (n )
518+ correct = regint .Array (2 )
519+ parts = regint .Array (2 )
520+ layers = [[Array .create_from (util .reveal (x )) for x in layer ]
521+ for layer in layers ]
522+ @for_range_multithread (n_threads , 1 , n )
523+ def _ (i ):
524+ guess [i ] = run_decision_tree ([[part [:] for part in layer ]
525+ for layer in layers ], x [i ]).reveal ()
526+ truth [i ] = y [i ].reveal ()
527+ @for_range (n )
528+ def _ (i ):
529+ parts [truth [i ]] += 1
530+ c = (guess [i ].bit_xor (truth [i ]).bit_not ())
531+ correct [truth [i ]] += c
532+ print_ln ('%s for height %s: %s/%s (%s/%s, %s/%s)' , name , len (layers ) - 1 ,
533+ sum (correct ), n , correct [0 ], parts [0 ], correct [1 ], parts [1 ])
534+ if time :
535+ stop_timer (100 )
536+
452537class TreeClassifier :
453538 """ Tree classification that uses
454539 :py:class:`TreeTrainer` internally.
@@ -482,3 +567,39 @@ def fit(self, X, y, attr_types=None):
482567
483568 def output (self ):
484569 output_decision_tree (self .tree )
570+
571+ def fit_with_testing (self , X_train , y_train , X_test , y_test ,
572+ attr_types = None , output_trees = False , debug = False ):
573+ """ Train tree with accuracy output after every level.
574+
575+ :param X_train: training data with row-wise samples (sint/sfix matrix)
576+ :param y_train: training binary labels (sint list/array)
577+ :param X_test: testing data with row-wise samples (sint/sfix matrix)
578+ :param y_test: testing binary labels (sint list/array)
579+ :param attr_types: attributes types (list of 'b'/'c' for
580+ binary/continuous; default is all continuous)
581+ :param output_trees: output tree after every level
582+ :param debug: output debugging information
583+
584+ """
585+ trainer = TreeTrainer (X_train .transpose (), y_train , self .max_depth ,
586+ attr_lengths = self .get_attr_lengths (attr_types ),
587+ n_threads = self .n_threads )
588+ trainer .debug = debug
589+ trainer .debug_gini = debug
590+ trainer .debug_threading = debug > 1
591+ self .tree = trainer .train_with_testing (y_test , X_test .transpose (),
592+ output = output_trees )
593+
594+ def predict (self , X ):
595+ """ Use tree for prediction.
596+
597+ :param X: sample data with row-wise samples (sint/sfix matrix)
598+ :returns: sint array
599+
600+ """
601+ res = sint .Array (len (X ))
602+ @for_range (len (X ))
603+ def _ (i ):
604+ res [i ] = run_decision_tree (self .tree , X [i ])
605+ return res
0 commit comments