-
Notifications
You must be signed in to change notification settings - Fork 45
Expand file tree
/
Copy pathrounding.h
More file actions
249 lines (230 loc) · 8.77 KB
/
rounding.h
File metadata and controls
249 lines (230 loc) · 8.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
/*
* Copyright (c) The mldsa-native project authors
* SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
*/
/* References
* ==========
*
* - [FIPS204]
* FIPS 204 Module-Lattice-Based Digital Signature Standard
* National Institute of Standards and Technology
* https://csrc.nist.gov/pubs/fips/204/final
*/
#ifndef MLD_ROUNDING_H
#define MLD_ROUNDING_H
#include "cbmc.h"
#include "common.h"
#include "ct.h"
#include "debug.h"
/* Parameter set namespacing
* This is to facilitate building multiple instances
* of mldsa-native (e.g. with varying parameter sets)
* within a single compilation unit. */
#define mld_power2round MLD_ADD_PARAM_SET(mld_power2round)
#define mld_decompose MLD_ADD_PARAM_SET(mld_decompose)
#define mld_make_hint MLD_ADD_PARAM_SET(mld_make_hint)
#define mld_use_hint MLD_ADD_PARAM_SET(mld_use_hint)
/* End of parameter set namespacing */
#define MLD_2_POW_D (1 << MLDSA_D)
/*************************************************
* Name: mld_power2round
*
* Description: For finite field element a, compute a0, a1 such that
* a mod^+ MLDSA_Q = a1*2^MLDSA_D + a0 with -2^{MLDSA_D-1} < a0 <=
* 2^{MLDSA_D-1}. Assumes a to be standard representative.
*
* Arguments: - int32_t a: input element
* - int32_t *a0: pointer to output element a0
* - int32_t *a1: pointer to output element a1
*
* Reference: In the reference implementation, a1 is passed as a
* return value instead.
**************************************************/
static MLD_INLINE void mld_power2round(int32_t *a0, int32_t *a1, int32_t a)
__contract__(
requires(memory_no_alias(a0, sizeof(int32_t)))
requires(memory_no_alias(a1, sizeof(int32_t)))
requires(a >= 0 && a < MLDSA_Q)
assigns(memory_slice(a0, sizeof(int32_t)))
assigns(memory_slice(a1, sizeof(int32_t)))
ensures(*a0 > -(MLD_2_POW_D/2) && *a0 <= (MLD_2_POW_D/2))
ensures(*a1 >= 0 && *a1 <= (MLDSA_Q - 1) / MLD_2_POW_D)
ensures((*a1 * MLD_2_POW_D + *a0 - a) % MLDSA_Q == 0)
)
{
*a1 = (a + (1 << (MLDSA_D - 1)) - 1) >> MLDSA_D;
*a0 = a - (*a1 << MLDSA_D);
}
/*************************************************
* Name: mld_decompose
*
* Description: For finite field element a, compute high and low bits a0, a1
* such that a mod^+ MLDSA_Q = a1* 2 * MLDSA_GAMMA2 + a0 with
* -MLDSA_GAMMA2 < a0 <= MLDSA_GAMMA2 except
* if a1 = (MLDSA_Q-1)/(MLDSA_GAMMA2*2) where we set a1 = 0 and
* -MLDSA_GAMMA2 <= a0 = a mod^+ MLDSA_Q - MLDSA_Q < 0.
* Assumes a to be standard representative.
*
* Arguments: - int32_t a: input element
* - int32_t *a0: pointer to output element a0
* - int32_t *a1: pointer to output element a1
*
* Reference: a1 is passed as a return value instead
**************************************************/
static MLD_INLINE void mld_decompose(int32_t *a0, int32_t *a1, int32_t a)
__contract__(
requires(memory_no_alias(a0, sizeof(int32_t)))
requires(memory_no_alias(a1, sizeof(int32_t)))
requires(a >= 0 && a < MLDSA_Q)
assigns(memory_slice(a0, sizeof(int32_t)))
assigns(memory_slice(a1, sizeof(int32_t)))
/* a0 = -MLDSA_GAMMA2 can only occur when (q-1) = a - (a mod MLDSA_GAMMA2),
* then a1=1; and a0 = a - (a mod MLDSA_GAMMA2) - 1 (@[FIPS204, Algorithm 36 (Decompose)]) */
ensures(*a0 >= -MLDSA_GAMMA2 && *a0 <= MLDSA_GAMMA2)
ensures(*a1 >= 0 && *a1 < (MLDSA_Q-1)/(2*MLDSA_GAMMA2))
ensures((*a1 * 2 * MLDSA_GAMMA2 + *a0 - a) % MLDSA_Q == 0)
)
{
/*
* The goal is to compute f1 = round-(f / (2*GAMMA2)), which can be computed
* alternatively as round-(f / (128B)) = round-(ceil(f / 128) / B) where
* B = 2*GAMMA2 / 128. Here round-() denotes "round half down".
*
* The equality round-(f / (128B)) = round-(ceil(f / 128) / B) can deduced
* as follows. Since changing f to align-up(f, 128) can move f onto but not
* across a rounding boundary for division by 128*B (note that we need B to be
* even for this to work), and round- rounds down on the boundary, we have
*
* round-(f / (128B)) = round-(align-up(f, 128) / (128B))
* = round-((align-up(f, 128) / 128) / B)
* = round-(ceil(f / 128) / B).
*/
*a1 = (a + 127) >> 7;
/* We know a >= 0 and a < MLDSA_Q, so... */
/* check-magic: 65472 == round((MLDSA_Q-1)/128) */
mld_assert(*a1 >= 0 && *a1 <= 65472);
#if MLD_CONFIG_PARAMETER_SET == 44
/* check-magic: 1488 == 2 * intdiv(intdiv(MLDSA_Q - 1, 88), 128) */
/* check-magic: 11275 == floor(2**24 / 1488) */
/* check-magic: 1560281088 == 1 / (1 / 1488 - 11275 / 2**24) */
/*
* Compute f1 = round-(f1' / B) ≈ round(f1' * 11275 / 2^24). This is exact for
* 0 <= f1' < 2^16.
*
* To see this, consider the (signed) error f1' * (1 / B - 11275 / 2^24)
* between f1' / B and the (under-)approximation f1' * 11275 / 2^24. Because
* eps := 1 / B - 11275 / 2^24 is 1 / 1560281088 ≈ 2^(-30.54) < 2^(-30), we
* have 0 <= f1' * eps < 2^16 * 2^(-30) = 1 / 2^14 < 1 / 2^11 < 1 / B (note
* that f1' is non-negative).
*
* On the other hand, 1 / B is the spacing between the integral multiples
* of 1 / B, which includes all rounding boundaries n + 0.5 (since B is even).
* Hence, if f1' / B is not of the form n + 0.5, then it is at least 1 / B
* away from the nearest rounding boundary, so moving from f1' / B to
* f1' * 11275 / 2^24 does not affect the rounding result, no matter the type
* of rounding used in either side. In particular, we have round-(f1' / B) =
* round(f1' * 11275 / 2^24) as claimed.
*
* As for the remaining case where f1' / B _is_ of the form n + 0.5, because
* f1' * 11275 / 2^24 is slightly but strictly below f1' / B = n + 0.5 (note
* that f1' and thus the error f1' * eps cannot be 0 here), it is always
* rounded down to n. More precisely, we have round-(f1' / B) =
* round(f1' * 11275 / 2^24), where the round-down on the LHS is essential,
* and on the RHS the type of rounding again does not matter. This concludes
* the proof.
*
* See proofs/isabelle/compress for a formalization of the above argument.
*/
*a1 = (*a1 * 11275 + (1 << 23)) >> 24;
mld_assert(*a1 >= 0 && *a1 <= 44);
*a1 = mld_ct_sel_int32(0, *a1, mld_ct_cmask_neg_i32(43 - *a1));
mld_assert(*a1 >= 0 && *a1 <= 43);
#else /* MLD_CONFIG_PARAMETER_SET == 44 */
/* check-magic: 4092 == 2 * intdiv(intdiv(MLDSA_Q - 1, 32), 128) */
/* check-magic: 1025 == floor(2**22 / 4092) */
/* check-magic: 4290772992 == 1 / (1 / 4092 - 1025 / 2**22) */
/*
* Compute f1 = round-(f1' / B) ≈ round(f1' * 1025 / 2^22). This is exact for
* 0 <= f1' < 2^16. Following the same argument above, it suffices to show
* that f1' * eps < 1 / B, where eps := 1 / B - 1025 / 2^22. Indeed, we have
* eps = 1 / 4290772992 ≈ 2^(-31.99) < 2^(-31), therefore f1' * eps <
* 2^16 * 2^(-31) = 1 / 2^15 < 1 / 2^12 < 1 / B.
*/
*a1 = (*a1 * 1025 + (1 << 21)) >> 22;
mld_assert(*a1 >= 0 && *a1 <= 16);
*a1 &= 15;
mld_assert(*a1 >= 0 && *a1 <= 15);
#endif /* MLD_CONFIG_PARAMETER_SET != 44 */
*a0 = a - *a1 * 2 * MLDSA_GAMMA2;
*a0 = mld_ct_sel_int32(*a0 - MLDSA_Q, *a0,
mld_ct_cmask_neg_i32((MLDSA_Q - 1) / 2 - *a0));
}
/*************************************************
* Name: mld_make_hint
*
* Description: Compute hint bit indicating whether the low bits of the
* input element overflow into the high bits.
*
* Arguments: - int32_t a0: low bits of input element
* - int32_t a1: high bits of input element
*
* Returns 1 if overflow, 0 otherwise
**************************************************/
MLD_MUST_CHECK_RETURN_VALUE
static MLD_INLINE unsigned int mld_make_hint(int32_t a0, int32_t a1)
__contract__(
ensures(return_value >= 0 && return_value <= 1)
)
{
if (a0 > MLDSA_GAMMA2 || a0 < -MLDSA_GAMMA2 ||
(a0 == -MLDSA_GAMMA2 && a1 != 0))
{
return 1;
}
return 0;
}
/*************************************************
* Name: mld_use_hint
*
* Description: Correct high bits according to hint.
*
* Arguments: - int32_t a: input element
* - int32_t hint: hint bit
*
* Returns corrected high bits.
**************************************************/
MLD_MUST_CHECK_RETURN_VALUE
static MLD_INLINE int32_t mld_use_hint(int32_t a, int32_t hint)
__contract__(
requires(hint >= 0 && hint <= 1)
requires(a >= 0 && a < MLDSA_Q)
ensures(return_value >= 0 && return_value < (MLDSA_Q-1)/(2*MLDSA_GAMMA2))
)
{
int32_t a0, a1;
mld_decompose(&a0, &a1, a);
if (hint == 0)
{
return a1;
}
#if MLD_CONFIG_PARAMETER_SET == 44
if (a0 > 0)
{
return (a1 == 43) ? 0 : a1 + 1;
}
else
{
return (a1 == 0) ? 43 : a1 - 1;
}
#else /* MLD_CONFIG_PARAMETER_SET == 44 */
if (a0 > 0)
{
return (a1 + 1) & 15;
}
else
{
return (a1 - 1) & 15;
}
#endif /* MLD_CONFIG_PARAMETER_SET != 44 */
}
#endif /* !MLD_ROUNDING_H */