Skip to content

Commit 6dcc0a0

Browse files
Protect inner value of NonZero/Odd (#1239)
- pub(crate) `new_unchecked` and `new_ref_unchecked` methods are added to `NonZero` and `Odd` - the inner values of these wrappers are made private (with accessors), preventing accidental mutation - `as_nz_vartime` methods are added to `Uint` and `BoxedUint` for converting references (consistent with `UintRef`) --------- Signed-off-by: Andrew Whitehead <cywolf@gmail.com>
1 parent 3d0edfa commit 6dcc0a0

30 files changed

Lines changed: 413 additions & 224 deletions

src/int.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,15 +152,15 @@ impl<const LIMBS: usize> Int<LIMBS> {
152152
/// Returns some if the original value is non-zero, and false otherwise.
153153
#[must_use]
154154
pub const fn to_nz(self) -> CtOption<NonZero<Self>> {
155-
CtOption::new(NonZero(self), self.0.is_nonzero())
155+
CtOption::new(NonZero::new_unchecked(self), self.0.is_nonzero())
156156
}
157157

158158
/// Convert to a [`Odd<Int<LIMBS>>`].
159159
///
160160
/// Returns some if the original value is odd, and false otherwise.
161161
#[must_use]
162162
pub const fn to_odd(self) -> CtOption<Odd<Self>> {
163-
CtOption::new(Odd(self), self.0.is_odd())
163+
CtOption::new(Odd::new_unchecked(self), self.0.is_odd())
164164
}
165165

166166
/// Interpret the data in this object as a [`Uint`] instead.

src/int/div.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ impl<const LIMBS: usize> Int<LIMBS> {
181181
let quotient = Uint::select(&quotient, &quotient_plus_one, modify);
182182

183183
// Invert the remainder.
184-
let inv_remainder = rhs_mag.0.wrapping_sub(&remainder);
184+
let inv_remainder = rhs_mag.as_ref().wrapping_sub(&remainder);
185185
let remainder = Uint::select(&remainder, &inv_remainder, modify);
186186

187187
// Negate output when lhs and rhs have opposing signs.
@@ -291,7 +291,7 @@ impl<const LIMBS: usize> Int<LIMBS> {
291291
let quotient = Uint::select(&quotient, &quotient_plus_one, modify);
292292

293293
// Invert the remainder.
294-
let inv_remainder = rhs_mag.0.wrapping_sub(&remainder);
294+
let inv_remainder = rhs_mag.as_ref().wrapping_sub(&remainder);
295295
let remainder = Uint::select(&remainder, &inv_remainder, modify);
296296

297297
// Negate output when lhs and rhs have opposing signs.

src/limb.rs

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,25 @@ impl Limb {
9494

9595
/// Convert to a [`NonZero<Limb>`].
9696
///
97-
/// Returns some if the original value is non-zero, and false otherwise.
97+
/// Returns some if the original value is non-zero, and none otherwise.
9898
#[must_use]
9999
pub const fn to_nz(self) -> CtOption<NonZero<Self>> {
100-
let is_nz = self.is_nonzero();
100+
let (nz, self_nz) = self.to_nz_or_one();
101+
CtOption::new(nz, self_nz)
102+
}
101103

102-
// Use `1` as a placeholder in the event that `self` is `Limb(0)`
103-
let nz_word = word::select(1, self.0, is_nz);
104-
CtOption::new(NonZero(Self(nz_word)), is_nz)
104+
/// Convert to a [`NonZero<Limb>`], defaulting to `Self::ONE`.
105+
///
106+
/// Returns a pair consisting of a [`NonZero<Limb>`], and a [`Choice`]
107+
/// indicating whether the original value was non-zero (and preserved).
108+
#[inline(always)]
109+
#[must_use]
110+
pub(crate) const fn to_nz_or_one(self) -> (NonZero<Self>, Choice) {
111+
let is_nz = self.is_nonzero();
112+
(
113+
NonZero::new_unchecked(Self::select(Self::ONE, self, is_nz)),
114+
is_nz,
115+
)
105116
}
106117

107118
/// Convert the least significant bit of this [`Limb`] to a [`Choice`].

src/limb/div.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,17 @@ impl Limb {
2929
/// if the divisor is non-zero, and `CtOption::none()` otherwise.
3030
#[must_use]
3131
pub const fn checked_div(self, rhs: Self) -> CtOption<Limb> {
32-
let is_nz = rhs.is_nonzero();
33-
let quo = self.div_rem(NonZero(Self::select(Limb::ONE, rhs, is_nz))).0;
32+
let (rhs_nz, is_nz) = rhs.to_nz_or_one();
33+
let quo = self.div_rem(rhs_nz).0;
3434
CtOption::new(quo, is_nz)
3535
}
3636

3737
/// Computes the checked division `self / rhs`, returning the remainder
3838
/// if the divisor is non-zero, and `CtOption::none()` otherwise.
3939
#[must_use]
4040
pub const fn checked_rem(self, rhs: Self) -> CtOption<Limb> {
41-
let is_nz = rhs.is_nonzero();
42-
let rem = self.div_rem(NonZero(Self::select(Limb::ONE, rhs, is_nz))).1;
41+
let (rhs_nz, is_nz) = rhs.to_nz_or_one();
42+
let rem = self.div_rem(rhs_nz).1;
4343
CtOption::new(rem, is_nz)
4444
}
4545
}
@@ -318,8 +318,8 @@ mod tests {
318318
assert_eq!(a / &b, c);
319319
assert_eq!(&a / b, c);
320320
assert_eq!(&a / &b, c);
321-
assert_eq!(a / &b.0, c);
322-
assert_eq!(&a / b.0, c);
321+
assert_eq!(a / b.as_ref(), c);
322+
assert_eq!(&a / b.get(), c);
323323
}
324324

325325
#[test]
@@ -335,10 +335,10 @@ mod tests {
335335
res /= &b;
336336
assert_eq!(res, c);
337337
let mut res = a;
338-
res /= b.0;
338+
res /= b.get();
339339
assert_eq!(res, c);
340340
let mut res = a;
341-
res /= &b.0;
341+
res /= b.as_ref();
342342
assert_eq!(res, c);
343343
}
344344

@@ -364,8 +364,8 @@ mod tests {
364364
assert_eq!(a % &b, c);
365365
assert_eq!(&a % b, c);
366366
assert_eq!(&a % &b, c);
367-
assert_eq!(a % &b.0, c);
368-
assert_eq!(&a % b.0, c);
367+
assert_eq!(a % b.as_ref(), c);
368+
assert_eq!(&a % b.get(), c);
369369
}
370370

371371
#[test]
@@ -381,10 +381,10 @@ mod tests {
381381
res %= &b;
382382
assert_eq!(res, c);
383383
let mut res = a;
384-
res %= b.0;
384+
res %= b.get();
385385
assert_eq!(res, c);
386386
let mut res = a;
387-
res %= &b.0;
387+
res %= b.as_ref();
388388
assert_eq!(res, c);
389389
}
390390

src/modular.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,9 @@ mod tests {
165165
// Computing xR mod modulus without Montgomery reduction
166166
let (lo, hi) = x.widening_mul(&Modulus256::PARAMS.one);
167167
let c = lo.concat(&hi);
168-
let red =
169-
c.rem_vartime(&NonZero::new(Modulus256::PARAMS.modulus.0.concat(&U256::ZERO)).unwrap());
168+
let red = c.rem_vartime(
169+
&NonZero::new(Modulus256::PARAMS.modulus.as_ref().concat(&U256::ZERO)).unwrap(),
170+
);
170171
let (lo, hi) = red.split();
171172
assert_eq!(hi, Uint::ZERO);
172173

@@ -294,8 +295,9 @@ mod tests {
294295
// Computing xR mod modulus without Montgomery reduction
295296
let (lo, hi) = x.widening_mul(&Modulus256::PARAMS.one);
296297
let c = lo.concat(&hi);
297-
let red =
298-
c.rem_vartime(&NonZero::new(Modulus256::PARAMS.modulus.0.concat(&U256::ZERO)).unwrap());
298+
let red = c.rem_vartime(
299+
&NonZero::new(Modulus256::PARAMS.modulus.as_ref().concat(&U256::ZERO)).unwrap(),
300+
);
299301
let (lo, hi) = red.split();
300302
assert_eq!(hi, Uint::ZERO);
301303

src/modular/bingcd/div_mod_2k.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
1515
k_upper_bound: u32,
1616
q: &Odd<Self>,
1717
) -> Self {
18-
let one_half_mod_q = OddUint::half_mod(q).0;
18+
let one_half_mod_q = OddUint::half_mod(q);
1919

2020
// Invariant: x = self / 2^e mod q.
2121
let (mut x, mut e) = (self, 0);
@@ -30,7 +30,8 @@ impl<const LIMBS: usize> Uint<LIMBS> {
3030
let f = u32_min(k - e, f_upper_bound);
3131

3232
// Find `s` s.t. qs + x = 0 mod 2^f
33-
let (_, s) = x.limbs[0].bounded_div2k_mod_q(f, f_upper_bound, one_half_mod_q.limbs[0]);
33+
let (_, s) =
34+
x.limbs[0].bounded_div2k_mod_q(f, f_upper_bound, one_half_mod_q.as_ref().limbs[0]);
3435

3536
// Set x <- (x + qs) / 2^f
3637
x = q.mul_add_div2k(s, &x, f);
@@ -98,7 +99,7 @@ impl<const LIMBS: usize> OddUint<LIMBS> {
9899
// = (q + 1) / 2 mod q
99100
// = (q - 1) / 2 + 1 mod q
100101
// = floor(q / 2) + 1 mod q, since q is odd.
101-
Odd(q.as_ref().shr1().wrapping_add(&Uint::ONE))
102+
Odd::new_unchecked(q.as_ref().shr1().wrapping_add(&Uint::ONE))
102103
}
103104

104105
/// Compute `((self * b) + addend) / 2^k`

src/modular/bingcd/xgcd.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ impl<const LIMBS: usize> OddUint<LIMBS> {
174174
// `rhs` is even.
175175
let rhs_is_even = rhs_.is_odd().not();
176176
let (abs_diff, rhs_gt_lhs) = lhs_.abs_diff(rhs_);
177-
let odd_rhs = Odd(Uint::select(rhs_, &abs_diff, rhs_is_even));
177+
let odd_rhs = Odd::new_unchecked(Uint::select(rhs_, &abs_diff, rhs_is_even));
178178

179179
let mut output = self.binxgcd_odd(&odd_rhs);
180180
let matrix = &mut output.matrix;

src/modular/const_monty_form.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ where
236236
D: Deserializer<'de>,
237237
{
238238
Uint::<LIMBS>::deserialize(deserializer).and_then(|montgomery_form| {
239-
if montgomery_form < MOD::PARAMS.modulus.0 {
239+
if montgomery_form < *MOD::PARAMS.modulus.as_ref() {
240240
Ok(Self {
241241
montgomery_form,
242242
phantom: PhantomData,

src/modular/div_by_2.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ pub(crate) const fn div_by_2<const LIMBS: usize>(
1717
// whose Montgomery representation is `b`.
1818

1919
let is_odd = a.is_odd();
20-
let (if_odd, carry) = a.carrying_add(&modulus.0, Limb::ZERO);
20+
let (if_odd, carry) = a.carrying_add(modulus.as_ref(), Limb::ZERO);
2121
let carry = Limb::select(Limb::ZERO, carry, is_odd);
2222
Uint::<LIMBS>::select(a, &if_odd, is_odd)
2323
.shr1()

src/modular/lincomb.rs

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,14 @@ pub const fn lincomb_const_monty_form<MOD: ConstMontyParams<LIMBS>, const LIMBS:
8080
let mut ret = Uint::ZERO;
8181
let mut remain = products.len();
8282
if remain <= max_accum {
83-
let carry =
84-
impl_longa_monty_lincomb!(products, ret.limbs, modulus.0.limbs, mod_neg_inv, LIMBS);
85-
ret.try_sub_with_carry(carry, &modulus.0).0
83+
let carry = impl_longa_monty_lincomb!(
84+
products,
85+
ret.limbs,
86+
modulus.as_ref().limbs,
87+
mod_neg_inv,
88+
LIMBS
89+
);
90+
ret.try_sub_with_carry(carry, modulus.as_ref()).0
8691
} else {
8792
let mut window;
8893
while remain > 0 {
@@ -92,9 +97,14 @@ pub const fn lincomb_const_monty_form<MOD: ConstMontyParams<LIMBS>, const LIMBS:
9297
count = max_accum;
9398
}
9499
(window, products) = products.split_at(count);
95-
let carry =
96-
impl_longa_monty_lincomb!(window, buf.limbs, modulus.0.limbs, mod_neg_inv, LIMBS);
97-
buf = buf.try_sub_with_carry(carry, &modulus.0).0;
100+
let carry = impl_longa_monty_lincomb!(
101+
window,
102+
buf.limbs,
103+
modulus.as_ref().limbs,
104+
mod_neg_inv,
105+
LIMBS
106+
);
107+
buf = buf.try_sub_with_carry(carry, modulus.as_ref()).0;
98108
ret = ret.add_mod(&buf, modulus.as_nz_ref());
99109
remain -= count;
100110
}
@@ -112,9 +122,14 @@ pub const fn lincomb_monty_form<const LIMBS: usize>(
112122
let mut ret = Uint::ZERO;
113123
let mut remain = products.len();
114124
if remain <= max_accum {
115-
let carry =
116-
impl_longa_monty_lincomb!(products, ret.limbs, modulus.0.limbs, mod_neg_inv, LIMBS);
117-
ret.try_sub_with_carry(carry, &modulus.0).0
125+
let carry = impl_longa_monty_lincomb!(
126+
products,
127+
ret.limbs,
128+
modulus.as_ref().limbs,
129+
mod_neg_inv,
130+
LIMBS
131+
);
132+
ret.try_sub_with_carry(carry, modulus.as_ref()).0
118133
} else {
119134
let mut window;
120135
while remain > 0 {
@@ -124,9 +139,14 @@ pub const fn lincomb_monty_form<const LIMBS: usize>(
124139
}
125140
(window, products) = products.split_at(count);
126141
let mut buf = Uint::ZERO;
127-
let carry =
128-
impl_longa_monty_lincomb!(window, buf.limbs, modulus.0.limbs, mod_neg_inv, LIMBS);
129-
buf = buf.try_sub_with_carry(carry, &modulus.0).0;
142+
let carry = impl_longa_monty_lincomb!(
143+
window,
144+
buf.limbs,
145+
modulus.as_ref().limbs,
146+
mod_neg_inv,
147+
LIMBS
148+
);
149+
buf = buf.try_sub_with_carry(carry, modulus.as_ref()).0;
130150
ret = ret.add_mod(&buf, modulus.as_nz_ref());
131151
remain -= count;
132152
}
@@ -142,26 +162,36 @@ pub fn lincomb_boxed_monty_form(
142162
mod_leading_zeros: u32,
143163
) -> BoxedUint {
144164
let max_accum = 1 << u32_min(mod_leading_zeros, usize::BITS - 1);
145-
let nlimbs = modulus.0.nlimbs();
146-
let mut ret = BoxedUint::zero_with_precision(modulus.0.bits_precision());
165+
let nlimbs = modulus.as_ref().nlimbs();
166+
let mut ret = BoxedUint::zero_with_precision(modulus.as_ref().bits_precision());
147167
let mut remain = products.len();
148168
if remain <= max_accum {
149-
let carry =
150-
impl_longa_monty_lincomb!(products, ret.limbs, modulus.0.limbs, mod_neg_inv, nlimbs);
151-
ret.sub_assign_mod_with_carry(carry, &modulus.0, &modulus.0);
169+
let carry = impl_longa_monty_lincomb!(
170+
products,
171+
ret.limbs,
172+
modulus.as_ref().limbs,
173+
mod_neg_inv,
174+
nlimbs
175+
);
176+
ret.sub_assign_mod_with_carry(carry, modulus.as_ref(), modulus.as_ref());
152177
} else {
153178
let mut window;
154-
let mut buf = BoxedUint::zero_with_precision(modulus.0.bits_precision());
179+
let mut buf = BoxedUint::zero_with_precision(modulus.as_ref().bits_precision());
155180
while remain > 0 {
156181
buf.limbs.fill(Limb::ZERO);
157182
let mut count = remain;
158183
if count > max_accum {
159184
count = max_accum;
160185
}
161186
(window, products) = products.split_at(count);
162-
let carry =
163-
impl_longa_monty_lincomb!(window, buf.limbs, modulus.0.limbs, mod_neg_inv, nlimbs);
164-
buf.sub_assign_mod_with_carry(carry, &modulus.0, &modulus.0);
187+
let carry = impl_longa_monty_lincomb!(
188+
window,
189+
buf.limbs,
190+
modulus.as_ref().limbs,
191+
mod_neg_inv,
192+
nlimbs
193+
);
194+
buf.sub_assign_mod_with_carry(carry, modulus.as_ref(), modulus.as_ref());
165195
ret.add_mod_assign(&buf, modulus.as_nz_ref());
166196
remain -= count;
167197
}

0 commit comments

Comments
 (0)