Skip to content

Commit 0a9d5e8

Browse files
committed
Unification of decision_tree and decision_tree_optimized in progress
1 parent 79407d8 commit 0a9d5e8

1 file changed

Lines changed: 87 additions & 112 deletions

File tree

Compiler/decision_tree.py

Lines changed: 87 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -157,32 +157,35 @@ def GroupMax(g, keys, *x):
157157
util.reveal(t), util.reveal(keys), util.reveal(x))
158158
return [GroupSum(g, t[:] * xx) for xx in [keys] + x]
159159

160-
def ModifiedGini(g, y, debug=False):
160+
def ComputeGini(g, x, y, notysum, ysum, debug=False):
161161
assert len(g) == len(y)
162162
y = [y.get_vector().bit_not(), y]
163163
u = [GroupPrefixSum(g, yy) for yy in y]
164-
s = [GroupSum(g, yy) for yy in y]
164+
s = [notysum, ysum]
165165
w = [ss - uu for ss, uu in zip(s, u)]
166166
us = sum(u)
167167
ws = sum(w)
168168
uqs = u[0] ** 2 + u[1] ** 2
169169
wqs = w[0] ** 2 + w[1] ** 2
170-
res = sfix(uqs) / us + sfix(wqs) / ws
171-
if debug:
172-
print_ln('g=%s y=%s s=%s',
173-
util.reveal(g), util.reveal(y),
174-
util.reveal(s))
175-
print_ln('u0=%s', util.reveal(u[0]))
176-
print_ln('u0=%s', util.reveal(u[1]))
177-
print_ln('us=%s', util.reveal(us))
178-
print_ln('w0=%s', util.reveal(w[0]))
179-
print_ln('w1=%s', util.reveal(w[1]))
180-
print_ln('ws=%s', util.reveal(ws))
181-
print_ln('uqs=%s', util.reveal(uqs))
182-
print_ln('wqs=%s', util.reveal(wqs))
183-
if debug:
184-
print_ln('gini %s %s', type(res), util.reveal(res))
185-
return res
170+
res_num = ws * uqs + us * wqs
171+
res_den = us * ws
172+
xx = x
173+
t = get_type(x).Array(len(x))
174+
t[-1] = MIN_VALUE
175+
t.assign_vector(xx.get_vector(size=len(x) - 1) + \
176+
xx.get_vector(size=len(x) - 1, base=1))
177+
gg = g
178+
p = sint.Array(len(x))
179+
p[-1] = 1
180+
p.assign_vector(gg.get_vector(base=1, size=len(x) - 1).bit_or(
181+
xx.get_vector(size=len(x) - 1) == \
182+
xx.get_vector(size=len(x) - 1, base=1)))
183+
break_point()
184+
res_num = p[:].if_else(MIN_VALUE, res_num)
185+
res_den = p[:].if_else(1, res_den)
186+
t = p[:].if_else(MIN_VALUE, t[:])
187+
return res_num, res_den, t
188+
186189

187190
MIN_VALUE = -10000
188191

@@ -243,6 +246,11 @@ class TreeTrainer:
243246
.. _`Hamada et al.`: https://arxiv.org/abs/2112.12906
244247
245248
"""
249+
def GetInversePermutation(self, perm):
250+
res = Array.create_from(self.identity_permutation)
251+
reveal_sort(perm, res)
252+
return res
253+
246254
def ApplyTests(self, x, AID, Threshold):
247255
m = len(x)
248256
n = len(AID)
@@ -260,96 +268,49 @@ def _(j):
260268
print_ln('threshold %s', util.reveal(Threshold))
261269
return 2 * xx < Threshold
262270

263-
def AttributeWiseTestSelection(self, g, x, y, time=False, debug=False):
264-
assert len(g) == len(x)
265-
assert len(g) == len(y)
266-
if time:
267-
start_timer(2)
268-
s = ModifiedGini(g, y, debug=debug or self.debug > 2)
269-
if time:
270-
stop_timer(2)
271-
if debug or self.debug > 1:
272-
print_ln('gini %s', s.reveal())
273-
xx = x
274-
t = get_type(x).Array(len(x))
275-
t[-1] = MIN_VALUE
276-
t.assign_vector(xx.get_vector(size=len(x) - 1) + \
277-
xx.get_vector(size=len(x) - 1, base=1))
278-
gg = g
279-
p = sint.Array(len(x))
280-
p[-1] = 1
281-
p.assign_vector(gg.get_vector(base=1, size=len(x) - 1).bit_or(
282-
xx.get_vector(size=len(x) - 1) == \
283-
xx.get_vector(size=len(x) - 1, base=1)))
284-
break_point()
285-
if debug:
286-
print_ln('attribute t=%s p=%s', util.reveal(t), util.reveal(p))
287-
s = p[:].if_else(MIN_VALUE, s)
288-
t = p[:].if_else(MIN_VALUE, t[:])
289-
if debug:
290-
print_ln('attribute s=%s t=%s', util.reveal(s), util.reveal(t))
291-
if time:
292-
start_timer(3)
293-
s, t = GroupMax(gg, s, t)
294-
if time:
295-
stop_timer(3)
296-
if debug:
297-
print_ln('attribute s=%s t=%s', util.reveal(s), util.reveal(t))
298-
return t, s
299-
300-
def GlobalTestSelection(self, x, y, g):
301-
assert len(y) == len(g)
271+
def TestSelection(self, g, x, y, pis, notysum, ysum, time=False):
302272
for xx in x:
303273
assert(len(xx) == len(g))
274+
assert len(g) == len(y)
304275
m = len(x)
305276
n = len(y)
277+
gg = g
306278
u, t = [get_type(x).Matrix(m, n) for i in range(2)]
307279
v = get_type(y).Matrix(m, n)
308-
s = sfix.Matrix(m, n)
280+
s_num = get_type(y).Matrix(m, n)
281+
s_den = get_type(y).Matrix(m, n)
282+
a = sint.Array(n)
283+
284+
notysum_arr = Array.create_from(notysum)
285+
ysum_arr = Array.create_from(ysum)
286+
309287
@for_range_multithread(self.n_threads, 1, m)
310288
def _(j):
311289
single = not self.n_threads or self.n_threads == 1
312290
time = self.time and single
313-
if debug:
291+
if self.debug_selection:
314292
print_ln('run %s', j)
315-
@if_e(self.attr_lengths[j])
316-
def _():
317-
u[j][:], v[j][:] = Sort((PrefixSum(g), x[j]), x[j], y,
318-
n_bits=[util.log2(n), 1], time=time)
319-
@else_
320-
def _():
321-
u[j][:], v[j][:] = Sort((PrefixSum(g), x[j]), x[j], y,
322-
n_bits=[util.log2(n), None],
323-
time=time)
324-
if self.debug_threading:
325-
print_ln('global sort %s %s %s', j, util.reveal(u[j]),
326-
util.reveal(v[j]))
327-
t[j][:], s[j][:] = self.AttributeWiseTestSelection(
328-
g, u[j], v[j], time=time, debug=self.debug_selection)
329-
if self.debug_threading:
330-
print_ln('global attribute %s %s %s', j, util.reveal(t[j]),
331-
util.reveal(s[j]))
332-
n = len(g)
333-
a = sint.Array(n)
334-
if self.debug_threading:
335-
print_ln('global s=%s', util.reveal(s))
336-
if self.debug_gini:
337-
print_ln('Gini indices ' + ' '.join(str(i) + ':%s' for i in range(m)),
338-
*(ss[0].reveal() for ss in s))
339-
if self.time:
340-
start_timer(4)
341-
if self.debug > 1:
342-
print_ln('s=%s', s.reveal_nested())
343-
print_ln('t=%s', t.reveal_nested())
344-
a[:], tt = VectMax((s[j][:] for j in range(m)), range(m),
345-
(t[j][:] for j in range(m)), debug=self.debug > 1)
346-
tt = Array.create_from(tt)
347-
if self.time:
348-
stop_timer(4)
349-
if self.debug > 1:
350-
print_ln('a=%s', util.reveal(a))
351-
print_ln('tt=%s', util.reveal(tt))
352-
return a[:], tt[:]
293+
u[j].assign_vector(x[j])
294+
v[j].assign_vector(y)
295+
reveal_sort(pis[j], u[j])
296+
reveal_sort(pis[j], v[j])
297+
s_num[j][:], s_den[j][:], t[j][:] = ComputeGini(g, u[j], v[j], notysum_arr, ysum_arr, debug=False)
298+
299+
ss_num, ss_den, tt, aa = VectMax((s_num[j][:] for j in range(m)), (s_den[j][:] for j in range(m)), (t[j][:] for j in range(m)), range(m), debug=self.debug)
300+
301+
aaa = get_type(y).Array(n)
302+
ttt = get_type(x).Array(n)
303+
304+
GroupMax_num, GroupMax_den, GroupMax_ttt, GroupMax_aaa = GroupMax(g, ss_num, ss_den, tt, aa)
305+
306+
f = sint.Array(n)
307+
f = (self.zeros.get_vector() == notysum).bit_or(self.zeros.get_vector() == ysum)
308+
aaa_vector, ttt_vector = f.if_else(0, GroupMax_aaa), f.if_else(MIN_VALUE, GroupMax_ttt)
309+
310+
ttt.assign_vector(ttt_vector)
311+
aaa.assign_vector(aaa_vector)
312+
313+
return aaa, ttt
353314

354315
def SetupPerm(self, g, x, y):
355316
m = len(x)
@@ -368,6 +329,30 @@ def _():
368329
time=time))
369330
return pis
370331

332+
def UpdateState(self, g, x, y, pis, NID, b, k):
333+
m = len(x)
334+
n = len(y)
335+
q = SortPerm(b)
336+
337+
y[:] = q.apply(y)
338+
NID[:] = 2 ** k * b + NID
339+
NID[:] = q.apply(NID)
340+
g[:] = GroupFirstOne(g, b.bit_not()) + GroupFirstOne(g, b)
341+
g[:] = q.apply(g)
342+
343+
b_arith = sint.Array(n)
344+
b_arith = Array.create_from(b)
345+
346+
@for_range_multithread(self.n_threads, 1, m)
347+
def _(j):
348+
x[j][:] = q.apply(x[j])
349+
b_permuted = ApplyPermutation(pis[j], b_arith)
350+
351+
pis[j] = q.apply(pis[j])
352+
pis[j] = ApplyInversePermutation(pis[j], SortPerm(b_permuted).perm)
353+
354+
return [g, x, y, NID, pis]
355+
371356
def TrainInternalNodes(self, k, x, y, g, NID):
372357
assert len(g) == len(y)
373358
for xx in x:
@@ -407,21 +392,11 @@ def train_layer(self, k):
407392
p = SortPerm(g.get_vector().bit_not())
408393

409394
self.nids[k], self.aids[k], self.thresholds[k]= FormatLayer_without_crop(g[:], NID, a, t, debug=self.debug)
410-
411-
if self.debug > 1:
412-
print_ln('layer %s:', k)
413-
for name, data in zip(('NID', 'AID', 'Thr'),
414-
(self.nids[k], self.aids[k],
415-
self.thresholds[k])):
416-
print_ln(' %s: %s', name, data.reveal())
417-
NID[:] = 2 ** k * b + NID
418-
b_not = b.bit_not()
419-
if self.debug > 1:
420-
print_ln('b_not=%s', b_not.reveal())
421-
g[:] = GroupFirstOne(g, b_not) + GroupFirstOne(g, b)
422-
y[:], g[:], NID[:], *xx = Sort([b], y, g, NID, *x, n_bits=[1])
423-
for i, xxx in enumerate(xx):
424-
x[i] = xxx
395+
self.g, self.x, self.y, self.NID, self.pis = self.UpdateState(g, x, y, pis, NID, b, k)
396+
397+
@if_(k >= (len(self.nids)-1))
398+
def _():
399+
self.label = Array.create_from(s0 < s1)
425400

426401
def __init__(self, x, y, h, binary=False, attr_lengths=None,
427402
n_threads=None):

0 commit comments

Comments
 (0)