Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 45 additions & 11 deletions scripts/check-contracts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3
# Copyright (c) The mlkem-native project authors
# Copyright (c) The mldsa-native project authors
# SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT

#
Expand Down Expand Up @@ -32,14 +33,26 @@ def gen_proofs():
return proofs


def strip_comments(content):
"""Replace block and line comments with whitespace, preserving line numbers."""

def repl(m):
return re.sub(r"[^\n]", " ", m.group(0))

# Block comments and line comments.
return re.sub(r"/\*[\s\S]*?\*/|//[^\n]*", repl, content)


def gen_contracts():
files = get_c_source_files() + get_header_files()

for filename in files:
with open(filename, "r") as f:
content = f.read()

contract_pattern = r"(\w+)\s*\([^)]*\)\s*\n?\s*__contract__"
content = strip_comments(content)

contract_pattern = r"(\w+)\s*\([^)]*\)\s*__contract__"
matches = re.finditer(contract_pattern, content)
for m in matches:
line = content.count("\n", 0, m.start())
Expand All @@ -55,7 +68,8 @@ def is_exception(funcname):
if (
funcname.endswith("_native")
or funcname.endswith("_asm")
or funcname.endswith("avx2")
or funcname.endswith(("_asm_k2", "_asm_k3", "_asm_k4"))
or funcname.endswith("_avx2")
Comment thread
mkannwischer marked this conversation as resolved.
):
# CBMC proofs are axiomatized against contracts of the backends
return True
Expand Down Expand Up @@ -104,10 +118,11 @@ def check_contracts():

# Loops that only terminate probabilistically (rejection sampling)
# and therefore cannot have a decreases clause.
DECREASES_EXCEPTIONS = [
("mlkem/src/sampling.c", "mlk_poly_rej_uniform_x4"),
("mlkem/src/sampling.c", "mlk_poly_rej_uniform"),
]
# The value is the number of loops allowed to lack a decreases clause (default 1).
DECREASES_EXCEPTIONS = {
("mlkem/src/sampling.c", "mlk_poly_rej_uniform_x4"): 1,
("mlkem/src/sampling.c", "mlk_poly_rej_uniform"): 1,
}


def find_enclosing_function(content, pos):
Expand All @@ -125,6 +140,9 @@ def check_decreases():
files = get_c_source_files() + get_header_files()
bad = []

# Count loops without decreases per (filename, func)
missing = {}

for filename in files:
with open(filename, "r") as f:
content = f.read()
Expand Down Expand Up @@ -156,20 +174,36 @@ def check_decreases():
continue

func = find_enclosing_function(content, start)
key = (filename, func)
missing.setdefault(key, []).append(line)

if (filename, func) in DECREASES_EXCEPTIONS:
for (filename, func), lines in missing.items():
key = (filename, func)
allowed = DECREASES_EXCEPTIONS.get(key, 0)
count = len(lines)

if count <= allowed:
for line in lines:
print(
f"{filename}:{line}: __loop__ without decreases in "
f"{func}(), but is listed as exception"
)
continue
continue

if allowed > 0:
print(
f"{filename}:{line}: __loop__ without decreases clause "
f"in {func}(). FAIL",
f"{filename}: {func}() has {count} loops without decreases "
f"but only {allowed} allowed. FAIL",
file=sys.stderr,
)
bad.append((filename, line))
else:
for line in lines:
print(
f"{filename}:{line}: __loop__ without decreases clause "
f"in {func}(). FAIL",
file=sys.stderr,
)
bad.append((filename, func))

return len(bad) == 0

Expand Down
Loading