Skip to content

Commit 367e96d

Browse files
siemen11nasahlpa
authored andcommitted
[base/hardened_memory] Add modular reduce, add, and sub
For the arithmetic sharings for ECC, we use a modular subtraction, implement helper functions which are constant time and hardened against fault injection. Also added a functest for the hardened arithmetic operations to check whether they are constant time and whether they are correct on a RV platform. The functions including the add and sub functions are written to use the risc-v instructions when on such a platform. This is to enforce the constant time nature of the operations. Helper functions for add, sub, and select are created. Signed-off-by: Siemen Dhooghe <sdhooghe@google.com>
1 parent bedd585 commit 367e96d

5 files changed

Lines changed: 977 additions & 0 deletions

File tree

sw/device/lib/base/hardened_memory.c

Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,76 @@ status_t randomized_bytexor_in_place(void *restrict x, const void *restrict y,
301301
return (status_t){.value = (int32_t)launder32((uint32_t)OTCRYPTO_OK.value)};
302302
}
303303

304+
#ifdef OT_PLATFORM_RV32
305+
/**
306+
* Call the RISC-V addition with carry.
307+
*
308+
* @param x First input of the addition.
309+
* @param y Second input of the addition.
310+
* @param carry The carry-in, updates to the carry-out.
311+
* @return The addition of x, y, and carry.
312+
*/
313+
static inline uint32_t rv32_addc(uint32_t x, uint32_t y, uint32_t *carry) {
314+
uint32_t res, next_carry, c1, c2;
315+
__asm__ __volatile__(
316+
"add %[res], %[x], %[c_in]\n\t"
317+
"sltu %[c1], %[res], %[c_in]\n\t"
318+
"add %[res], %[res], %[y]\n\t"
319+
"sltu %[c2], %[res], %[y]\n\t"
320+
"or %[next_c], %[c1], %[c2]"
321+
: [res] "=&r"(res), [next_c] "=&r"(next_carry), [c1] "=&r"(c1),
322+
[c2] "=&r"(c2)
323+
: [x] "r"(x), [y] "r"(y), [c_in] "r"(*carry));
324+
*carry = next_carry;
325+
return res;
326+
}
327+
328+
/**
329+
* Call the RISC-V subtraction with borrow.
330+
*
331+
* @param x First input of the subtraction.
332+
* @param y Second input of the subtraction.
333+
* @param borrow The borrow-in, updates to the borrow-out.
334+
* @return The subtraction of x, y, and borrow.
335+
*/
336+
static inline uint32_t rv32_subc(uint32_t x, uint32_t y, uint32_t *borrow) {
337+
uint32_t res, next_borrow, b1, b2;
338+
__asm__ __volatile__(
339+
"sltu %[b1], %[x], %[b_in]\n\t"
340+
"sub %[res], %[x], %[b_in]\n\t"
341+
"sltu %[b2], %[res], %[y]\n\t"
342+
"sub %[res], %[res], %[y]\n\t"
343+
"or %[next_b], %[b1], %[b2]"
344+
: [res] "=&r"(res), [next_b] "=&r"(next_borrow), [b1] "=&r"(b1),
345+
[b2] "=&r"(b2)
346+
: [x] "r"(x), [y] "r"(y), [b_in] "r"(*borrow));
347+
*borrow = next_borrow;
348+
return res;
349+
}
350+
351+
/**
352+
* Call the RISC-V select.
353+
*
354+
*
355+
* @param a First input of the select.
356+
* @param b Second input of the select.
357+
* @param cond The condition to select.
358+
* @return `a` if `cond == 1`, or `b` if `cond == 0`.
359+
*/
360+
static inline uint32_t rv32_sel(uint32_t cond, uint32_t a, uint32_t b) {
361+
uint32_t res, mask, tmp;
362+
__asm__ __volatile__(
363+
"neg %[mask], %[cond]\n\t" // mask = 0 - cond (0xFFFFFFFF if 1,
364+
// 0x00000000 if 0)
365+
"xor %[tmp], %[a], %[b]\n\t" // tmp = a ^ b
366+
"and %[tmp], %[tmp], %[mask]\n\t" // tmp = (a ^ b) & mask
367+
"xor %[res], %[b], %[tmp]" // out = b ^ ((a ^ b) & mask)
368+
: [res] "=r"(res), [mask] "=&r"(mask), [tmp] "=&r"(tmp)
369+
: [cond] "r"(cond), [a] "r"(a), [b] "r"(b));
370+
return res;
371+
}
372+
#endif
373+
304374
status_t hardened_add(const uint32_t *restrict x, const uint32_t *restrict y,
305375
size_t word_len, uint32_t *restrict dest) {
306376
// Randomize the content of the output buffer before writing to it.
@@ -310,6 +380,9 @@ status_t hardened_add(const uint32_t *restrict x, const uint32_t *restrict y,
310380
size_t count = 0;
311381

312382
for (; launderw(count) < word_len; count = launderw(count) + 1) {
383+
#ifdef OT_PLATFORM_RV32
384+
dest[count] = rv32_addc(x[count], y[count], &carry);
385+
#else
313386
uint32_t x_val = x[count];
314387
uint32_t y_val = y[count];
315388

@@ -321,6 +394,7 @@ status_t hardened_add(const uint32_t *restrict x, const uint32_t *restrict y,
321394

322395
dest[count] = res;
323396
carry = next_carry;
397+
#endif
324398
}
325399
HARDENED_CHECK_EQ(count, word_len);
326400

@@ -336,6 +410,9 @@ status_t hardened_sub(const uint32_t *restrict x, const uint32_t *restrict y,
336410
size_t count = 0;
337411

338412
for (; launderw(count) < word_len; count = launderw(count) + 1) {
413+
#ifdef OT_PLATFORM_RV32
414+
dest[count] = rv32_subc(x[count], y[count], &borrow);
415+
#else
339416
uint32_t x_val = x[count];
340417
uint32_t y_val = y[count];
341418

@@ -348,6 +425,141 @@ status_t hardened_sub(const uint32_t *restrict x, const uint32_t *restrict y,
348425

349426
dest[count] = res;
350427
borrow = next_borrow;
428+
#endif
429+
}
430+
HARDENED_CHECK_EQ(count, word_len);
431+
432+
return (status_t){.value = (int32_t)launder32((uint32_t)OTCRYPTO_OK.value)};
433+
}
434+
435+
status_t hardened_sub_mod(const uint32_t *restrict x,
436+
const uint32_t *restrict y,
437+
const uint32_t *restrict n, size_t word_len,
438+
uint32_t *restrict dest) {
439+
// Randomize the content of the output buffer before writing to it.
440+
hardened_memshred(dest, word_len);
441+
442+
// temp_sub = x - y
443+
uint32_t temp_sub[word_len];
444+
uint32_t borrow = 0;
445+
size_t count = 0;
446+
for (; launderw(count) < word_len; count = launderw(count) + 1) {
447+
#ifdef OT_PLATFORM_RV32
448+
temp_sub[count] = rv32_subc(x[count], y[count], &borrow);
449+
#else
450+
uint32_t x_val = x[count];
451+
uint32_t y_val = y[count];
452+
uint32_t res = x_val - borrow;
453+
uint32_t next_borrow = (x_val < borrow);
454+
next_borrow += (res < y_val);
455+
res -= y_val;
456+
temp_sub[count] = res;
457+
borrow = next_borrow;
458+
#endif
459+
}
460+
HARDENED_CHECK_EQ(count, word_len);
461+
462+
// temp_add = temp_sub + n
463+
uint32_t temp_add[word_len];
464+
uint32_t carry = 0;
465+
count = 0;
466+
for (; launderw(count) < word_len; count = launderw(count) + 1) {
467+
#ifdef OT_PLATFORM_RV32
468+
temp_add[count] = rv32_addc(temp_sub[count], n[count], &carry);
469+
#else
470+
uint32_t x_val = temp_sub[count];
471+
uint32_t y_val = n[count];
472+
uint32_t res = x_val + carry;
473+
uint32_t next_carry = (res < carry);
474+
res += y_val;
475+
next_carry += (res < y_val);
476+
temp_add[count] = res;
477+
carry = next_carry;
478+
#endif
479+
}
480+
HARDENED_CHECK_EQ(count, word_len);
481+
482+
// If borrow is 1, choose temp_add, otherwise choose temp_sub.
483+
uint32_t is_borrow = launder32(borrow);
484+
485+
count = 0;
486+
for (; launderw(count) < word_len; count = launderw(count) + 1) {
487+
#ifdef OT_PLATFORM_RV32
488+
dest[count] = rv32_sel(is_borrow, temp_add[count], temp_sub[count]);
489+
#else
490+
// The mask is all 1s if borrow is 1, and all 0s if borrow is 0.
491+
uint32_t mask = ~(is_borrow - 1);
492+
// Prevent optimizations of mask.
493+
mask = launder32(mask);
494+
dest[count] = (temp_add[count] & launder32(mask)) |
495+
(temp_sub[count] & launder32(~mask));
496+
#endif
497+
}
498+
HARDENED_CHECK_EQ(count, word_len);
499+
500+
return (status_t){.value = (int32_t)launder32((uint32_t)OTCRYPTO_OK.value)};
501+
}
502+
503+
status_t hardened_add_mod(const uint32_t *restrict x,
504+
const uint32_t *restrict y,
505+
const uint32_t *restrict n, size_t word_len,
506+
uint32_t *restrict dest) {
507+
// Randomize the content of the output buffer before writing to it.
508+
hardened_memshred(dest, word_len);
509+
510+
// temp_add = x + y
511+
uint32_t temp_add[word_len];
512+
uint32_t carry = 0;
513+
size_t count = 0;
514+
for (; launderw(count) < word_len; count = launderw(count) + 1) {
515+
#ifdef OT_PLATFORM_RV32
516+
temp_add[count] = rv32_addc(x[count], y[count], &carry);
517+
#else
518+
uint32_t x_val = x[count];
519+
uint32_t y_val = y[count];
520+
uint32_t res = x_val + carry;
521+
uint32_t next_carry = (res < carry);
522+
res += y_val;
523+
next_carry += (res < y_val);
524+
temp_add[count] = res;
525+
carry = next_carry;
526+
#endif
527+
}
528+
HARDENED_CHECK_EQ(count, word_len);
529+
530+
// temp_sub = temp_add - n
531+
uint32_t temp_sub[word_len];
532+
uint32_t borrow = 0;
533+
count = 0;
534+
for (; launderw(count) < word_len; count = launderw(count) + 1) {
535+
#ifdef OT_PLATFORM_RV32
536+
temp_sub[count] = rv32_subc(temp_add[count], n[count], &borrow);
537+
#else
538+
uint32_t x_val = temp_add[count];
539+
uint32_t y_val = n[count];
540+
uint32_t res = x_val - borrow;
541+
uint32_t next_borrow = (x_val < borrow);
542+
next_borrow += (res < y_val);
543+
res -= y_val;
544+
temp_sub[count] = res;
545+
borrow = next_borrow;
546+
#endif
547+
}
548+
HARDENED_CHECK_EQ(count, word_len);
549+
550+
uint32_t is_ge = launder32(carry) | (1 - launder32(borrow));
551+
552+
count = 0;
553+
for (; launderw(count) < word_len; count = launderw(count) + 1) {
554+
#ifdef OT_PLATFORM_RV32
555+
dest[count] = rv32_sel(is_ge, temp_sub[count], temp_add[count]);
556+
#else
557+
uint32_t mask = ~(is_ge - 1);
558+
// Prevent optimizations of mask
559+
mask = launder32(mask);
560+
dest[count] = (temp_sub[count] & launder32(mask)) |
561+
(temp_add[count] & launder32(~mask));
562+
#endif
351563
}
352564
HARDENED_CHECK_EQ(count, word_len);
353565

@@ -367,12 +579,16 @@ status_t hardened_range_check(const uint32_t *value, const uint32_t *N,
367579
// Accumulate bits to check if value is zero.
368580
is_zero_acc = launder32(is_zero_acc) | launder32(val_word);
369581

582+
#ifdef OT_PLATFORM_RV32
583+
(void)rv32_subc(val_word, n_word, &borrow);
584+
#else
370585
// Compute borrow to check if value < N.
371586
uint32_t res = val_word - borrow;
372587
uint32_t next_borrow = (val_word < borrow);
373588
next_borrow += (res < n_word);
374589

375590
borrow = next_borrow;
591+
#endif
376592
}
377593
HARDENED_CHECK_EQ(count, word_len);
378594

@@ -388,3 +604,106 @@ status_t hardened_range_check(const uint32_t *value, const uint32_t *N,
388604

389605
return (status_t){.value = (int32_t)launder32((uint32_t)OTCRYPTO_OK.value)};
390606
}
607+
608+
status_t hardened_mod_reduce(const uint32_t *value, const uint32_t *n,
609+
size_t word_len, uint32_t *result) {
610+
// This function computes modular reduction (value % n).
611+
// It implements a constant-time shift-and-subtract division. It iterates
612+
// through the bits of the dividend (`value`) from MSB to LSB, shifting them
613+
// into a remainder `r`, and subtracting the divisor `n` in constant time.
614+
615+
// Remainder, twice the size of the modulus to handle the left shift.
616+
uint32_t r[2 * word_len];
617+
// Intermediate storing of (r - n).
618+
uint32_t r_sub[2 * word_len];
619+
620+
size_t i = 0;
621+
622+
// Initialize remainder arrays to zero.
623+
for (; launderw(i) < 2 * word_len; i = launderw(i) + 1) {
624+
r[i] = 0;
625+
r_sub[i] = 0;
626+
}
627+
HARDENED_CHECK_EQ(i, 2 * word_len);
628+
629+
// Process each bit of `value` from Most Significant Bit (MSB) down to LSB.
630+
i = word_len * 32;
631+
for (; launderw(i) > 0; i = launderw(i) - 1) {
632+
size_t bit_idx = i - 1;
633+
size_t word_idx = bit_idx >> 5;
634+
size_t bit_in_word = bit_idx % 32;
635+
636+
// Shift the current remainder `r` left by 1 bit.
637+
uint32_t carry = 0;
638+
size_t j = 0;
639+
for (; launderw(j) < 2 * word_len; j = launderw(j) + 1) {
640+
uint32_t next_carry = (r[j] >> 31);
641+
r[j] = (r[j] << 1) | carry;
642+
carry = next_carry;
643+
}
644+
HARDENED_CHECK_EQ(j, 2 * word_len);
645+
646+
// Inject the current top bit of `value` into the LSB of `r`.
647+
uint32_t bit = (value[word_idx] >> bit_in_word) & 1;
648+
r[0] = r[0] ^ ((r[0] ^ bit) & 1);
649+
650+
// Compute `r_sub = r - n`.
651+
uint32_t borrow = 0;
652+
j = 0;
653+
for (; launderw(j) < word_len; j = launderw(j) + 1) {
654+
#ifdef OT_PLATFORM_RV32
655+
r_sub[j] = rv32_subc(r[j], n[j], &borrow);
656+
#else
657+
uint32_t r_word = r[j];
658+
uint32_t n_word = n[j];
659+
uint32_t res = r_word - borrow;
660+
uint32_t next_borrow = (r_word < borrow);
661+
next_borrow |= (res < n_word);
662+
res -= n_word;
663+
r_sub[j] = res;
664+
borrow = next_borrow;
665+
#endif
666+
}
667+
HARDENED_CHECK_EQ(j, word_len);
668+
669+
// Propagate the borrow through the upper half of r
670+
j = word_len;
671+
for (; launderw(j) < 2 * word_len; j = launderw(j) + 1) {
672+
#ifdef OT_PLATFORM_RV32
673+
r_sub[j] = rv32_subc(r[j], 0, &borrow);
674+
#else
675+
uint32_t res = r[j] - borrow;
676+
borrow = (r[j] < borrow);
677+
r_sub[j] = res;
678+
#endif
679+
}
680+
HARDENED_CHECK_EQ(j, 2 * word_len);
681+
682+
// Conditional swap.
683+
// If r < n, the final borrow is 1. If r >= n, the final borrow is 0.
684+
#ifdef OT_PLATFORM_RV32
685+
uint32_t cond = launder32(1 - launder32(borrow));
686+
#else
687+
uint32_t mask = borrow - 1;
688+
// Prevent compiler optimizations of the mask.
689+
mask = launder32(mask);
690+
#endif
691+
692+
j = 0;
693+
for (; launderw(j) < 2 * word_len; j = launderw(j) + 1) {
694+
#ifdef OT_PLATFORM_RV32
695+
r[j] = rv32_sel(cond, r_sub[j], r[j]);
696+
#else
697+
r[j] = (r_sub[j] & launder32(mask)) | (r[j] & launder32(~mask));
698+
#endif
699+
}
700+
HARDENED_CHECK_EQ(j, 2 * word_len);
701+
}
702+
HARDENED_CHECK_EQ(i, 0);
703+
704+
// Copy the lower word_len elements of the final remainder into the result
705+
// array.
706+
TRY(hardened_memcpy(result, r, word_len));
707+
708+
return (status_t){.value = (int32_t)launder32((uint32_t)OTCRYPTO_OK.value)};
709+
}

0 commit comments

Comments
 (0)