Skip to content

Commit 79407d8

Browse files
committed
Modifying decision_tree.py
1 parent b6499cf commit 79407d8

2 files changed

Lines changed: 42 additions & 7 deletions

File tree

Compiler/decision_tree.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,6 @@ def CropLayer(k, *v):
211211
def TrainLeafNodes(h, g, y, NID):
212212
assert len(g) == len(y)
213213
assert len(g) == len(NID)
214-
Label = GroupSum(g, y.bit_not()) < GroupSum(g, y)
215214
return FormatLayer(h, g, NID, Label)
216215

217216
def GroupSame(g, y):
@@ -352,6 +351,23 @@ def _():
352351
print_ln('tt=%s', util.reveal(tt))
353352
return a[:], tt[:]
354353

354+
def SetupPerm(self, g, x, y):
355+
m = len(x)
356+
n = len(y)
357+
pis = get_type(y).Matrix(m, n)
358+
@for_range_multithread(self.n_threads, 1, m)
359+
def _(j):
360+
@if_e(self.attr_lengths[j])
361+
def _():
362+
pis[j][:] = self.GetInversePermutation(GetSortPerm([x[j]], x[j], y,
363+
n_bits=[1], time=time))
364+
@else_
365+
def _():
366+
pis[j][:] = self.GetInversePermutation(GetSortPerm([x[j]], x[j], y,
367+
n_bits=[None],
368+
time=time))
369+
return pis
370+
355371
def TrainInternalNodes(self, k, x, y, g, NID):
356372
assert len(g) == len(y)
357373
for xx in x:
@@ -377,12 +393,21 @@ def train_layer(self, k):
377393
y = self.y
378394
g = self.g
379395
NID = self.NID
396+
pis = self.pis
380397
if self.debug > 1:
381398
print_ln('g=%s', g.reveal())
382399
print_ln('y=%s', y.reveal())
383400
print_ln('x=%s', x.reveal_nested())
384-
self.nids[k], self.aids[k], self.thresholds[k], b = \
385-
self.TrainInternalNodes(k, x, y, g, NID)
401+
402+
s0 = GroupSum(g, y.get_vector().bit_not())
403+
s1 = GroupSum(g, y.get_vector())
404+
405+
a, t = self.TestSelection(g, x, y, pis, s0, s1)
406+
b = self.ApplyTests(x, a, t)
407+
p = SortPerm(g.get_vector().bit_not())
408+
409+
self.nids[k], self.aids[k], self.thresholds[k]= FormatLayer_without_crop(g[:], NID, a, t, debug=self.debug)
410+
386411
if self.debug > 1:
387412
print_ln('layer %s:', k)
388413
for name, data in zip(('NID', 'AID', 'Thr'),
@@ -422,6 +447,8 @@ def __init__(self, x, y, h, binary=False, attr_lengths=None,
422447
self.x = Matrix.create_from(x)
423448
self.nids, self.aids = [sint.Matrix(h, n) for i in range(2)]
424449
self.thresholds = self.x.value_type.Matrix(h, n)
450+
self.identity_permutation = sint.Array(n)
451+
self.label = sintbit.Array(n)
425452
self.n_threads = n_threads
426453
self.debug_selection = False
427454
self.debug_threading = False
@@ -431,11 +458,19 @@ def __init__(self, x, y, h, binary=False, attr_lengths=None,
431458

432459
def train(self):
433460
""" Train and return decision tree. """
461+
n = len(self.y)
462+
463+
@for_range(n)
464+
def _(i):
465+
self.identity_permutation[i] = sint(i)
466+
434467
h = len(self.nids)
468+
self.pis = self.SetupPerm(self.g, self.x, self.y)
469+
435470
@for_range(h)
436471
def _(k):
437472
self.train_layer(k)
438-
return self.get_tree(h)
473+
return self.get_tree(h, self.label)
439474

440475
def train_with_testing(self, *test_set, output=False):
441476
""" Train decision tree and test against test data.
@@ -459,12 +494,12 @@ def train_with_testing(self, *test_set, output=False):
459494
n_threads=self.n_threads)
460495
return tree
461496

462-
def get_tree(self, h):
497+
def get_tree(self, h, Label):
463498
Layer = [None] * (h + 1)
464499
for k in range(h):
465500
Layer[k] = CropLayer(k, self.nids[k], self.aids[k],
466501
self.thresholds[k])
467-
Layer[h] = TrainLeafNodes(h, self.g[:], self.y[:], self.NID)
502+
Layer[h] = TrainLeafNodes(h, self.g[:], self.y[:], self.NID, Label)
468503
return Layer
469504

470505
def DecisionTreeTraining(x, y, h, binary=False):

Programs/Source/custom_data_dt.mpc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ df_y = Array.create_from(df_y[0])
1919
program.set_bit_length(32)
2020
sfix.set_precision(16, 31)
2121

22-
from Compiler.decision_tree_optimized import TreeClassifier
22+
from Compiler.decision_tree import TreeClassifier
2323

2424
tree = TreeClassifier(max_depth=int(program.args[3]), n_threads=4)
2525

0 commit comments

Comments
 (0)