Skip to content

Commit 1788401

Browse files
committed
Changed decision_tree_new to decision_tree_optimized
1 parent f7ed22e commit 1788401

3 files changed

Lines changed: 123 additions & 2 deletions

File tree

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,28 @@ def _(i):
425425
def _(k):
426426
self.train_layer(k)
427427
return self.get_tree(h, self.label)
428+
429+
def train_with_testing(self, *test_set, output=False):
430+
""" Train decision tree and test against test data.
431+
432+
:param y: binary labels (list or sint vector)
433+
:param x: sample data (by attribute, list or
434+
:py:obj:`~Compiler.types.Matrix`)
435+
:param output: output tree after every level
436+
:returns: tree
437+
438+
"""
439+
for k in range(len(self.nids)):
440+
self.train_layer(k)
441+
tree = self.get_tree(k + 1, self.label)
442+
if output:
443+
output_decision_tree(tree)
444+
test_decision_tree('train', tree, self.y, self.x,
445+
n_threads=self.n_threads)
446+
if test_set:
447+
test_decision_tree('test', tree, *test_set,
448+
n_threads=self.n_threads)
449+
return tree
428450

429451
def get_tree(self, h, Label):
430452
Layer = [None] * (h + 1)
@@ -449,6 +471,69 @@ def output_decision_tree(layers):
449471
for j, x in enumerate(('NID', 'result')):
450472
print_ln(' %s: %s', x, util.reveal(layers[-1][j]))
451473

474+
def pick(bits, x):
475+
if len(bits) == 1:
476+
return bits[0] * x[0]
477+
else:
478+
try:
479+
return x[0].dot_product(bits, x)
480+
except:
481+
return sum(aa * bb for aa, bb in zip(bits, x))
482+
483+
def run_decision_tree(layers, data):
484+
""" Run decision tree against sample data.
485+
486+
:param layers: tree output by :py:class:`TreeTrainer`
487+
:param data: sample data (:py:class:`~Compiler.types.Array`)
488+
:returns: binary label
489+
490+
"""
491+
h = len(layers) - 1
492+
index = 1
493+
for k, layer in enumerate(layers[:-1]):
494+
assert len(layer) == 3
495+
for x in layer:
496+
assert len(x) <= 2 ** k
497+
bits = layer[0].equal(index, k)
498+
threshold = pick(bits, layer[2])
499+
key_index = pick(bits, layer[1])
500+
if key_index.is_clear:
501+
key = data[key_index]
502+
else:
503+
key = pick(
504+
oram.demux(key_index.bit_decompose(util.log2(len(data)))), data)
505+
child = 2 * key < threshold
506+
index += child * 2 ** k
507+
bits = layers[h][0].equal(index, h)
508+
return pick(bits, layers[h][1])
509+
510+
def test_decision_tree(name, layers, y, x, n_threads=None, time=False):
511+
if time:
512+
start_timer(100)
513+
n = len(y)
514+
x = x.transpose().reveal()
515+
y = y.reveal()
516+
guess = regint.Array(n)
517+
truth = regint.Array(n)
518+
correct = regint.Array(2)
519+
parts = regint.Array(2)
520+
layers = [[Array.create_from(util.reveal(x)) for x in layer]
521+
for layer in layers]
522+
@for_range_multithread(n_threads, 1, n)
523+
def _(i):
524+
guess[i] = run_decision_tree([[part[:] for part in layer]
525+
for layer in layers], x[i]).reveal()
526+
truth[i] = y[i].reveal()
527+
@for_range(n)
528+
def _(i):
529+
parts[truth[i]] += 1
530+
c = (guess[i].bit_xor(truth[i]).bit_not())
531+
correct[truth[i]] += c
532+
print_ln('%s for height %s: %s/%s (%s/%s, %s/%s)', name, len(layers) - 1,
533+
sum(correct), n, correct[0], parts[0], correct[1], parts[1])
534+
if time:
535+
stop_timer(100)
536+
452537
class TreeClassifier:
453538
""" Tree classification that uses
454539
:py:class:`TreeTrainer` internally.
@@ -482,3 +567,39 @@ def fit(self, X, y, attr_types=None):
482567

483568
def output(self):
484569
output_decision_tree(self.tree)
570+
571+
def fit_with_testing(self, X_train, y_train, X_test, y_test,
572+
attr_types=None, output_trees=False, debug=False):
573+
""" Train tree with accuracy output after every level.
574+
575+
:param X_train: training data with row-wise samples (sint/sfix matrix)
576+
:param y_train: training binary labels (sint list/array)
577+
:param X_test: testing data with row-wise samples (sint/sfix matrix)
578+
:param y_test: testing binary labels (sint list/array)
579+
:param attr_types: attributes types (list of 'b'/'c' for
580+
binary/continuous; default is all continuous)
581+
:param output_trees: output tree after every level
582+
:param debug: output debugging information
583+
584+
"""
585+
trainer = TreeTrainer(X_train.transpose(), y_train, self.max_depth,
586+
attr_lengths=self.get_attr_lengths(attr_types),
587+
n_threads=self.n_threads)
588+
trainer.debug = debug
589+
trainer.debug_gini = debug
590+
trainer.debug_threading = debug > 1
591+
self.tree = trainer.train_with_testing(y_test, X_test.transpose(),
592+
output=output_trees)
593+
594+
def predict(self, X):
595+
""" Use tree for prediction.
596+
597+
:param X: sample data with row-wise samples (sint/sfix matrix)
598+
:returns: sint array
599+
600+
"""
601+
res = sint.Array(len(X))
602+
@for_range(len(X))
603+
def _(i):
604+
res[i] = run_decision_tree(self.tree, X[i])
605+
return res

Programs/Source/breast_tree.mpc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ y_test = sint.input_tensor_via(0, y_test)
1616

1717
sfix.set_precision_from_args(program)
1818

19-
from Compiler.decision_tree import TreeClassifier
19+
from Compiler.decision_tree_optimized import TreeClassifier
2020

2121
tree = TreeClassifier(max_depth=5, n_threads=2)
2222

Programs/Source/custom_data_dt.mpc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ df_y = Array.create_from(df_y[0])
1919
program.set_bit_length(32)
2020
sfix.set_precision(16, 31)
2121

22-
from Compiler.decision_tree_new import TreeClassifier
22+
from Compiler.decision_tree_optimized import TreeClassifier
2323

2424
tree = TreeClassifier(max_depth=int(program.args[3]), n_threads=4)
2525

0 commit comments

Comments
 (0)