@@ -10,89 +10,116 @@ import re
1010import math
1111import pathlib
1212
13- from sympy import simplify , sympify , Function , Rational
13+ from sympy import sympify , Rational
14+
1415
1516def get_c_source_files ():
1617 return get_files ("mlkem/**/*.c" )
1718
19+
1820def get_header_files ():
1921 return get_files ("mlkem/**/*.h" )
2022
23+
2124def get_files (pattern ):
2225 return list (map (str , pathlib .Path ().glob (pattern )))
2326
27+
2428# Standard color definitions
25- GREEN = "\033 [32m"
26- RED = "\033 [31m"
27- BLUE = "\033 [94m"
28- BOLD = "\033 [1m"
29- NORMAL = "\033 [0m"
29+ GREEN = "\033 [32m"
30+ RED = "\033 [31m"
31+ BLUE = "\033 [94m"
32+ BOLD = "\033 [1m"
33+ NORMAL = "\033 [0m"
3034
3135CHECKED = f"{ GREEN } ✓{ NORMAL } "
3236FAIL = f"{ RED } ✗{ NORMAL } "
3337REMEMBERED = f"{ BLUE } ⊢{ NORMAL } "
3438
39+
3540def check_magic_numbers ():
3641 mlkem_q = 3329
37- exceptions = [mlkem_q ,
38- 1665 , # q/2
39- 1600 , # For Keccak-F1600
40- 1023 , 1024 , 2047 , 2048 , 4095 , 4096 , 8192 , 32767 , 32768 , 65535 , 65536 ,
41- 2025 , # years
42- ]
42+ exceptions = [
43+ mlkem_q ,
44+ 1665 , # q/2
45+ 1600 , # For Keccak-F1600
46+ 1023 ,
47+ 1024 ,
48+ 2047 ,
49+ 2048 ,
50+ 4095 ,
51+ 4096 ,
52+ 8192 ,
53+ 32767 ,
54+ 32768 ,
55+ 65535 ,
56+ 65536 ,
57+ 2025 , # years
58+ ]
4359 enable_marker = "check-magic: on"
4460 disable_marker = "check-magic: off"
4561 autogen_marker = "This file is auto-generated from scripts/autogen"
4662
4763 files = get_c_source_files () + get_header_files ()
4864
49- def is_exception (filename , l , magic ):
65+ def is_exception (filename , line , magic ):
5066 return magic in exceptions
5167
52- def get_magic (l ):
53- regexp = r' /\* check-magic:\s+([-]?\d{4,})\s*==\s*(.*?) \*/'
54- m = re .search (regexp , l )
68+ def get_magic (line ):
69+ regexp = r" /\* check-magic:\s+([-]?\d{4,})\s*==\s*(.*?) \*/"
70+ m = re .search (regexp , line )
5571 if m is not None :
5672 # Remove magic annotation to avoid it being treated
5773 # as magic value itself
58- l = re .sub (regexp ,'' , l )
59- return l , (int (m .group (1 )), m .group (2 ))
60- return l , None
74+ line = re .sub (regexp , "" , line )
75+ return line , (int (m .group (1 )), m .group (2 ))
76+ return line , None
6177
62- def get_define (l ):
63- m = re .search (r' #define\s+(\w+)' , l )
78+ def get_define (line ):
79+ m = re .search (r" #define\s+(\w+)" , line )
6480 if m is not None :
6581 return m .group (1 )
6682 return None
6783
6884 def evaluate_magic (m , known_magics ):
69- def unsigned_mod (x ,y ):
85+ def unsigned_mod (x , y ):
7086 return x % y
71- def signed_mod (x ,y ):
72- r = unsigned_mod (x ,y )
87+
88+ def signed_mod (x , y ):
89+ r = unsigned_mod (x , y )
7390 if r >= y // 2 :
7491 r -= y
7592 return r
76- def pow_mod (x ,y ,m ):
93+
94+ def pow_mod (x , y , m ):
7795 x = int (x )
7896 y = int (y )
7997 m = int (m )
80- return signed_mod (pow (x ,y ,m ),m )
98+ return signed_mod (pow (x , y , m ), m )
99+
81100 def safe_round (x ):
82101 if x - math .floor (x ) == Rational (1 , 2 ):
83- raise ValueError (f"Ambiguous rounding: { x } is an odd multiple of 0.5 and it is unclear if round-up or round-down is desired" )
102+ raise ValueError (
103+ f"Ambiguous rounding: { x } is an odd multiple of 0.5 and it is unclear if round-up or round-down is desired"
104+ )
84105 return round (x )
106+
85107 def safe_floordiv (x , y ):
86108 x = int (x )
87109 y = int (y )
88110 if x % y != 0 :
89- raise ValueError (f"Non-integral division: { x } // { y } has remainder { x % y } " )
111+ raise ValueError (
112+ f"Non-integral division: { x } // { y } has remainder { x % y } "
113+ )
90114 return x // y
91- locals_dict = {'signed_mod' : signed_mod ,
92- 'unsigned_mod' : unsigned_mod ,
93- 'pow' : pow_mod ,
94- 'round' : safe_round ,
95- 'intdiv' : safe_floordiv }
115+
116+ locals_dict = {
117+ "signed_mod" : signed_mod ,
118+ "unsigned_mod" : unsigned_mod ,
119+ "pow" : pow_mod ,
120+ "round" : safe_round ,
121+ "intdiv" : safe_floordiv ,
122+ }
96123 locals_dict .update (known_magics )
97124 return sympify (m , locals = locals_dict )
98125
@@ -104,52 +131,60 @@ def check_magic_numbers():
104131 content = content .split ("\n " )
105132 # Use negative lookbefore and lookahead to exclude numbers
106133 # that occur as part of identifiers (e.g. layer12345 or 199901L)
107- pattern = r' (?<![0-9a-zA-Z/_-])([-]?\d{4,})(?![0-9a-zA-Z_-])'
134+ pattern = r" (?<![0-9a-zA-Z/_-])([-]?\d{4,})(?![0-9a-zA-Z_-])"
108135 enabled = True
109- magic_dict = {' MLKEM_Q' : mlkem_q }
136+ magic_dict = {" MLKEM_Q" : mlkem_q }
110137 magic_expr = None
111138 verified_magics = {}
112- for i , l in enumerate (content ):
113- if enabled is True and disable_marker in l :
139+ for i , line in enumerate (content ):
140+ if enabled is True and disable_marker in line :
114141 enabled = False
115142 continue
116- if enabled is False and enable_marker in l :
143+ if enabled is False and enable_marker in line :
117144 enabled = True
118145 continue
119146 if enabled is False :
120147 continue
121- l , g = get_magic (l )
148+ line , g = get_magic (line )
122149 if g is not None :
123150 magic_val , magic_expr = g
124151 magic_val_check = evaluate_magic (magic_expr , magic_dict )
125152 if magic_val != magic_val_check :
126- print (f"{ FAIL } :{ filename } :{ i } : Mismatching magic annotation: { magic_val } != { magic_expr } (= { magic_val_check } )" )
153+ print (
154+ f"{ FAIL } :{ filename } :{ i } : Mismatching magic annotation: { magic_val } != { magic_expr } (= { magic_val_check } )"
155+ )
127156 exit (1 )
128- print (f"{ REMEMBERED } :{ filename } :{ i } : Verified explanation { magic_val } == { magic_expr } " )
157+ print (
158+ f"{ REMEMBERED } :{ filename } :{ i } : Verified explanation { magic_val } == { magic_expr } "
159+ )
129160 verified_magics [magic_val ] = magic_expr
130161
131- found = next (re .finditer (pattern , l ), None )
162+ found = next (re .finditer (pattern , line ), None )
132163 if found is None :
133164 continue
134165
135166 magic = int (found .group ())
136- if is_exception (filename , l , magic ):
167+ if is_exception (filename , line , magic ):
137168 continue
138169
139170 explanation = verified_magics .get (magic , None )
140171 if explanation is None :
141172 print (f"{ FAIL } :{ filename } :{ i } : No explanation for magic value { magic } " )
142173 exit (1 )
143174
144- print (f"{ CHECKED } :{ filename } :{ i } : { magic } previously explained as { explanation } " )
175+ print (
176+ f"{ CHECKED } :{ filename } :{ i } : { magic } previously explained as { explanation } "
177+ )
145178
146179 # If this is a #define's clause, remember it
147- define = get_define (l )
180+ define = get_define (line )
148181 if define is not None :
149182 magic_dict [define ] = magic
150183
184+
151185def _main ():
152186 check_magic_numbers ()
153187
188+
154189if __name__ == "__main__" :
155190 _main ()
0 commit comments