From b19fa7b378cdeafeae25f2e031537b827a011ad3 Mon Sep 17 00:00:00 2001 From: "Matthias J. Kannwischer" Date: Wed, 22 Apr 2026 12:03:25 +0800 Subject: [PATCH] check-contracts: Detect contracts separated from prototype by a comment The previous regex skipped contracts with a block comment between the prototype's `)` and `__contract__`, silently missing six contracts (polyvec_basemul_acc_montgomery_cached_asm_k{2,3,4} on both aarch64 and x86_64). Strip comments before matching, and extend the exception list to cover the `_asm_k{2,3,4}` suffixes alongside `_native`, `_asm`, and `_avx2`. Also fix a typo in the `_avx2` suffix. Also align DECREASES_EXCEPTIONS handling with mldsa-native. - Ported from pq-code-package/mldsa-native#1051. Signed-off-by: Matthias J. Kannwischer --- scripts/check-contracts | 56 +++++++++++++++++++++++++++++++++-------- 1 file changed, 45 insertions(+), 11 deletions(-) diff --git a/scripts/check-contracts b/scripts/check-contracts index 4d52a56ced..061627275e 100755 --- a/scripts/check-contracts +++ b/scripts/check-contracts @@ -1,5 +1,6 @@ #!/usr/bin/env python3 # Copyright (c) The mlkem-native project authors +# Copyright (c) The mldsa-native project authors # SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT # @@ -32,6 +33,16 @@ def gen_proofs(): return proofs +def strip_comments(content): + """Replace block and line comments with whitespace, preserving line numbers.""" + + def repl(m): + return re.sub(r"[^\n]", " ", m.group(0)) + + # Block comments and line comments. + return re.sub(r"/\*[\s\S]*?\*/|//[^\n]*", repl, content) + + def gen_contracts(): files = get_c_source_files() + get_header_files() @@ -39,7 +50,9 @@ def gen_contracts(): with open(filename, "r") as f: content = f.read() - contract_pattern = r"(\w+)\s*\([^)]*\)\s*\n?\s*__contract__" + content = strip_comments(content) + + contract_pattern = r"(\w+)\s*\([^)]*\)\s*__contract__" matches = re.finditer(contract_pattern, content) for m in matches: line = content.count("\n", 0, m.start()) @@ -55,7 +68,8 @@ def is_exception(funcname): if ( funcname.endswith("_native") or funcname.endswith("_asm") - or funcname.endswith("avx2") + or funcname.endswith(("_asm_k2", "_asm_k3", "_asm_k4")) + or funcname.endswith("_avx2") ): # CBMC proofs are axiomatized against contracts of the backends return True @@ -104,10 +118,11 @@ def check_contracts(): # Loops that only terminate probabilistically (rejection sampling) # and therefore cannot have a decreases clause. -DECREASES_EXCEPTIONS = [ - ("mlkem/src/sampling.c", "mlk_poly_rej_uniform_x4"), - ("mlkem/src/sampling.c", "mlk_poly_rej_uniform"), -] +# The value is the number of loops allowed to lack a decreases clause (default 1). +DECREASES_EXCEPTIONS = { + ("mlkem/src/sampling.c", "mlk_poly_rej_uniform_x4"): 1, + ("mlkem/src/sampling.c", "mlk_poly_rej_uniform"): 1, +} def find_enclosing_function(content, pos): @@ -125,6 +140,9 @@ def check_decreases(): files = get_c_source_files() + get_header_files() bad = [] + # Count loops without decreases per (filename, func) + missing = {} + for filename in files: with open(filename, "r") as f: content = f.read() @@ -156,20 +174,36 @@ def check_decreases(): continue func = find_enclosing_function(content, start) + key = (filename, func) + missing.setdefault(key, []).append(line) - if (filename, func) in DECREASES_EXCEPTIONS: + for (filename, func), lines in missing.items(): + key = (filename, func) + allowed = DECREASES_EXCEPTIONS.get(key, 0) + count = len(lines) + + if count <= allowed: + for line in lines: print( f"{filename}:{line}: __loop__ without decreases in " f"{func}(), but is listed as exception" ) - continue + continue + if allowed > 0: print( - f"{filename}:{line}: __loop__ without decreases clause " - f"in {func}(). FAIL", + f"{filename}: {func}() has {count} loops without decreases " + f"but only {allowed} allowed. FAIL", file=sys.stderr, ) - bad.append((filename, line)) + else: + for line in lines: + print( + f"{filename}:{line}: __loop__ without decreases clause " + f"in {func}(). FAIL", + file=sys.stderr, + ) + bad.append((filename, func)) return len(bad) == 0