Skip to content

Commit bcf3828

Browse files
mkannwischerwillieyz
authored andcommitted
Lint: Check that all __loop__ annotations have decreases clauses
Add a check to scripts/check-contracts that every __loop__ annotation contains a decreases clause for termination proofs. Rejection sampling loops in sampling.c are excepted as they only terminate probabilistically. Signed-off-by: Matthias J. Kannwischer <matthias@kannwischer.eu>
1 parent d590779 commit bcf3828

File tree

1 file changed

+110
-15
lines changed

1 file changed

+110
-15
lines changed

scripts/check-contracts

Lines changed: 110 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,34 @@
33
# SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
44

55
#
6-
# Looks for CBMC contracts without proof
6+
# Looks for CBMC contracts without proof, and for loop annotations
7+
# missing a decreases clause (needed for termination proofs).
78
#
89

910
import re
1011
import sys
1112
import subprocess
1213
import pathlib
1314

15+
1416
def get_c_source_files():
1517
return get_files("mlkem/**/*.c")
1618

19+
1720
def get_header_files():
1821
return get_files("mlkem/**/*.h")
1922

23+
2024
def get_files(pattern):
2125
return list(map(str, pathlib.Path().glob(pattern)))
2226

27+
2328
def gen_proofs():
24-
cmd_str = ["./proofs/cbmc/list_proofs.sh"]
25-
p = subprocess.run(cmd_str, capture_output=True, universal_newlines=False)
26-
proofs = filter(lambda s: s.strip() != "", p.stdout.decode().split("\n"))
27-
return proofs
29+
cmd_str = ["./proofs/cbmc/list_proofs.sh"]
30+
p = subprocess.run(cmd_str, capture_output=True, universal_newlines=False)
31+
proofs = filter(lambda s: s.strip() != "", p.stdout.decode().split("\n"))
32+
return proofs
33+
2834

2935
def gen_contracts():
3036
files = get_c_source_files() + get_header_files()
@@ -33,19 +39,24 @@ def gen_contracts():
3339
with open(filename, "r") as f:
3440
content = f.read()
3541

36-
contract_pattern = r'(\w+)\s*\([^)]*\)\s*\n?\s*__contract__'
42+
contract_pattern = r"(\w+)\s*\([^)]*\)\s*\n?\s*__contract__"
3743
matches = re.finditer(contract_pattern, content)
3844
for m in matches:
39-
line = content.count('\n', 0, m.start())
45+
line = content.count("\n", 0, m.start())
4046
yield (filename, line, m.group(1).removeprefix("mlk_"))
4147

48+
4249
def is_exception(funcname):
4350
# The functions passing this filter are known not to have a proof
4451

45-
if funcname == 'poly_permute_bitrev_to_custom':
52+
if funcname == "poly_permute_bitrev_to_custom":
4653
return True
4754

48-
if funcname.endswith("_native") or funcname.endswith("_asm") or funcname.endswith("avx2"):
55+
if (
56+
funcname.endswith("_native")
57+
or funcname.endswith("_asm")
58+
or funcname.endswith("avx2")
59+
):
4960
# CBMC proofs are axiomatized against contracts of the backends
5061
return True
5162

@@ -63,31 +74,115 @@ def is_exception(funcname):
6374

6475
return False
6576

77+
6678
def check_contracts():
6779
contracts = set(gen_contracts())
6880
proofs = set(gen_proofs())
6981

7082
bad = []
7183

7284
# Print contracts without proofs
73-
for (filename, line, funcname) in contracts:
85+
for filename, line, funcname in contracts:
7486
if funcname in proofs:
7587
continue
7688

7789
if is_exception(funcname):
78-
print(f"{filename}:{line}:{funcname} has contract but no proof, "
79-
"but is listed as exception")
90+
print(
91+
f"{filename}:{line}:{funcname} has contract but no proof, "
92+
"but is listed as exception"
93+
)
8094
continue
8195

82-
print(f"{filename}:{line}:{funcname}: has contract but no proof. FAIL",
83-
file=sys.stderr)
96+
print(
97+
f"{filename}:{line}:{funcname}: has contract but no proof. FAIL",
98+
file=sys.stderr,
99+
)
84100
bad.append(funcname)
85101

86102
return len(bad) == 0
87103

104+
105+
# Loops that only terminate probabilistically (rejection sampling)
106+
# and therefore cannot have a decreases clause.
107+
DECREASES_EXCEPTIONS = [
108+
("mlkem/src/sampling.c", "mlk_poly_rej_uniform_x4"),
109+
("mlkem/src/sampling.c", "mlk_poly_rej_uniform"),
110+
]
111+
112+
113+
def find_enclosing_function(content, pos):
114+
"""Find the name of the function enclosing the given position."""
115+
prefix = content[:pos]
116+
# Match function definitions: look for 'name(' at the start of a line
117+
# or after a type, followed eventually by '{' on its own line
118+
matches = list(re.finditer(r"\n\w+\s+(\w+)\s*\(", prefix))
119+
if matches:
120+
return matches[-1].group(1)
121+
return None
122+
123+
124+
def check_decreases():
125+
files = get_c_source_files() + get_header_files()
126+
bad = []
127+
128+
for filename in files:
129+
with open(filename, "r") as f:
130+
content = f.read()
131+
132+
# Find __loop__( that are actual loop annotations (not macro defs)
133+
for m in re.finditer(r"__loop__\(", content):
134+
start = m.start()
135+
line = content.count("\n", 0, start) + 1
136+
137+
# Skip macro definitions like #define __loop__(x)
138+
line_start = content.rfind("\n", 0, start) + 1
139+
line_text = content[line_start : content.find("\n", start)]
140+
if "#define" in line_text:
141+
continue
142+
143+
# Extract the full __loop__(...) content by matching parentheses
144+
depth = 0
145+
i = m.end() - 1
146+
for i in range(m.end() - 1, len(content)):
147+
if content[i] == "(":
148+
depth += 1
149+
elif content[i] == ")":
150+
depth -= 1
151+
if depth == 0:
152+
break
153+
loop_body = content[m.start() : i + 1]
154+
155+
if "decreases(" in loop_body:
156+
continue
157+
158+
func = find_enclosing_function(content, start)
159+
160+
if (filename, func) in DECREASES_EXCEPTIONS:
161+
print(
162+
f"{filename}:{line}: __loop__ without decreases in "
163+
f"{func}(), but is listed as exception"
164+
)
165+
continue
166+
167+
print(
168+
f"{filename}:{line}: __loop__ without decreases clause "
169+
f"in {func}(). FAIL",
170+
file=sys.stderr,
171+
)
172+
bad.append((filename, line))
173+
174+
return len(bad) == 0
175+
176+
88177
def _main():
89-
if check_contracts() != True:
178+
ok = True
179+
if not check_contracts():
180+
ok = False
181+
if not check_decreases():
182+
ok = False
183+
if not ok:
90184
sys.exit(1)
91185

186+
92187
if __name__ == "__main__":
93188
_main()

0 commit comments

Comments
 (0)