-
Notifications
You must be signed in to change notification settings - Fork 356
Expand file tree
/
Copy pathprogram.py
More file actions
1778 lines (1572 loc) · 64.4 KB
/
program.py
File metadata and controls
1778 lines (1572 loc) · 64.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
This module contains the building blocks of the compiler such as code
blocks and registers. Most relevant is the central :py:class:`Program`
object that holds various properties of the computation.
"""
import inspect
import itertools
import math
import os
import re
import sys
import hashlib
import random
from collections import defaultdict, deque
from functools import reduce
import Compiler.instructions
import Compiler.instructions_base
import Compiler.instructions_base as inst_base
from Compiler.config import REG_MAX, USER_MEM, COST, MEM_MAX
from Compiler.exceptions import CompilerError
from Compiler.instructions_base import RegType
from . import allocator as al
from . import util
data_types = dict(
triple=0,
square=1,
bit=2,
inverse=3,
dabit=4,
mixed=5,
random=6,
open=7,
)
field_types = dict(
modp=0,
gf2n=1,
bit=2,
)
class defaults:
debug = False
verbose = False
outfile = None
ring = 0
field = 0
binary = 0
garbled = False
prime = None
galois = 40
budget = 1000
mixed = False
edabit = False
invperm = False
split = None
cisc = True
comparison = None
merge_opens = True
preserve_mem_order = False
max_parallel_open = 0
dead_code_elimination = False
noreallocate = False
asmoutfile = None
stop = False
insecure = False
keep_cisc = False
comparison_rabbit = False
class Program(object):
"""A program consists of a list of tapes representing the whole
computation.
When compiling an :file:`.mpc` file, the single instance is
available as :py:obj:`program`. When compiling directly
from Python code, an instance has to be created before running any
instructions.
"""
def __init__(self, args, options=defaults, name=None):
from .non_linear import KnownPrime, Prime
self.options = options
self.verbose = options.verbose
self.args = args
self.name = name
self.init_names(args)
self._security = 40
self.used_security = 0
self.prime = None
self.tapes = []
if sum(x != 0 for x in (options.ring, options.field, options.binary)) > 1:
raise CompilerError("can only use one out of -B, -R, -F")
if options.prime and (options.ring or options.binary):
raise CompilerError("can only use one out of -B, -R, -p")
if options.ring:
self.set_ring_size(int(options.ring))
else:
self.bit_length = int(options.binary) or int(options.field)
if options.prime:
self.prime = int(options.prime)
print("WARNING: --prime/-P activates code that usually isn't "
"the most efficient variant. Consider using --field/-F "
"and set the prime only during the actual computation.")
if not self.rabbit_gap() and self.prime > 2 ** 50:
print("The chosen prime is particularly inefficient. "
"Consider using a prime that is closer to a power "
"of two", end='')
try:
import gmpy2
bad_prime = self.prime
self.prime = 2 ** int(
round(math.log(self.prime, 2))) + 1
while True:
if self.prime > 2 ** 59:
# LWE compatibility
step = 2 ** 15
else:
step = 1
if self.prime < bad_prime:
self.prime += step
else:
self.prime -= step
if gmpy2.is_prime(self.prime):
break
assert self.rabbit_gap()
print(", for example, %d." % self.prime)
self.prime = bad_prime
except ImportError:
print(".")
if options.execute:
print("Use '-- --prime <prime>' to specify the prime for "
"execution only.")
max_bit_length = int(options.prime).bit_length() - 2
if self.bit_length > max_bit_length:
raise CompilerError(
"integer bit length can be maximal %s" % max_bit_length
)
self.bit_length = self.bit_length or max_bit_length
self.non_linear = KnownPrime(self.prime)
else:
self.non_linear = Prime()
if not self.bit_length:
self.bit_length = 64
print("Default bit length for compilation:", self.bit_length)
if not (options.binary or options.garbled):
print("Default security parameter for compilation:", self._security)
self.galois_length = int(options.galois)
if self.verbose:
print("Galois length:", self.galois_length)
self.tape_counter = 0
self._curr_tape = None
self.DEBUG = options.debug
self.allocated_mem = RegType.create_dict(lambda: USER_MEM)
self.free_mem_blocks = defaultdict(al.BlockAllocator)
self.later_mem_blocks = defaultdict(list)
self.allocated_mem_blocks = {}
self.saved = 0
self.req_num = None
self.tape_stack = []
self.n_threads = 1
self.public_input_file = None
self.types = {}
if self.options.budget:
self.budget = int(self.options.budget)
else:
if self.options.optimize_hard:
self.budget = 100000
else:
self.budget = defaults.budget
self.to_merge = [
Compiler.instructions.asm_open_class,
Compiler.instructions.gasm_open_class,
Compiler.instructions.muls_class,
Compiler.instructions.gmuls_class,
Compiler.instructions.mulrs_class,
Compiler.instructions.gmulrs,
Compiler.instructions.dotprods_class,
Compiler.instructions.gdotprods_class,
Compiler.instructions.asm_input_class,
Compiler.instructions.gasm_input_class,
Compiler.instructions.inputfix_class,
Compiler.instructions.inputfloat_class,
Compiler.instructions.inputmixed_class,
Compiler.instructions.trunc_pr_class,
Compiler.instructions_base.Mergeable,
]
import Compiler.GC.instructions as gc
self.to_merge += [
gc.ldmsdi,
gc.stmsdi,
gc.ldmsd,
gc.stmsd,
gc.stmsdci,
gc.andrs,
gc.ands,
gc.inputb,
gc.inputbvec,
gc.reveal,
]
self.use_trunc_pr = False
""" Setting whether to use special probabilistic truncation. """
self.use_dabit = options.mixed
""" Setting whether to use daBits for non-linear functionality. """
self._edabit = options.edabit
self._comparison_rabbit = options.comparison_rabbit
""" Whether to use the low-level INVPERM instruction (only implemented with the assumption of a semi-honest two-party environment)"""
self._invperm = options.invperm
self._split = False
if options.split:
self.use_split(int(options.split))
self._square = False
self._always_raw = False
self._linear_rounds = False
self.warn_about_mem = [True]
self.relevant_opts = set()
self.n_running_threads = None
self.input_files = {}
self.base_addresses = util.dict_by_id()
self._protect_memory = False
self.mem_protect_stack = []
self._always_active = True
self.active = True
self.prevent_breaks = False
self.cisc_to_function = True
if not self.options.cisc:
self.options.cisc = not self.options.optimize_hard
self.use_tape_calls = True
self.force_cisc_tape = False
self.have_warned_trunc_pr = False
Program.prog = self
from . import comparison, instructions, instructions_base, types
instructions.program = self
instructions_base.program = self
types.program = self
comparison.program = self
comparison.set_variant(options)
def get_args(self):
return self.args
def max_par_tapes(self):
"""Upper bound on number of tapes that will be run in parallel.
(Excludes empty tapes)"""
return self.n_threads
def init_names(self, args):
self.programs_dir = "Programs"
if self.verbose:
print("Compiling program in", self.programs_dir)
for dirname in (self.programs_dir, "Player-Data"):
if not os.path.exists(dirname):
os.mkdir(dirname)
# create extra directories if needed
for dirname in ["Public-Input", "Bytecode", "Schedules", "Functions"]:
if not os.path.exists(self.programs_dir + "/" + dirname):
os.mkdir(self.programs_dir + "/" + dirname)
if self.name is None:
self.name = args[0].split("/")[-1]
exts = ".mpc", ".py"
for ext in exts:
if self.name.endswith(ext):
self.name = self.name[:-len(ext)]
infiles = [args[0]]
for x in (self.programs_dir, sys.path[0] + "/Programs"):
for ext in exts:
filename = args[0]
if not filename.endswith(ext):
filename += ext
filename = x + "/Source/" + filename
if os.path.abspath(filename) not in \
[os.path.abspath(f) for f in infiles]:
infiles += [filename]
existing = [f for f in infiles if os.path.exists(f)]
if len(existing) == 1:
self.infile = existing[0]
elif len(existing) > 1:
raise CompilerError("ambiguous input files: " +
", ".join(existing))
else:
raise CompilerError(
"found none of the potential input files: " +
", ".join("'%s'" % x for x in infiles))
"""
self.name is input file name (minus extension) + any optional arguments.
Used to generate output filenames
"""
if self.options.outfile:
self.name = self.options.outfile + "-" + self.name
else:
self.name = self.name
if len(args) > 1:
self.name += "-" + "-".join(re.sub("/", "_", arg) for arg in args[1:])
def set_ring_size(self, ring_size):
from .non_linear import Ring
for tape in self.tapes:
prev = tape.req_bit_length["p"]
if prev and prev != ring_size:
raise CompilerError("cannot have different ring sizes")
self.bit_length = ring_size - 1
self.non_linear = Ring(ring_size)
self.options.ring = str(ring_size)
def new_tape(self, function, args=[], name=None, single_thread=False,
finalize=True, **kwargs):
"""
Create a new tape from a function. See
:py:func:`~Compiler.library.multithread` and
:py:func:`~Compiler.library.for_range_opt_multithread` for
easier-to-use higher-level functionality. The following runs
two threads defined by two different functions::
def f():
...
def g():
...
tapes = [program.new_tape(x) for x in (f, g)]
thread_numbers = program.run_tapes(tapes)
program.join_tapes(threads_numbers)
:param function: Python function defining the thread
:param args: arguments to the function
:param name: name used for files
:param single_thread: Boolean indicating whether tape will
never be run in parallel to itself
:returns: tape handle
"""
if name is None:
name = function.__name__
name = "%s-%s" % (self.name, name)
# make sure there is a current tape
self.curr_tape
tape_index = len(self.tapes)
self.tape_stack.append(self.curr_tape)
self.curr_tape = Tape(name, self, **kwargs)
self.curr_tape.singular = single_thread
self.tapes.append(self.curr_tape)
function(*args)
if finalize:
self.finalize_tape(self.curr_tape)
if self.tape_stack:
self.curr_tape = self.tape_stack.pop()
return tape_index
def run_tape(self, tape_index, arg):
return self.run_tapes([[tape_index, arg]])[0]
def run_tapes(self, args):
"""Run tapes in parallel. See :py:func:`new_tape` for an example.
:param args: list of tape handles or tuples of tape handle and extra
argument (for :py:func:`~Compiler.library.get_arg`)
:returns: list of thread numbers
"""
if not self.curr_tape.singular:
raise CompilerError(
"Compiler does not support " "recursive spawning of threads"
)
args = [list(util.tuplify(arg)) for arg in args]
singular_tapes = set()
for arg in args:
if self.tapes[arg[0]].singular:
if arg[0] in singular_tapes:
raise CompilerError("cannot run singular tape in parallel")
singular_tapes.add(arg[0])
assert len(arg)
assert len(arg) <= 2
if len(arg) == 1:
arg += [0]
thread_numbers = []
while len(thread_numbers) < len(args):
free_threads = self.curr_tape.free_threads
self.curr_tape.ran_threads = True
if free_threads:
thread_numbers.append(min(free_threads))
free_threads.remove(thread_numbers[-1])
else:
thread_numbers.append(self.n_threads)
self.n_threads += 1
self.curr_tape.start_new_basicblock(name="pre-run_tape")
Compiler.instructions.run_tape(
*sum(([x] + list(y) for x, y in zip(thread_numbers, args)), [])
)
self.curr_tape.start_new_basicblock(name="post-run_tape")
for arg in args:
self.curr_block.req_node.children.append(
self.tapes[arg[0]].req_tree)
return thread_numbers
def join_tape(self, thread_number):
self.join_tapes([thread_number])
def join_tapes(self, thread_numbers):
"""Wait for completion of tapes. See :py:func:`new_tape` for an example.
:param thread_numbers: list of thread numbers
"""
self.curr_tape.start_new_basicblock(name="pre-join_tape")
for thread_number in thread_numbers:
Compiler.instructions.join_tape(thread_number)
self.curr_tape.free_threads.add(thread_number)
self.curr_tape.start_new_basicblock(name="post-join_tape")
def update_req(self, tape):
if self.req_num is None:
self.req_num = tape.req_num
else:
self.req_num += tape.req_num
def write_bytes(self):
"""Write all non-empty threads and schedule to files."""
nonempty_tapes = [t for t in self.tapes]
sch_filename = self.programs_dir + "/Schedules/%s.sch" % self.name
sch_file = open(sch_filename, "w")
print("Writing to", sch_filename)
sch_file.write(str(self.max_par_tapes()) + "\n")
sch_file.write(str(len(nonempty_tapes)) + "\n")
sch_file.write(" ".join("%s:%d" % (tape.name, len(tape))
for tape in nonempty_tapes) + "\n")
sch_file.write("1 0\n")
sch_file.write("0\n")
sch_file.write(" ".join(sys.argv) + "\n")
req = max(x.req_bit_length["p"] for x in self.tapes)
if self.options.ring:
sch_file.write("R:%s" % self.options.ring)
elif self.options.prime:
sch_file.write("p:%s" % self.options.prime)
else:
sch_file.write("lgp:%s" % req)
sch_file.write("\n")
sch_file.write("opts: %s\n" % " ".join(self.relevant_opts))
sch_file.write("sec:%d\n" % self.used_security)
sch_file.close()
h = hashlib.sha256()
for tape in self.tapes:
tape.write_bytes()
h.update(tape.hash)
print('Hash:', h.hexdigest())
def finalize_tape(self, tape):
if not tape.purged:
curr_tape = self.curr_tape
self.curr_tape = tape
tape.optimize(self.options)
self.curr_tape = curr_tape
tape.write_bytes()
if self.options.asmoutfile:
tape.write_str(self.options.asmoutfile + "-" + tape.name)
tape.purge()
@property
def curr_tape(self):
"""The tape that is currently running."""
if self._curr_tape is None:
assert not self.tapes
self._curr_tape = Tape(self.name, self)
self.tapes.append(self._curr_tape)
return self._curr_tape
@curr_tape.setter
def curr_tape(self, value):
self._curr_tape = value
@property
def curr_block(self):
"""The basic block that is currently being created."""
return self.curr_tape.active_basicblock
def malloc(self, size, mem_type, reg_type=None, creator_tape=None, use_freed=True):
"""Allocate memory from the top"""
if not isinstance(size, int):
raise CompilerError("size must be known at compile time")
if size == 0:
return
if isinstance(mem_type, type):
try:
size *= math.ceil(mem_type.n / mem_type.unit)
except AttributeError:
pass
self.types[mem_type.reg_type] = mem_type
mem_type = mem_type.reg_type
elif reg_type is not None:
self.types[mem_type] = reg_type
single_size = None
if not (creator_tape or self.curr_tape).singular:
if self.n_running_threads:
single_size = size
size *= self.n_running_threads
else:
raise CompilerError("cannot allocate memory " "outside main thread")
blocks = self.free_mem_blocks[mem_type]
addr = blocks.pop(size) if use_freed else None
if addr is not None:
self.saved += size
else:
addr = self.allocated_mem[mem_type]
self.allocated_mem[mem_type] += size
if len(str(addr)) != len(str(addr + size)) and self.verbose:
print("Memory of type '%s' now of size %d" % (mem_type, addr + size))
if addr + size >= MEM_MAX:
raise CompilerError(
"allocation exceeded for type '%s' after adding %d" % \
(mem_type, size))
self.allocated_mem_blocks[addr, mem_type] = size, self.curr_block.alloc_pool
if single_size:
from .library import get_arg, runtime_error_if
bak = self.curr_tape.active_basicblock
self.curr_tape.active_basicblock = self.curr_tape.basicblocks[0]
arg = get_arg()
runtime_error_if(arg >= self.n_running_threads, "malloc")
res = addr + single_size * arg
self.curr_tape.active_basicblock = bak
self.base_addresses[res] = addr
return res
else:
return addr
def free(self, addr, mem_type):
"""Free memory"""
now = True
if not util.is_constant(addr):
addr = self.base_addresses[addr]
now = self.curr_tape == self.tapes[0]
size, pool = self.allocated_mem_blocks[addr, mem_type]
if self.curr_block.alloc_pool is not pool:
raise CompilerError("Cannot free memory across function blocks")
self.allocated_mem_blocks.pop((addr, mem_type))
if now:
self.free_mem_blocks[mem_type].push(addr, size)
else:
self.later_mem_blocks[mem_type].append((addr, size))
def free_later(self):
for mem_type in self.later_mem_blocks:
for block in self.later_mem_blocks[mem_type]:
self.free_mem_blocks[mem_type].push(*block)
self.later_mem_blocks.clear()
def finalize(self):
# optimize the tapes
for tape in self.tapes:
tape.optimize(self.options)
if self.tapes:
self.update_req(self.curr_tape)
# finalize the memory
self.finalize_memory()
# communicate protocol compability
Compiler.instructions.active(self._always_active)
self.write_bytes()
if self.options.asmoutfile:
for tape in self.tapes:
tape.write_str(self.options.asmoutfile + "-" + tape.name)
# Making sure that the public_input_file has been properly closed
if self.public_input_file is not None:
self.public_input_file.close()
def finalize_memory(self):
self.curr_tape.start_new_basicblock(None, "memory-usage",
req_node=self.curr_tape.req_tree)
# reset register counter to 0
if not self.options.noreallocate:
self.curr_tape.init_registers()
for mem_type, size in sorted(self.allocated_mem.items()):
if size and (not self.options.garbled or \
mem_type not in ('s', 'sg', 'c', 'cg')):
# print "Memory of type '%s' of size %d" % (mem_type, size)
if mem_type in self.types:
self.types[mem_type].load_mem(size - 1, mem_type)
else:
from Compiler.types import _get_type
_get_type(mem_type).load_mem(size - 1, mem_type)
if self.verbose:
if self.saved:
print("Saved %s memory units through reallocation" % self.saved)
def public_input(self, x):
"""Append a value to the public input file."""
if self.public_input_file is None:
self.public_input_file = open(
self.programs_dir + "/Public-Input/%s" % self.name, "w"
)
self.public_input_file.write("%s\n" % str(x))
def get_binary_input_file(self, player):
key = player, 'bin'
if key not in self.input_files:
filename = 'Player-Data/Input-Binary-P%d-0' % player
print('Writing binary data to', filename)
self.input_files[key] = open(filename, 'wb')
return self.input_files[key]
def set_bit_length(self, bit_length):
"""Change the integer bit length for non-linear functions."""
self.bit_length = bit_length
print("Changed bit length for comparisons etc. to", bit_length)
def set_security(self, security):
changed = self._security != security
self._security = security
if changed:
print("Changed statistical security for comparison etc. to",
security)
@property
def security(self):
"""The statistical security parameter for non-linear
functions."""
self.used_security = max(self.used_security, self._security)
return self._security
@security.setter
def security(self, security):
self.set_security(security)
def optimize_for_gc(self):
import Compiler.GC.instructions as gc
self.to_merge += [gc.xors]
def get_tape_counter(self):
res = self.tape_counter
self.tape_counter += 1
return res
@property
def use_trunc_pr(self):
if not self._use_trunc_pr:
self.relevant_opts.add("trunc_pr")
return self._use_trunc_pr
@use_trunc_pr.setter
def use_trunc_pr(self, change):
self._use_trunc_pr = change
def trunc_pr_warning(self):
if not self.have_warned_trunc_pr:
print("WARNING: Probabilistic truncation leaks some information, "
"see https://eprint.iacr.org/2024/1127 for discussion. "
"Use 'sfix.round_nearest = True' to deactivate this for "
"fixed-point operations.")
self.have_warned_trunc_pr = True
def use_edabit(self, change=None):
"""Setting whether to use edaBits for non-linear
functionality (default: false).
:param change: change setting if not :py:obj:`None`
:returns: setting if :py:obj:`change` is :py:obj:`None`
"""
if change is None:
if not self._edabit:
self.relevant_opts.add("edabit")
return self._edabit
else:
self._edabit = change
def use_comparison_rabbit(self, change=None):
"""Setting whether to use the rabbit comparison protocol (default: false).
:param change: change setting if not :py:obj:`None`
:returns: setting if :py:obj:`change` is :py:obj:`None`
"""
if change is None:
if not self._comparison_rabbit:
self.relevant_opts.add("comparison_rabbit")
return self._comparison_rabbit
else:
self._comparison_rabbit = change
def use_invperm(self, change=None):
""" Set whether to use the low-level INVPERM instruction to inverse a permutation (see sint.inverse_permutation). The INVPERM instruction assumes a semi-honest two-party environment. If false, a general protocol implemented in the high-level language is used.
:param change: change setting if not :py:obj:`None`
:returns: setting if :py:obj:`change` is :py:obj:`None`
"""
if change is None:
if not self._invperm:
self.relevant_opts.add("invperm")
return self._invperm
else:
self._invperm = change
def use_edabit_for(self, *args):
return True
def use_split(self, change=None):
"""Setting whether to use local arithmetic-binary share
conversion for non-linear functionality (default: false).
:param change: change setting if not :py:obj:`None`
:returns: setting if :py:obj:`change` is :py:obj:`None`
"""
if change is None:
if not self._split:
self.relevant_opts.add("split")
return self._split
else:
if change and not self.options.ring:
raise CompilerError("splitting only supported for rings")
assert change > 1 or change is False
self._split = change
def use_square(self, change=None):
"""Setting whether to use preprocessed square tuples
(default: false).
:param change: change setting if not :py:obj:`None`
:returns: setting if :py:obj:`change` is :py:obj:`None`
"""
if change is None:
return self._square
else:
self._square = change
def always_raw(self, change=None):
if change is None:
return self._always_raw
else:
self._always_raw = change
def linear_rounds(self, change=None):
if change is None:
return self._linear_rounds
else:
self._linear_rounds = change
def options_from_args(self):
"""Set a number of options from the command-line arguments."""
if "trunc_pr" in self.args:
self.use_trunc_pr = True
if "signed_trunc_pr" in self.args:
self.use_trunc_pr = -1
if "split" in self.args or "split3" in self.args:
self.use_split(3)
for arg in self.args:
m = re.match("split([0-9]+)", arg)
if m:
self.use_split(int(m.group(1)))
if "raw" in self.args:
self.always_raw(True)
if "edabit" in self.args:
self.use_edabit(True)
if "comparison_rabbit" in self.args:
self.use_comparison_rabbit(True)
if "invperm" in self.args:
self.use_invperm(True)
if "linear_rounds" in self.args:
self.linear_rounds(True)
def disable_memory_warnings(self):
self.warn_about_mem.append(False)
self.curr_block.warn_about_mem = False
def protect_memory(self, status):
""" Enable or disable memory protection. """
self._protect_memory = status
def open_memory_scope(self, key=None):
self.mem_protect_stack.append(self._protect_memory)
self.protect_memory(key or object())
def close_memory_scope(self):
self.protect_memory(self.mem_protect_stack.pop())
def use_cisc(self):
return self.options.cisc and (not self.prime or self.rabbit_gap()) \
and not self.options.max_parallel_open
def rabbit_gap(self):
assert self.prime
p = self.prime
logp = int(round(math.log(p, 2)))
return abs(p - 2 ** logp) / p < 2 ** -self.security
@property
def active(self):
""" Whether to use actively secure protocols. """
return self._active
@active.setter
def active(self, change):
self._always_active &= change
self._active = change
def semi_honest(self):
self._always_active = False
@staticmethod
def read_schedule(schedule):
m = re.search(r"([^/]*)\.mpc", schedule)
if m:
schedule = m.group(1)
if not os.path.exists(schedule):
schedule = "Programs/Schedules/%s.sch" % schedule
try:
return open(schedule).readlines()
except FileNotFoundError:
print(
"%s not found, have you compiled the program?" % schedule,
file=sys.stderr,
)
sys.exit(1)
@classmethod
def read_tapes(cls, schedule):
lines = cls.read_schedule(schedule)
for tapename in lines[2].split(" "):
yield tapename.strip().split(":")[0]
@classmethod
def read_n_threads(cls, schedule):
return int(cls.read_schedule(schedule)[0])
@classmethod
def read_domain_size(cls, schedule):
from Compiler.instructions import reqbl_class
tapename = cls.read_schedule(schedule)[2].strip().split(":")[0]
for inst in Tape.read_instructions(tapename):
if inst.code == reqbl_class.code:
bl = inst.args[0]
return (abs(bl.i) + 63) // 64 * 8
class Tape:
"""A tape contains a list of basic blocks, onto which instructions are added."""
def __init__(self, name, program, thread_pool=None):
"""Set prime p and the initial instructions and registers."""
self.program = program
name += "-%d" % program.get_tape_counter()
self.init_names(name)
self.init_registers()
self.req_tree = self.ReqNode(name)
self.basicblocks = []
self.purged = False
self.block_counter = 0
self.active_basicblock = None
self.old_allocated_mem = program.allocated_mem.copy()
self.start_new_basicblock(req_node=self.req_tree)
self._is_empty = False
self.merge_opens = True
self.if_states = []
self.req_bit_length = defaultdict(lambda: 0)
self.bit_length_reason = None
self.function_basicblocks = {}
self.functions = []
self.singular = True
self.free_threads = set() if thread_pool is None else thread_pool
self.loop_breaks = []
self.warned_about_mem = False
self.return_values = []
self.ran_threads = False
self.unused_decorators = {}
class BasicBlock(object):
def __init__(self, parent, name, scope, exit_condition=None,
req_node=None):
self.parent = parent
self.instructions = []
self.name = name
self.open_queue = []
self.exit_condition = exit_condition
self.exit_block = None
self.previous_block = None
self.scope = scope
self.children = []
if scope is not None:
scope.children.append(self)
self.alloc_pool = scope.alloc_pool
else:
self.alloc_pool = al.AllocPool()
self.purged = False
self.n_rounds = 0
self.n_to_merge = 0
self.rounds = Tape.ReqNum()
self.warn_about_mem = parent.program.warn_about_mem[-1]
self.req_node = req_node
self.used_from_scope = set()
def __len__(self):
return len(self.instructions)
def new_reg(self, reg_type, size=None):
return self.parent.new_reg(reg_type, size=size)
def set_return(self, previous_block, sub_block):
self.previous_block = previous_block
self.sub_block = sub_block
def adjust_return(self):
offset = self.sub_block.get_offset(self)
self.previous_block.return_address_store.args[1] = offset
def set_exit(self, condition, exit_true=None):
"""Sets the block which we start from next, depending on the condition.
(Default is to go to next block in the list)
"""
self.exit_condition = condition
self.exit_block = exit_true
for reg in condition.get_used():
reg.can_eliminate = False
def add_jump(self):
"""Add the jump for this block's exit condition to list of
instructions (must be done after merging)"""
self.instructions.append(self.exit_condition)
def get_offset(self, next_block):
return next_block.offset - (self.offset + len(self.instructions))
def adjust_jump(self):
"""Set the correct relative jump offset"""
offset = self.get_offset(self.exit_block)
self.exit_condition.set_relative_jump(offset)
def purge(self, retain_usage=True):
def relevant(inst):
req_node = Tape.ReqNode("")
req_node.num = Tape.ReqNum()
inst.add_usage(req_node)
return req_node.num != {}
if retain_usage:
self.usage_instructions = list(filter(relevant, self.instructions))
else:
self.usage_instructions = []
if len(self.usage_instructions) > 1000 and \
self.parent.program.verbose:
print("Retaining %d instructions" % len(self.usage_instructions))
del self.instructions
self.purged = True
def add_usage(self, req_node):
if self.purged:
instructions = self.usage_instructions
else:
instructions = self.instructions
for inst in instructions:
inst.add_usage(req_node)
req_node.num["all", "round"] += self.n_rounds
req_node.num["all", "inv"] += self.n_to_merge
req_node.num += self.rounds
def expand_cisc(self):
if self.parent.program.options.keep_cisc is not None:
skip = ["LTZ", "Trunc", "EQZ"]
skip += self.parent.program.options.keep_cisc.split(",")
else:
skip = []
tape = self.parent
tape.start_new_basicblock(scope=self.scope, req_node=self.req_node,
name="cisc")
start_block = tape.basicblocks[-1]
start_block.alloc_pool = self.alloc_pool
for inst in self.instructions:
inst.expand_merged(skip)
self.instructions = tape.active_basicblock.instructions
if start_block == tape.basicblocks[-1]:
res = self
else:
res = start_block
tape.basicblocks[-1] = self
return res
def __str__(self):
return self.name
def is_empty(self):
"""Returns True if the list of basic blocks is empty.