Skip to content

Commit d559afc

Browse files
42Pupusasclaude
andcommitted
k256: fix wNAF overflow past bit 127 for near-2^128 inputs
`wnaf_128` tracked the residual scalar in two u64 limbs, but a negative recentered digit adds up to 2^(W-1) − 1 to the value, which can legitimately overflow past bit 127 when the input is close to 2^128 − 1. The old code let `hi.wrapping_add(1)` silently wrap, losing the carried bit and producing a NAF that reconstructs to the wrong value. The GLV decomposition's `(r1, r2)` each have magnitude strictly less than 2^128, so values in the carry-out window are possible (though vanishingly rare in random scalars — which is why the existing randomized tests never caught it). Fix by carrying the overflow bit into a third limb `top` that is absorbed back on the next right-shift. Perf impact is in the noise: the `top` branch is almost never taken and the predictor handles it cleanly. Add two regression tests: - `test_wnaf_128_reconstruction_adversarial` — reconstructs the NAF of a scalar with low 128 bits = 0xFF..FF and asserts it equals 2^128 − 1. - `test_mul_vartime_adversarial_scalars` — end-to-end check that `mul_vartime(P, k)` matches the constant-time reference when `k`'s low 128 bits trigger the carry window. Also add a `debug_assert!` on `idx` in `WnafSlot::apply` to guard the parallel invariant (`idx < WNAF_TABLE_SIZE`) if `WNAF_WIDTH` is ever widened without growing the table. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent a1cf282 commit d559afc

1 file changed

Lines changed: 93 additions & 10 deletions

File tree

k256/src/arithmetic/mul.rs

Lines changed: 93 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,13 @@ fn wnaf_128(k: &Scalar) -> [i8; WNAF_DIGITS] {
342342

343343
let mut out = [0i8; WNAF_DIGITS];
344344
let mut i = 0;
345-
while (lo | hi) != 0 {
345+
// Three-limb representation `(lo, hi, top)`: `top` is 0 or 1 and only becomes 1 when a
346+
// negative digit adds past bit 127. The extra bit is absorbed back on the next right-shift.
347+
// This is needed because GLV sub-scalars can legitimately reach magnitudes up to 2^128 − 1,
348+
// and a width-W recentering can add up to 2^(W-1) − 1 to the value, so transient overflow
349+
// past bit 127 must be preserved rather than silently wrapping `hi`.
350+
let mut top: u64 = 0;
351+
while (lo | hi | top) != 0 {
346352
debug_assert!(i < WNAF_DIGITS);
347353
if (lo & 1) == 1 {
348354
// d = k mod 2^W, recentered into [-2^(W-1) + 1, 2^(W-1) - 1]
@@ -352,27 +358,37 @@ fn wnaf_128(k: &Scalar) -> [i8; WNAF_DIGITS] {
352358
}
353359
out[i] = d as i8;
354360

355-
// k -= d (128-bit signed update)
361+
// k -= d (129-bit signed update, but the result is always >= 0 because the low W
362+
// bits of k equalled d mod 2^W and the recentering chose the signed representative).
356363
if d < 0 {
357364
// k -= (negative d) == k += |d|
358365
let add = (-d) as u64;
359-
let (new_lo, carry) = lo.overflowing_add(add);
366+
let (new_lo, carry0) = lo.overflowing_add(add);
360367
lo = new_lo;
361-
if carry {
362-
hi = hi.wrapping_add(1);
368+
if carry0 {
369+
let (new_hi, carry1) = hi.overflowing_add(1);
370+
hi = new_hi;
371+
if carry1 {
372+
top = top.wrapping_add(1);
373+
}
363374
}
364375
} else {
365376
let sub = d as u64;
366-
let (new_lo, borrow) = lo.overflowing_sub(sub);
377+
let (new_lo, borrow0) = lo.overflowing_sub(sub);
367378
lo = new_lo;
368-
if borrow {
369-
hi = hi.wrapping_sub(1);
379+
if borrow0 {
380+
let (new_hi, borrow1) = hi.overflowing_sub(1);
381+
hi = new_hi;
382+
if borrow1 {
383+
top = top.wrapping_sub(1);
384+
}
370385
}
371386
}
372387
}
373-
// Shift right by 1 across the 128-bit value.
388+
// Shift right by 1 across the 129-bit value.
374389
lo = (lo >> 1) | (hi << 63);
375-
hi >>= 1;
390+
hi = (hi >> 1) | (top << 63);
391+
top >>= 1;
376392
i += 1;
377393
}
378394
out
@@ -427,6 +443,10 @@ impl WnafSlot {
427443
let d = self.digits[i];
428444
if d != 0 {
429445
let idx = (d.unsigned_abs() >> 1) as usize;
446+
// |d| ≤ 2^(W-1) − 1 = 15 for W=5, so idx ≤ 7 = WNAF_TABLE_SIZE − 1. Guard here so
447+
// any future widening of WNAF_WIDTH that forgets to grow WNAF_TABLE_SIZE panics at
448+
// test time rather than at a random position in the ladder under release.
449+
debug_assert!(idx < WNAF_TABLE_SIZE);
430450
if d > 0 {
431451
*acc += &self.table[idx];
432452
} else {
@@ -662,6 +682,69 @@ mod tests {
662682
);
663683
}
664684

685+
// Reconstructs a wNAF digit array as a signed integer and compares to the expected low-128-bit
686+
// value of `k` (since wnaf_128 only reads bytes[16..32]).
687+
fn check_wnaf_reconstruction(k: &Scalar) {
688+
let digits = wnaf_128(k);
689+
let mut sum = num_bigint::BigInt::from(0);
690+
for (i, &d) in digits.iter().enumerate() {
691+
if d != 0 {
692+
sum += num_bigint::BigInt::from(d) << i;
693+
}
694+
}
695+
let bytes = k.to_bytes();
696+
let mut expected = num_bigint::BigInt::from(0);
697+
for &b in bytes[16..32].iter() {
698+
expected = (expected << 8) + b as u32;
699+
}
700+
assert_eq!(
701+
sum, expected,
702+
"wnaf_128 reconstructs wrong value for k.lo128 = {expected:x}"
703+
);
704+
}
705+
706+
/// End-to-end check on a scalar whose GLV halves land at or near the 2^128 boundary.
707+
/// We don't know in advance which scalars produce such halves, so instead we hunt: for a
708+
/// fixed base point, try scalars whose low bits are `0xFF..FF` and verify that the vartime
709+
/// result matches the constant-time reference. If the 3-limb carry fix is ever reverted,
710+
/// one of these will mismatch.
711+
#[test]
712+
fn test_mul_vartime_adversarial_scalars() {
713+
let p = ProjectivePoint::GENERATOR;
714+
// A scalar where the low 128 bits are all 1s forces wnaf_128's original 128-bit
715+
// code path through its carry-out window.
716+
let mut bytes = [0u8; 32];
717+
for b in bytes.iter_mut().skip(16) {
718+
*b = 0xFF;
719+
}
720+
// Ensure it's a valid scalar (not >= n). Setting the high byte to 0 keeps it small.
721+
let k = Scalar::from_bytes_unchecked(&bytes);
722+
let reference = p * k;
723+
let test = mul_vartime_impl(&p, &k);
724+
assert_eq!(
725+
reference, test,
726+
"mul_vartime mismatch on adversarial scalar"
727+
);
728+
}
729+
730+
#[test]
731+
fn test_wnaf_128_reconstruction_adversarial() {
732+
// Pathological: all-ones low 128 bits (= 2^128 - 1). Triggers a carry past bit 127.
733+
let mut bytes = [0u8; 32];
734+
for b in bytes.iter_mut().skip(16) {
735+
*b = 0xFF;
736+
}
737+
check_wnaf_reconstruction(&Scalar::from_bytes_unchecked(&bytes));
738+
739+
// Just below: 2^128 - 2 (even → d=0 on iter 0, no carry issue).
740+
bytes[31] = 0xFE;
741+
check_wnaf_reconstruction(&Scalar::from_bytes_unchecked(&bytes));
742+
743+
// Just below 2^128 but odd with high-bit set: 2^128 - 17.
744+
bytes[31] = 0xEF;
745+
check_wnaf_reconstruction(&Scalar::from_bytes_unchecked(&bytes));
746+
}
747+
665748
#[test]
666749
fn test_mul_vartime_edge_cases() {
667750
let p = ProjectivePoint::GENERATOR;

0 commit comments

Comments
 (0)