Skip to content

Commit aca3c3d

Browse files
committed
AArch64: Add ABI checker
This commit adds an ABI checker for all AArch64 assembly in mlkem-native. The purpose of an ABI checker is to ensure that assembly files obey the required calling conventions. For AArch64 specifically, it is required by the AAPCS that x19-x30 and d8-d15 are unmodified. It is important to test this since failure to obey the AAPCS can lead to rare but fatal bugs depending on the exact compiler and compilation settings. On a high level, the ABI checker works as follows: 1. An assembly callstub void asm_call_stub(struct register_state *input, struct register_state *output, void (*function_ptr)(void)) is implemented which calls the function specified by function_ptr on the input register state input, and returns the output register state in output. Importantly, it does _not_ assume that function_ptr obeys the AAPCS, so must be implemented in assembly. This is done for AArch64 only so far. 2. Randomized ABI checker Based on asm_call_stub, we implement an ABI checker which calls the target function with random register inputs and checks that the output register state has retained the callee saved registers. 3. Pointer inputs Most functions read/write memory specified via parameters. Calling those functions with random input would lead to segmentation faults. The ABI checker therefore does not randomize those arguments, but instead sets up valid pointers of the right size for them. For now, we always align the input pointers via MLK_ALIGN. A future extension should take alignment constraints into account and vary the alignment of the inputs within those constraints. The ABI checker gets the information about which inputs are pointers from the YAML specification of AArch64 assembly. Some implementation details: - autogen is extended to generate one test/abicheck/check_FUNCNAME.c file per function under test. All tests are declared and enumerated in the autogenerated test/abicheck/checks_all.h. - The handwritten test/abicheck/abicheck.c invokes all per-function ABI checks. - We integrate the abicheck build by extending test/mk/components.mk with specific build rules. This allows us to partly piggy back on existing Makefile infrastructure, such as setting execution wrappers or cross prefixes. On the other hand, a lot of custom hackery is currently needed to get past the preprocessor directives guarding individual assembly files. This should be improved in the future. - scripts/tests is extended to allow for tests abicheck, and tests all supports --abicheck and --no-abicheck. By default, the abicheck _is_ run with tests all. Signed-off-by: Hanno Becker <beckphan@amazon.co.uk>
1 parent 45f8f40 commit aca3c3d

26 files changed

Lines changed: 2067 additions & 2 deletions

Makefile

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
clean quickcheck check-defined-CYCLES \
1818
size_512 size_768 size_1024 size \
1919
run_size_512 run_size_768 run_size_1024 run_size \
20-
host_info
20+
host_info abicheck run_abicheck
2121

2222
SHELL := /usr/bin/env bash
2323
.DEFAULT_GOAL := build
@@ -46,7 +46,7 @@ quickcheck: test
4646
build: func kat acvp wycheproof
4747
$(Q)echo " Everything builds fine!"
4848

49-
test: run_kat run_func run_acvp run_wycheproof run_unit run_alloc run_rng_fail
49+
test: run_kat run_func run_acvp run_wycheproof run_unit run_alloc run_rng_fail run_abicheck
5050
$(Q)echo " Everything checks fine!"
5151

5252
# Detect available SHA256 command
@@ -251,6 +251,16 @@ run_size: \
251251
run_size_768 \
252252
run_size_1024
253253

254+
ifeq ($(OPT),1)
255+
abicheck: $(ABICHECK_DIR)/bin/abicheck
256+
257+
run_abicheck: abicheck
258+
$(W) $(ABICHECK_DIR)/bin/abicheck
259+
else
260+
abicheck:
261+
run_abicheck:
262+
endif
263+
254264
# Display host and compiler feature detection information
255265
# Shows which architectural features are supported by both the compiler and host CPU
256266
# Usage: make host_info [AUTO=0|1] [CROSS_PREFIX=...]

scripts/autogen

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4159,6 +4159,239 @@ def gen_test_configs():
41594159
gen_test_config(config_spec["path"], config_spec, default_config)
41604160

41614161

4162+
def extract_yaml_from_assembly(assembly_file):
4163+
"""Extract YAML metadata from assembly file."""
4164+
with open(assembly_file, "r") as f:
4165+
content = f.read()
4166+
4167+
yaml_match = re.search(r"/\*yaml\s*\n(.*?)\n\*/", content, re.DOTALL)
4168+
if not yaml_match:
4169+
raise ValueError(f"No YAML metadata found in {assembly_file}")
4170+
4171+
return yaml.safe_load(yaml_match.group(1))
4172+
4173+
4174+
def extract_arch_flags_from_assembly(assembly_file):
4175+
"""Extract architecture-specific preprocessor flags from assembly file."""
4176+
with open(assembly_file, "r") as f:
4177+
content = f.read()
4178+
4179+
arch_flags = []
4180+
if re.search(r"#if\s+defined\(__ARM_FEATURE_SHA3\)", content):
4181+
arch_flags.append("__ARM_FEATURE_SHA3")
4182+
return arch_flags
4183+
4184+
4185+
def resolve_buffer_size(reg_info, abi_data):
4186+
"""Resolve buffer size from register info, handling cross-references."""
4187+
size = reg_info.get("size_bytes")
4188+
if isinstance(size, str):
4189+
if size in abi_data:
4190+
ref_reg = abi_data[size]
4191+
if ref_reg.get("type") == "scalar" and "test_with" in ref_reg:
4192+
return ref_reg["test_with"]
4193+
raise ValueError(f"Buffer {reg_info} references non-scalar register {size}")
4194+
return int(size)
4195+
elif isinstance(size, int):
4196+
return size
4197+
raise ValueError(f"Cannot resolve buffer size from {reg_info}")
4198+
4199+
4200+
def gen_abicheck():
4201+
"""Generate ABI checker tests for all aarch64 assembly functions."""
4202+
4203+
base_path = pathlib.Path(".")
4204+
patterns = [
4205+
"mlkem/src/fips202/native/aarch64/src/*.S",
4206+
"mlkem/src/native/aarch64/src/*.S",
4207+
]
4208+
4209+
aarch64_asm_files = []
4210+
for pattern in patterns:
4211+
aarch64_asm_files.extend(base_path.glob(pattern))
4212+
4213+
if not aarch64_asm_files:
4214+
return
4215+
4216+
aarch64_asm_files = sorted(str(f) for f in aarch64_asm_files)
4217+
generated_functions = []
4218+
4219+
for assembly_file in aarch64_asm_files:
4220+
yaml_data = extract_yaml_from_assembly(assembly_file)
4221+
arch_flags = extract_arch_flags_from_assembly(assembly_file)
4222+
4223+
function_name = yaml_data.get("Name")
4224+
c_function_name = "mlk_" + function_name
4225+
4226+
def gen_c_test(
4227+
function_name=function_name,
4228+
c_function_name=c_function_name,
4229+
yaml_data=yaml_data,
4230+
arch_flags=arch_flags,
4231+
):
4232+
yield from gen_header()
4233+
yield '#include "../../mlkem/src/sys.h"'
4234+
yield ""
4235+
yield "#if defined(MLK_SYS_AARCH64)"
4236+
yield ""
4237+
4238+
for flag in arch_flags:
4239+
yield f"#if defined({flag})"
4240+
yield ""
4241+
4242+
yield "#include <stdio.h>"
4243+
yield "#include <string.h>"
4244+
yield ""
4245+
yield '#include "../notrandombytes/notrandombytes.h"'
4246+
yield '#include "abicheckutil.h"'
4247+
yield '#include "checks_all.h"'
4248+
yield ""
4249+
yield "typedef struct register_state register_state;"
4250+
yield ""
4251+
yield "#define NUM_TESTS 3"
4252+
yield ""
4253+
yield yaml_data.get("Signature") + ";"
4254+
yield ""
4255+
yield f"int check_{function_name}(void)"
4256+
yield "{"
4257+
4258+
yield " int test_iter;"
4259+
yield " register_state input_state, output_state;"
4260+
yield " int violations;"
4261+
4262+
abi_data = yaml_data.get("ABI", {})
4263+
sorted_registers = sorted(abi_data.items())
4264+
4265+
buffer_info = []
4266+
for reg_name, reg_info in sorted_registers:
4267+
if reg_info.get("type") == "buffer":
4268+
size_bytes = resolve_buffer_size(reg_info, abi_data)
4269+
buffer_name = f"buf_{reg_name}"
4270+
description = reg_info.get("description", "")
4271+
yield f" MLK_ALIGN uint8_t {buffer_name}[{size_bytes}]; /* {description} */"
4272+
buffer_info.append((reg_name, buffer_name, size_bytes))
4273+
4274+
yield ""
4275+
yield " for (test_iter = 0; test_iter < NUM_TESTS; test_iter++)"
4276+
yield " {"
4277+
yield " /* Initialize random register state */"
4278+
yield " init_random_register_state(&input_state);"
4279+
yield ""
4280+
4281+
for reg_name, buffer_name, size_bytes in buffer_info:
4282+
yield f" /* Initialize buffer for {reg_name} */"
4283+
yield f" randombytes({buffer_name}, {size_bytes});"
4284+
4285+
yield ""
4286+
yield " /* Set up register state for function arguments */"
4287+
4288+
for reg_name, reg_info in sorted_registers:
4289+
reg_num = int(reg_name[1:])
4290+
if reg_info.get("type") == "buffer":
4291+
yield f" input_state.gpr[{reg_num}] = (uint64_t){f'buf_{reg_name}'};"
4292+
elif reg_info.get("type") == "scalar":
4293+
test_with = reg_info.get("test_with")
4294+
if test_with:
4295+
yield f" input_state.gpr[{reg_num}] = {test_with};"
4296+
4297+
yield ""
4298+
yield " /* Call function through ABI test stub */"
4299+
yield f" asm_call_stub(&input_state, &output_state, (void (*)(void)){c_function_name});"
4300+
yield ""
4301+
yield " /* Check ABI compliance */"
4302+
yield " violations = check_aarch64_aapcs_compliance(&input_state, &output_state);"
4303+
yield " if (violations > 0) {"
4304+
yield f' fprintf(stderr, "ABI test FAILED for {function_name} (iteration %d): %d violations\\n",'
4305+
yield " test_iter + 1, violations);"
4306+
yield " return 1;"
4307+
yield " }"
4308+
yield " }"
4309+
yield ""
4310+
yield " return 0;"
4311+
yield "}"
4312+
yield ""
4313+
4314+
for flag in reversed(arch_flags):
4315+
yield f"#else /* !{flag} */"
4316+
yield ""
4317+
yield '#include "../../mlkem/src/common.h"'
4318+
yield f"MLK_EMPTY_CU(check_{function_name})"
4319+
yield ""
4320+
yield f"#endif /* {flag} */"
4321+
yield ""
4322+
4323+
yield "#else /* !MLK_SYS_AARCH64 */"
4324+
yield ""
4325+
yield '#include "../../mlkem/src/common.h"'
4326+
yield f"MLK_EMPTY_CU(check_{function_name})"
4327+
yield ""
4328+
yield "#endif /* MLK_SYS_AARCH64 */"
4329+
yield ""
4330+
4331+
output_file = f"test/abicheck/check_{function_name}.c"
4332+
update_file(output_file, "\n".join(gen_c_test()), force_format=True)
4333+
generated_functions.append((function_name, arch_flags))
4334+
4335+
# Generate checks_all.h
4336+
def gen_checks_all_header():
4337+
yield from gen_header()
4338+
yield ""
4339+
yield "#ifndef CHECKS_ALL_H"
4340+
yield "#define CHECKS_ALL_H"
4341+
yield ""
4342+
yield "#include <stddef.h>"
4343+
yield '#include "../../mlkem/src/sys.h"'
4344+
yield ""
4345+
yield "/* Array of all check functions */"
4346+
yield "typedef struct"
4347+
yield "{"
4348+
yield " const char *name;"
4349+
yield " int (*check_func)(void);"
4350+
yield "} abicheck_entry_t;"
4351+
yield ""
4352+
yield "#if defined(MLK_SYS_AARCH64)"
4353+
yield ""
4354+
yield "/* Function prototypes for all ABI checks (only on AArch64) */"
4355+
4356+
for func_name, arch_flags in generated_functions:
4357+
for flag in arch_flags:
4358+
yield f"#if defined({flag})"
4359+
yield f"int check_{func_name}(void);"
4360+
for flag in reversed(arch_flags):
4361+
yield f"#endif"
4362+
4363+
yield ""
4364+
yield "static const abicheck_entry_t all_checks[] = {"
4365+
4366+
for func_name, arch_flags in generated_functions:
4367+
for flag in arch_flags:
4368+
yield f"#if defined({flag})"
4369+
yield f' {{"{func_name}", check_{func_name}}},'
4370+
for flag in reversed(arch_flags):
4371+
yield f"#endif"
4372+
4373+
yield " {NULL, NULL} /* Sentinel */"
4374+
yield "};"
4375+
yield ""
4376+
yield "#else /* !MLK_SYS_AARCH64 */"
4377+
yield ""
4378+
yield "/* Empty array for non-AArch64 platforms */"
4379+
yield "static const abicheck_entry_t all_checks[] = {"
4380+
yield " {NULL, NULL} /* Sentinel */"
4381+
yield "};"
4382+
yield ""
4383+
yield "#endif /* MLK_SYS_AARCH64 */"
4384+
yield ""
4385+
yield "#endif /* CHECKS_ALL_H */"
4386+
yield ""
4387+
4388+
update_file(
4389+
"test/abicheck/checks_all.h",
4390+
"\n".join(gen_checks_all_header()),
4391+
force_format=True,
4392+
)
4393+
4394+
41624395
def _main():
41634396
slothy_choices = [
41644397
"ntt",
@@ -4277,6 +4510,7 @@ def _main():
42774510
("Generate undefs", gen_undefs),
42784511
("Generate test configs", gen_test_configs),
42794512
("Check macro typos", check_macro_typos),
4513+
("Generate ABI checker tests", gen_abicheck),
42804514
("Generate preprocessor comments", gen_preprocessor_comments),
42814515
# Formatting should be the last step
42824516
("Format files", lambda: format_files(args.dry_run)),

scripts/tests

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ class TEST_TYPES(Enum):
214214
ALLOC = 20
215215
RNG_FAIL = 21
216216
WYCHEPROOF = 22
217+
ABICHECK = 23
217218

218219
def is_benchmark(self):
219220
return self in [TEST_TYPES.BENCH, TEST_TYPES.BENCH_COMPONENTS]
@@ -294,6 +295,8 @@ class TEST_TYPES(Enum):
294295
return "Alloc Test"
295296
if self == TEST_TYPES.RNG_FAIL:
296297
return "RNG Failure Test"
298+
if self == TEST_TYPES.ABICHECK:
299+
return "ABI Compliance Test"
297300

298301
def make_dir(self):
299302
if self == TEST_TYPES.BRING_YOUR_OWN_FIPS202:
@@ -365,6 +368,8 @@ class TEST_TYPES(Enum):
365368
return "alloc"
366369
if self == TEST_TYPES.RNG_FAIL:
367370
return "rng_fail"
371+
if self == TEST_TYPES.ABICHECK:
372+
return "abicheck"
368373

369374
def make_run_target(self, scheme):
370375
t = self.make_target()
@@ -830,6 +835,17 @@ class Tests:
830835
if resultss is None:
831836
self.check_fail()
832837

838+
def abicheck(self):
839+
"""Run ABI compliance tests for assembly functions."""
840+
if not self.do_opt():
841+
return
842+
843+
self._compile_schemes(TEST_TYPES.ABICHECK, True)
844+
if self.args.run:
845+
self._run_scheme(TEST_TYPES.ABICHECK, True, None)
846+
847+
self.check_fail()
848+
833849
def all(self):
834850
func = self.args.func
835851
kat = self.args.kat
@@ -840,6 +856,7 @@ class Tests:
840856
unit = self.args.unit
841857
alloc = self.args.alloc
842858
rng_fail = self.args.rng_fail
859+
abicheck = self.args.abicheck
843860

844861
def _all(opt):
845862
if func is True:
@@ -858,6 +875,8 @@ class Tests:
858875
self._compile_schemes(TEST_TYPES.ALLOC, opt)
859876
if rng_fail is True:
860877
self._compile_schemes(TEST_TYPES.RNG_FAIL, opt)
878+
if abicheck is True and opt:
879+
self._compile_schemes(TEST_TYPES.ABICHECK, opt)
861880

862881
if self.args.check_namespace is True:
863882
p = subprocess.run(
@@ -887,6 +906,8 @@ class Tests:
887906
self._run_schemes(TEST_TYPES.ALLOC, opt)
888907
if rng_fail is True:
889908
self._run_schemes(TEST_TYPES.RNG_FAIL, opt)
909+
if abicheck is True and opt:
910+
self._run_scheme(TEST_TYPES.ABICHECK, opt, None)
890911

891912
if self.do_no_opt():
892913
_all(False)
@@ -1308,6 +1329,21 @@ def cli():
13081329
help="Do not run RNG failure tests",
13091330
)
13101331

1332+
abicheck_group = all_parser.add_mutually_exclusive_group()
1333+
abicheck_group.add_argument(
1334+
"--abicheck",
1335+
action="store_true",
1336+
dest="abicheck",
1337+
help="Run ABI compliance tests",
1338+
default=True,
1339+
)
1340+
abicheck_group.add_argument(
1341+
"--no-abicheck",
1342+
action="store_false",
1343+
dest="abicheck",
1344+
help="Do not run ABI compliance tests",
1345+
)
1346+
13111347
# acvp arguments
13121348
acvp_parser = cmd_subparsers.add_parser(
13131349
"acvp", help="Run ACVP client", parents=[common_parser]
@@ -1385,6 +1421,13 @@ def cli():
13851421
parents=[common_parser],
13861422
)
13871423

1424+
# abicheck arguments
1425+
cmd_subparsers.add_parser(
1426+
"abicheck",
1427+
help="Run ABI compliance tests for assembly functions",
1428+
parents=[common_parser],
1429+
)
1430+
13881431
# cbmc arguments
13891432
cbmc_parser = cmd_subparsers.add_parser(
13901433
"cbmc",
@@ -1576,6 +1619,8 @@ def cli():
15761619
Tests(args).alloc()
15771620
elif args.cmd == "rng_fail":
15781621
Tests(args).rng_fail()
1622+
elif args.cmd == "abicheck":
1623+
Tests(args).abicheck()
15791624

15801625

15811626
if __name__ == "__main__":

0 commit comments

Comments
 (0)