33# SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
44
55#
6- # Looks for CBMC contracts without proof
6+ # Looks for CBMC contracts without proof, and for loop annotations
7+ # missing a decreases clause (needed for termination proofs).
78#
89
910import re
1011import sys
1112import subprocess
1213import pathlib
1314
15+
1416def get_c_source_files ():
1517 return get_files ("mlkem/**/*.c" )
1618
19+
1720def get_header_files ():
1821 return get_files ("mlkem/**/*.h" )
1922
23+
2024def get_files (pattern ):
2125 return list (map (str , pathlib .Path ().glob (pattern )))
2226
27+
2328def gen_proofs ():
24- cmd_str = ["./proofs/cbmc/list_proofs.sh" ]
25- p = subprocess .run (cmd_str , capture_output = True , universal_newlines = False )
26- proofs = filter (lambda s : s .strip () != "" , p .stdout .decode ().split ("\n " ))
27- return proofs
29+ cmd_str = ["./proofs/cbmc/list_proofs.sh" ]
30+ p = subprocess .run (cmd_str , capture_output = True , universal_newlines = False )
31+ proofs = filter (lambda s : s .strip () != "" , p .stdout .decode ().split ("\n " ))
32+ return proofs
33+
2834
2935def gen_contracts ():
3036 files = get_c_source_files () + get_header_files ()
@@ -33,19 +39,24 @@ def gen_contracts():
3339 with open (filename , "r" ) as f :
3440 content = f .read ()
3541
36- contract_pattern = r' (\w+)\s*\([^)]*\)\s*\n?\s*__contract__'
42+ contract_pattern = r" (\w+)\s*\([^)]*\)\s*\n?\s*__contract__"
3743 matches = re .finditer (contract_pattern , content )
3844 for m in matches :
39- line = content .count (' \n ' , 0 , m .start ())
45+ line = content .count (" \n " , 0 , m .start ())
4046 yield (filename , line , m .group (1 ).removeprefix ("mlk_" ))
4147
48+
4249def is_exception (funcname ):
4350 # The functions passing this filter are known not to have a proof
4451
45- if funcname == ' poly_permute_bitrev_to_custom' :
52+ if funcname == " poly_permute_bitrev_to_custom" :
4653 return True
4754
48- if funcname .endswith ("_native" ) or funcname .endswith ("_asm" ) or funcname .endswith ("avx2" ):
55+ if (
56+ funcname .endswith ("_native" )
57+ or funcname .endswith ("_asm" )
58+ or funcname .endswith ("avx2" )
59+ ):
4960 # CBMC proofs are axiomatized against contracts of the backends
5061 return True
5162
@@ -63,31 +74,115 @@ def is_exception(funcname):
6374
6475 return False
6576
77+
6678def check_contracts ():
6779 contracts = set (gen_contracts ())
6880 proofs = set (gen_proofs ())
6981
7082 bad = []
7183
7284 # Print contracts without proofs
73- for ( filename , line , funcname ) in contracts :
85+ for filename , line , funcname in contracts :
7486 if funcname in proofs :
7587 continue
7688
7789 if is_exception (funcname ):
78- print (f"{ filename } :{ line } :{ funcname } has contract but no proof, "
79- "but is listed as exception" )
90+ print (
91+ f"{ filename } :{ line } :{ funcname } has contract but no proof, "
92+ "but is listed as exception"
93+ )
8094 continue
8195
82- print (f"{ filename } :{ line } :{ funcname } : has contract but no proof. FAIL" ,
83- file = sys .stderr )
96+ print (
97+ f"{ filename } :{ line } :{ funcname } : has contract but no proof. FAIL" ,
98+ file = sys .stderr ,
99+ )
84100 bad .append (funcname )
85101
86102 return len (bad ) == 0
87103
104+
105+ # Loops that only terminate probabilistically (rejection sampling)
106+ # and therefore cannot have a decreases clause.
107+ DECREASES_EXCEPTIONS = [
108+ ("mlkem/src/sampling.c" , "mlk_poly_rej_uniform_x4" ),
109+ ("mlkem/src/sampling.c" , "mlk_poly_rej_uniform" ),
110+ ]
111+
112+
113+ def find_enclosing_function (content , pos ):
114+ """Find the name of the function enclosing the given position."""
115+ prefix = content [:pos ]
116+ # Match function definitions: look for 'name(' at the start of a line
117+ # or after a type, followed eventually by '{' on its own line
118+ matches = list (re .finditer (r"\n\w+\s+(\w+)\s*\(" , prefix ))
119+ if matches :
120+ return matches [- 1 ].group (1 )
121+ return None
122+
123+
124+ def check_decreases ():
125+ files = get_c_source_files () + get_header_files ()
126+ bad = []
127+
128+ for filename in files :
129+ with open (filename , "r" ) as f :
130+ content = f .read ()
131+
132+ # Find __loop__( that are actual loop annotations (not macro defs)
133+ for m in re .finditer (r"__loop__\(" , content ):
134+ start = m .start ()
135+ line = content .count ("\n " , 0 , start ) + 1
136+
137+ # Skip macro definitions like #define __loop__(x)
138+ line_start = content .rfind ("\n " , 0 , start ) + 1
139+ line_text = content [line_start : content .find ("\n " , start )]
140+ if "#define" in line_text :
141+ continue
142+
143+ # Extract the full __loop__(...) content by matching parentheses
144+ depth = 0
145+ i = m .end () - 1
146+ for i in range (m .end () - 1 , len (content )):
147+ if content [i ] == "(" :
148+ depth += 1
149+ elif content [i ] == ")" :
150+ depth -= 1
151+ if depth == 0 :
152+ break
153+ loop_body = content [m .start () : i + 1 ]
154+
155+ if "decreases(" in loop_body :
156+ continue
157+
158+ func = find_enclosing_function (content , start )
159+
160+ if (filename , func ) in DECREASES_EXCEPTIONS :
161+ print (
162+ f"{ filename } :{ line } : __loop__ without decreases in "
163+ f"{ func } (), but is listed as exception"
164+ )
165+ continue
166+
167+ print (
168+ f"{ filename } :{ line } : __loop__ without decreases clause "
169+ f"in { func } (). FAIL" ,
170+ file = sys .stderr ,
171+ )
172+ bad .append ((filename , line ))
173+
174+ return len (bad ) == 0
175+
176+
88177def _main ():
89- if check_contracts () != True :
178+ ok = True
179+ if not check_contracts ():
180+ ok = False
181+ if not check_decreases ():
182+ ok = False
183+ if not ok :
90184 sys .exit (1 )
91185
186+
92187if __name__ == "__main__" :
93188 _main ()
0 commit comments