Skip to content

Commit b1b11e9

Browse files
jakemasclaude
authored andcommitted
Address review comments on PR #1014
Reviewer-requested cleanup for the x86_64 rej_uniform assembly and HOL Light proof: Contract tightening (dev and mldsa copies of arith_native_x86_64.h): - requires(memory_no_alias(buf, 840)) instead of memory_no_alias(buf, MLD_AVX2_REJ_UNIFORM_BUFLEN) so the literal matches the HOL Light spec exactly. - requires(table == (const uint8_t *)mld_rej_uniform_table) pinning the table to the exported rejection-sampling table, replacing the looser memory_no_alias(table, 256 * sizeof(uint64_t)). - Clarify sync comment. vzeroupper removal: none of the other asm routines issue vzeroupper; drop it from rej_uniform for consistency. This shifts the function length by 3 bytes, so the HOL Light proof's nonoverlapping 246 / pc+245 references in mldsa_rej_uniform.ml become 243 / pc+242 accordingly, and the two X86_STEPS_TAC invocations that stepped the vzeroupper byte are removed. Bytecode regenerated via autogen --update-hol-light-bytecode. Autogen plumbing: register rej_uniform_avx2_asm.S in the x86_64 HOL Light asm joblist so the proofs/hol_light/x86_64/mldsa/ copy is regenerated by scripts/autogen. Add gen_avx2_hol_light_rej_uniform_table to regenerate proofs/hol_light/x86_64/proofs/mldsa_rej_uniform_table.ml alongside the C/aarch64 lookup tables (matches mlkem-native's pattern). Cross-reference comment in proofs/hol_light/x86_64/proofs/ rej_uniform_avx2_asm.ml pointing at the CBMC contract. Proof runtime: ~5-6 min in the CI native build. Signed-off-by: Jake Massimo <jakemas@amazon.com> Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent b87f807 commit b1b11e9

8 files changed

Lines changed: 161 additions & 92 deletions

File tree

dev/x86_64/src/arith_native_x86_64.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,16 +74,16 @@ __contract__(
7474
);
7575

7676
#define mld_rej_uniform_avx2_asm MLD_NAMESPACE(rej_uniform_avx2_asm)
77-
/* This must be kept in sync with the HOL-Light specification
77+
/* This contract must be kept in sync with the HOL-Light specification
7878
* in proofs/hol_light/x86_64/proofs/rej_uniform_avx2_asm.ml */
7979
MLD_MUST_CHECK_RETURN_VALUE
8080
unsigned mld_rej_uniform_avx2_asm(
8181
int32_t *r, const uint8_t buf[MLD_AVX2_REJ_UNIFORM_BUFLEN],
8282
const uint8_t *table)
8383
__contract__(
8484
requires(memory_no_alias(r, sizeof(int32_t) * MLDSA_N))
85-
requires(memory_no_alias(buf, MLD_AVX2_REJ_UNIFORM_BUFLEN))
86-
requires(memory_no_alias(table, 256 * sizeof(uint64_t)))
85+
requires(memory_no_alias(buf, 840))
86+
requires(table == (const uint8_t *)mld_rej_uniform_table)
8787
assigns(memory_slice(r, sizeof(int32_t) * MLDSA_N))
8888
ensures(return_value <= MLDSA_N)
8989
ensures(array_bound(r, 0, return_value, 0, MLDSA_Q))

dev/x86_64/src/rej_uniform_avx2_asm.S

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,6 @@ rej_uniform_avx2_asm_scalar:
149149
jmp rej_uniform_avx2_asm_scalar
150150

151151
rej_uniform_avx2_asm_done:
152-
vzeroupper
153152
ret
154153

155154
/* To facilitate single-compilation-unit (SCU) builds, undefine all macros.

mldsa/src/native/x86_64/src/arith_native_x86_64.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,16 +74,16 @@ __contract__(
7474
);
7575

7676
#define mld_rej_uniform_avx2_asm MLD_NAMESPACE(rej_uniform_avx2_asm)
77-
/* This must be kept in sync with the HOL-Light specification
77+
/* This contract must be kept in sync with the HOL-Light specification
7878
* in proofs/hol_light/x86_64/proofs/rej_uniform_avx2_asm.ml */
7979
MLD_MUST_CHECK_RETURN_VALUE
8080
unsigned mld_rej_uniform_avx2_asm(
8181
int32_t *r, const uint8_t buf[MLD_AVX2_REJ_UNIFORM_BUFLEN],
8282
const uint8_t *table)
8383
__contract__(
8484
requires(memory_no_alias(r, sizeof(int32_t) * MLDSA_N))
85-
requires(memory_no_alias(buf, MLD_AVX2_REJ_UNIFORM_BUFLEN))
86-
requires(memory_no_alias(table, 256 * sizeof(uint64_t)))
85+
requires(memory_no_alias(buf, 840))
86+
requires(table == (const uint8_t *)mld_rej_uniform_table)
8787
assigns(memory_slice(r, sizeof(int32_t) * MLDSA_N))
8888
ensures(return_value <= MLDSA_N)
8989
ensures(array_bound(r, 0, return_value, 0, MLDSA_Q))

mldsa/src/native/x86_64/src/rej_uniform_avx2_asm.S

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ Lrej_uniform_avx2_asm_scalar:
8888
jmp Lrej_uniform_avx2_asm_scalar
8989

9090
Lrej_uniform_avx2_asm_done:
91-
vzeroupper
9291
retq
9392
.cfi_endproc
9493

proofs/hol_light/x86_64/mldsa/rej_uniform_avx2_asm.S

Lines changed: 63 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -20,85 +20,80 @@
2020

2121
/*
2222
* WARNING: This file is auto-derived from the mldsa-native source file
23-
* dev/x86_64/src/rej_uniform_avx2.S using scripts/simpasm. Do not modify it directly.
23+
* dev/x86_64/src/rej_uniform_avx2_asm.S using scripts/simpasm. Do not modify it directly.
2424
*/
2525

26-
#if defined(__ELF__)
27-
.section .note.GNU-stack,"",@progbits
28-
#endif
29-
3026
.text
3127
.balign 4
3228
#ifdef __APPLE__
33-
.global _PQCP_MLDSA_NATIVE_MLDSA44_mld_rej_uniform_avx2
34-
_PQCP_MLDSA_NATIVE_MLDSA44_mld_rej_uniform_avx2:
29+
.global _PQCP_MLDSA_NATIVE_MLDSA44_rej_uniform_avx2_asm
30+
_PQCP_MLDSA_NATIVE_MLDSA44_rej_uniform_avx2_asm:
3531
#else
36-
.global PQCP_MLDSA_NATIVE_MLDSA44_mld_rej_uniform_avx2
37-
PQCP_MLDSA_NATIVE_MLDSA44_mld_rej_uniform_avx2:
32+
.global PQCP_MLDSA_NATIVE_MLDSA44_rej_uniform_avx2_asm
33+
PQCP_MLDSA_NATIVE_MLDSA44_rej_uniform_avx2_asm:
3834
#endif
3935

4036
.cfi_startproc
4137
endbr64
42-
movabs $0xff050403ff020100,%r10
43-
44-
vmovq %r10,%xmm0
45-
movabs $0xff0b0a09ff080706,%r10
46-
47-
vpinsrq $0x1,%r10,%xmm0,%xmm0
48-
movabs $0xff090807ff060504,%r10
38+
movabsq $-0xfafbfc00fdff00, %r10 # imm = 0xFF050403FF020100
39+
vmovq %r10, %xmm0
40+
movabsq $-0xf4f5f600f7f8fa, %r10 # imm = 0xFF0B0A09FF080706
41+
vpinsrq $0x1, %r10, %xmm0, %xmm0
42+
movabsq $-0xf6f7f800f9fafc, %r10 # imm = 0xFF090807FF060504
43+
vmovq %r10, %xmm3
44+
movabsq $-0xf0f1f200f3f4f6, %r10 # imm = 0xFF0F0E0DFF0C0B0A
45+
vpinsrq $0x1, %r10, %xmm3, %xmm3
46+
vinserti128 $0x1, %xmm3, %ymm0, %ymm0
47+
movl $0x7fffff, %r8d # imm = 0x7FFFFF
48+
vmovd %r8d, %xmm1
49+
vpbroadcastd %xmm1, %ymm1
50+
movl $0x7fe001, %r8d # imm = 0x7FE001
51+
vmovd %r8d, %xmm2
52+
vpbroadcastd %xmm2, %ymm2
53+
xorl %eax, %eax
54+
xorl %ecx, %ecx
4955

50-
vmovq %r10,%xmm3
51-
movabs $0xff0f0e0dff0c0b0a,%r10
56+
Lrej_uniform_avx2_asm_loop:
57+
cmpl $0xf8, %eax
58+
ja Lrej_uniform_avx2_asm_scalar
59+
cmpl $0x328, %ecx # imm = 0x328
60+
ja Lrej_uniform_avx2_asm_scalar
61+
vmovdqu (%rsi,%rcx), %ymm3
62+
addl $0x18, %ecx
63+
vpermq $0x94, %ymm3, %ymm3 # ymm3 = ymm3[0,1,1,2]
64+
vpshufb %ymm0, %ymm3, %ymm3
65+
vpand %ymm1, %ymm3, %ymm3
66+
vpsubd %ymm2, %ymm3, %ymm4
67+
vmovmskps %ymm4, %r8d
68+
popcntl %r8d, %r9d
69+
vmovq (%rdx,%r8,8), %xmm4
70+
vpmovzxbd %xmm4, %ymm4 # ymm4 = xmm4[0],zero,zero,zero,xmm4[1],zero,zero,zero,xmm4[2],zero,zero,zero,xmm4[3],zero,zero,zero,xmm4[4],zero,zero,zero,xmm4[5],zero,zero,zero,xmm4[6],zero,zero,zero,xmm4[7],zero,zero,zero
71+
vpermd %ymm3, %ymm4, %ymm3
72+
vmovdqu %ymm3, (%rdi,%rax,4)
73+
addl %r9d, %eax
74+
jmp Lrej_uniform_avx2_asm_loop
5275

53-
vpinsrq $0x1,%r10,%xmm3,%xmm3
54-
vinserti128 $0x1,%xmm3,%ymm0,%ymm0
55-
mov $0x7fffff,%r8d
56-
vmovd %r8d,%xmm1
57-
vpbroadcastd %xmm1,%ymm1
58-
mov $0x7fe001,%r8d
59-
vmovd %r8d,%xmm2
60-
vpbroadcastd %xmm2,%ymm2
61-
xor %eax,%eax
62-
xor %ecx,%ecx
76+
Lrej_uniform_avx2_asm_scalar:
77+
cmpl $0x100, %eax # imm = 0x100
78+
jae Lrej_uniform_avx2_asm_done
79+
cmpl $0x345, %ecx # imm = 0x345
80+
ja Lrej_uniform_avx2_asm_done
81+
movzwl (%rsi,%rcx), %r8d
82+
movzbl 0x2(%rsi,%rcx), %r9d
83+
shll $0x10, %r9d
84+
orl %r9d, %r8d
85+
andl $0x7fffff, %r8d # imm = 0x7FFFFF
86+
addl $0x3, %ecx
87+
cmpl $0x7fe001, %r8d # imm = 0x7FE001
88+
jae Lrej_uniform_avx2_asm_scalar
89+
movl %r8d, (%rdi,%rax,4)
90+
addl $0x1, %eax
91+
jmp Lrej_uniform_avx2_asm_scalar
6392

64-
Lmld_rej_uniform_avx2_loop:
65-
cmp $0xf8,%eax
66-
ja Lmld_rej_uniform_avx2_scalar
67-
cmp $0x328,%ecx
68-
ja Lmld_rej_uniform_avx2_scalar
69-
vmovdqu (%rsi,%rcx,1),%ymm3
70-
add $0x18,%ecx
71-
vpermq $0x94,%ymm3,%ymm3
72-
vpshufb %ymm0,%ymm3,%ymm3
73-
vpand %ymm1,%ymm3,%ymm3
74-
vpsubd %ymm2,%ymm3,%ymm4
75-
vmovmskps %ymm4,%r8d
76-
popcnt %r8d,%r9d
77-
vmovq (%rdx,%r8,8),%xmm4
78-
vpmovzxbd %xmm4,%ymm4
79-
vpermd %ymm3,%ymm4,%ymm3
80-
vmovdqu %ymm3,(%rdi,%rax,4)
81-
add %r9d,%eax
82-
jmp Lmld_rej_uniform_avx2_loop
83-
84-
Lmld_rej_uniform_avx2_scalar:
85-
cmp $0x100,%eax
86-
jae Lmld_rej_uniform_avx2_done
87-
cmp $0x345,%ecx
88-
ja Lmld_rej_uniform_avx2_done
89-
movzwl (%rsi,%rcx,1),%r8d
90-
movzbl 0x2(%rsi,%rcx,1),%r9d
91-
shl $0x10,%r9d
92-
or %r9d,%r8d
93-
and $0x7fffff,%r8d
94-
add $0x3,%ecx
95-
cmp $0x7fe001,%r8d
96-
jae Lmld_rej_uniform_avx2_scalar
97-
mov %r8d,(%rdi,%rax,4)
98-
add $0x1,%eax
99-
jmp Lmld_rej_uniform_avx2_scalar
100-
101-
Lmld_rej_uniform_avx2_done:
102-
vzeroupper
103-
ret
93+
Lrej_uniform_avx2_asm_done:
94+
retq
10495
.cfi_endproc
96+
97+
#if defined(__ELF__)
98+
.section .note.GNU-stack,"",%progbits
99+
#endif

proofs/hol_light/x86_64/proofs/mldsa_rej_uniform_table.ml

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
(*
2-
* Copyright (c) The mldsa-native project authors
32
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
43
* SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT-0
54
*)
65

7-
(* Lookup table for ML-DSA rejection uniform sampling. *)
8-
(* Each entry is 8 bytes: permutation indices for VPERMD. *)
6+
(*
7+
* WARNING: This file is auto-generated from scripts/autogen
8+
* in the mldsa-native repository.
9+
* Do not modify it directly.
10+
*)
11+
12+
(*
13+
* Lookup table used by rejection sampling in the x86_64 AVX2
14+
* implementation. See autogen for details.
15+
*)
916

1017
let mldsa_rej_uniform_table = (REWRITE_RULE[MAP] o define)
1118
`mldsa_rej_uniform_table:byte list = MAP word [
@@ -264,5 +271,5 @@ let mldsa_rej_uniform_table = (REWRITE_RULE[MAP] o define)
264271
2; 3; 4; 5; 6; 7; 0; 0;
265272
0; 2; 3; 4; 5; 6; 7; 0;
266273
1; 2; 3; 4; 5; 6; 7; 0;
267-
0; 1; 2; 3; 4; 5; 6; 7]`
268-
;;
274+
0; 1; 2; 3; 4; 5; 6; 7
275+
]`;;

proofs/hol_light/x86_64/proofs/rej_uniform_avx2_asm.ml

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@ let mldsa_rej_uniform_mc = define_assert_from_elf
101101
0x44; 0x89; 0x04; 0x87; (* MOV (Memop Doubleword (%%% (rdi,2,rax))) (% r8d) *)
102102
0x83; 0xc0; 0x01; (* ADD (% eax) (Imm8 (word 1)) *)
103103
0xeb; 0xc3; (* JMP (Imm8 (word 195)) *)
104-
0xc5; 0xf8; 0x77; (* VZEROUPPER *)
105104
0xc3 (* RET *)
106105
];;
107106
(*** BYTECODE END ***)
@@ -406,6 +405,31 @@ let CMP_MASK_CORRECT = prove(
406405
CONV_TAC(DEPTH_CONV WORD_NUM_RED_CONV) THEN
407406
CONV_TAC NUM_REDUCE_CONV);;
408407

408+
(* Pre-compute the 256 table entry values for VPERMD brute force.
409+
Each entry is an int64 value: 8 bytes from the table at offset 8*mask. *)
410+
let TABLE_ENTRY_VALS =
411+
let table_expanded =
412+
(REWRITE_CONV[mldsa_rej_uniform_table; num_of_wordlist; DIMINDEX_8] THENC
413+
DEPTH_CONV WORD_NUM_RED_CONV THENC NUM_REDUCE_CONV)
414+
`num_of_wordlist mldsa_rej_uniform_table` in
415+
let table_num = rhs(concl table_expanded) in
416+
let entries = Array.init 256 (fun m ->
417+
let tm = mk_comb(mk_comb(`(MOD)`,
418+
mk_comb(mk_comb(`(DIV)`, table_num),
419+
mk_comb(mk_comb(`(EXP)`, `2`), mk_numeral(Num.num_of_int(64*m))))),
420+
mk_comb(mk_comb(`(EXP)`, `2`), `64`)) in
421+
let th = NUM_REDUCE_CONV tm in
422+
let rhs_val = rhs(concl th) in
423+
(* Prove: (num_of_wordlist table DIV 2^(64*m)) MOD 2^64 = entry_m *)
424+
let lhs_tm = mk_comb(mk_comb(`(MOD)`,
425+
mk_comb(mk_comb(`(DIV)`,
426+
`num_of_wordlist mldsa_rej_uniform_table`),
427+
mk_comb(mk_comb(`(EXP)`, `2`), mk_numeral(Num.num_of_int(64*m))))),
428+
mk_comb(mk_comb(`(EXP)`, `2`), `64`)) in
429+
let eq = mk_eq(lhs_tm, rhs_val) in
430+
EQT_ELIM((REWRITE_CONV[table_expanded] THENC NUM_REDUCE_CONV) eq)) in
431+
entries;;
432+
409433
(* TABLE_ENTRY_FROM_MEMORY: connect bytes64 memory read at table+8k to
410434
(table_num DIV 2^(64k)) MOD 2^64 via bigdigit/bignum_from_memory *)
411435
let TABLE_ENTRY_FROM_MEMORY = prove(
@@ -1416,9 +1440,9 @@ let VAL_RCX_ADD3_ZX = prove
14161440
let SCALAR_BODY_LEMMA = prove
14171441
(`!res buf table (inlist:(24 word)list) pc stackpointer N K i.
14181442
LENGTH inlist = 280 /\
1419-
nonoverlapping (word pc, 246) (res, 1024) /\
1420-
nonoverlapping (word pc, 246) (buf, 840) /\
1421-
nonoverlapping (word pc, 246) (table, 2048) /\
1443+
nonoverlapping (word pc, 243) (res, 1024) /\
1444+
nonoverlapping (word pc, 243) (buf, 840) /\
1445+
nonoverlapping (word pc, 243) (table, 2048) /\
14221446
nonoverlapping (res, 1024) (buf, 840) /\
14231447
nonoverlapping (res, 1024) (table, 2048) /\
14241448
24 * N <= 832 /\
@@ -2284,9 +2308,9 @@ let SCALAR_BODY_LEMMA = prove
22842308
let MLDSA_REJ_UNIFORM_CORRECT = prove
22852309
(`!res buf table (inlist:(24 word)list) pc.
22862310
LENGTH inlist = 280 /\
2287-
nonoverlapping (word pc, 246) (res, 1024) /\
2288-
nonoverlapping (word pc, 246) (buf, 840) /\
2289-
nonoverlapping (word pc, 246) (table, 2048) /\
2311+
nonoverlapping (word pc, 243) (res, 1024) /\
2312+
nonoverlapping (word pc, 243) (buf, 840) /\
2313+
nonoverlapping (word pc, 243) (table, 2048) /\
22902314
nonoverlapping (res, 1024) (buf, 840) /\
22912315
nonoverlapping (res, 1024) (table, 2048)
22922316
==> ensures x86
@@ -2296,7 +2320,7 @@ let MLDSA_REJ_UNIFORM_CORRECT = prove
22962320
read(memory :> bytes(buf,840)) s = num_of_wordlist inlist /\
22972321
read(memory :> bytes(table,2048)) s =
22982322
num_of_wordlist(mldsa_rej_uniform_table:byte list))
2299-
(\s. read RIP s = word(pc + 245) /\
2323+
(\s. read RIP s = word(pc + 242) /\
23002324
let outlist = SUB_LIST(0,256) (REJ_SAMPLE inlist) in
23012325
let outlen = LENGTH outlist in
23022326
C_RETURN s = word outlen /\
@@ -3775,7 +3799,7 @@ let MLDSA_REJ_UNIFORM_CORRECT = prove
37753799
is_eq(concl th)
37763800
then ASSUME_TAC(CONV_RULE(RAND_CONV(DEPTH_CONV WORD_NUM_RED_CONV)) th)
37773801
else failwith "not RIP") THEN
3778-
X86_STEPS_TAC MLDSA_REJ_UNIFORM_EXEC [55] THEN
3802+
(* vzeroupper removed (was step 55); RIP is already at the RET. *)
37793803
ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN
37803804
CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN
37813805
SUBGOAL_THEN `SUB_LIST (0,256) (REJ_SAMPLE (inlist:(24 word)list)) =
@@ -3843,8 +3867,7 @@ let MLDSA_REJ_UNIFORM_CORRECT = prove
38433867
let c = concl th in
38443868
if is_conj c && (try can (find_term ((=) `LENGTH (REJ_SAMPLE (SUB_LIST (0,8 * N + K) (inlist:(24 word)list)))`)) c with _ -> false)
38453869
then STRIP_ASSUME_TAC th else failwith "not inv") THEN
3846-
(* VZEROUPPER *)
3847-
X86_STEPS_TAC MLDSA_REJ_UNIFORM_EXEC [55] THEN
3870+
(* vzeroupper removed (was step 55); RIP is already at the RET. *)
38483871
ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN
38493872
CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN
38503873
(* The disjunct at K: either count-exit (256 <= outlen_K) or offset-exit (837 < 24*N+3*K) *)
@@ -3969,6 +3992,10 @@ let MLDSA_REJ_UNIFORM_CORRECT = prove
39693992

39703993
(* ========================================================================= *)
39713994
(* SUBROUTINE_CORRECT variants (standard x86_64 ABI). *)
3995+
(* *)
3996+
(* These specifications must be kept in sync with the CBMC contract in *)
3997+
(* dev/x86_64/src/arith_native_x86_64.h / mldsa/src/native/x86_64/src/ *)
3998+
(* arith_native_x86_64.h for mld_rej_uniform_avx2_asm. *)
39723999
(* ========================================================================= *)
39734000

39744001
let MLDSA_REJ_UNIFORM_NOIBT_SUBROUTINE_CORRECT = prove

0 commit comments

Comments
 (0)