Skip to content

Commit 8dae40b

Browse files
committed
check-contracts: Detect contracts separated from prototype by a comment
The previous regex skipped contracts with a block comment between the prototype's `)` and `__contract__`, silently missing six contracts (polyvec_basemul_acc_montgomery_cached_asm_k{2,3,4} on both aarch64 and x86_64). Strip comments before matching, and extend the exception list to cover the `_asm_k{2,3,4}` suffixes alongside `_native`, `_asm`, and `_avx2`. Also fix a typo in the `_avx2` suffix. Also align DECREASES_EXCEPTIONS handling with mldsa-native. - Ported from pq-code-package/mldsa-native#1051. Signed-off-by: Matthias J. Kannwischer <matthias@zerorisc.com>
1 parent 7da6768 commit 8dae40b

1 file changed

Lines changed: 45 additions & 11 deletions

File tree

scripts/check-contracts

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/usr/bin/env python3
22
# Copyright (c) The mlkem-native project authors
3+
# Copyright (c) The mldsa-native project authors
34
# SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
45

56
#
@@ -32,14 +33,26 @@ def gen_proofs():
3233
return proofs
3334

3435

36+
def strip_comments(content):
37+
"""Replace block and line comments with whitespace, preserving line numbers."""
38+
39+
def repl(m):
40+
return re.sub(r"[^\n]", " ", m.group(0))
41+
42+
# Block comments and line comments.
43+
return re.sub(r"/\*[\s\S]*?\*/|//[^\n]*", repl, content)
44+
45+
3546
def gen_contracts():
3647
files = get_c_source_files() + get_header_files()
3748

3849
for filename in files:
3950
with open(filename, "r") as f:
4051
content = f.read()
4152

42-
contract_pattern = r"(\w+)\s*\([^)]*\)\s*\n?\s*__contract__"
53+
content = strip_comments(content)
54+
55+
contract_pattern = r"(\w+)\s*\([^)]*\)\s*__contract__"
4356
matches = re.finditer(contract_pattern, content)
4457
for m in matches:
4558
line = content.count("\n", 0, m.start())
@@ -55,7 +68,8 @@ def is_exception(funcname):
5568
if (
5669
funcname.endswith("_native")
5770
or funcname.endswith("_asm")
58-
or funcname.endswith("avx2")
71+
or funcname.endswith(("_asm_k2", "_asm_k3", "_asm_k4"))
72+
or funcname.endswith("_avx2")
5973
):
6074
# CBMC proofs are axiomatized against contracts of the backends
6175
return True
@@ -104,10 +118,11 @@ def check_contracts():
104118

105119
# Loops that only terminate probabilistically (rejection sampling)
106120
# and therefore cannot have a decreases clause.
107-
DECREASES_EXCEPTIONS = [
108-
("mlkem/src/sampling.c", "mlk_poly_rej_uniform_x4"),
109-
("mlkem/src/sampling.c", "mlk_poly_rej_uniform"),
110-
]
121+
# The value is the number of loops allowed to lack a decreases clause (default 1).
122+
DECREASES_EXCEPTIONS = {
123+
("mlkem/src/sampling.c", "mlk_poly_rej_uniform_x4"): 1,
124+
("mlkem/src/sampling.c", "mlk_poly_rej_uniform"): 1,
125+
}
111126

112127

113128
def find_enclosing_function(content, pos):
@@ -125,6 +140,9 @@ def check_decreases():
125140
files = get_c_source_files() + get_header_files()
126141
bad = []
127142

143+
# Count loops without decreases per (filename, func)
144+
missing = {}
145+
128146
for filename in files:
129147
with open(filename, "r") as f:
130148
content = f.read()
@@ -156,20 +174,36 @@ def check_decreases():
156174
continue
157175

158176
func = find_enclosing_function(content, start)
177+
key = (filename, func)
178+
missing.setdefault(key, []).append(line)
159179

160-
if (filename, func) in DECREASES_EXCEPTIONS:
180+
for (filename, func), lines in missing.items():
181+
key = (filename, func)
182+
allowed = DECREASES_EXCEPTIONS.get(key, 0)
183+
count = len(lines)
184+
185+
if count <= allowed:
186+
for line in lines:
161187
print(
162188
f"{filename}:{line}: __loop__ without decreases in "
163189
f"{func}(), but is listed as exception"
164190
)
165-
continue
191+
continue
166192

193+
if allowed > 0:
167194
print(
168-
f"{filename}:{line}: __loop__ without decreases clause "
169-
f"in {func}(). FAIL",
195+
f"{filename}: {func}() has {count} loops without decreases "
196+
f"but only {allowed} allowed. FAIL",
170197
file=sys.stderr,
171198
)
172-
bad.append((filename, line))
199+
else:
200+
for line in lines:
201+
print(
202+
f"{filename}:{line}: __loop__ without decreases clause "
203+
f"in {func}(). FAIL",
204+
file=sys.stderr,
205+
)
206+
bad.append((filename, func))
173207

174208
return len(bad) == 0
175209

0 commit comments

Comments
 (0)