33// SPDX-License-Identifier: Apache-2.0
44
55#include " modexp.hpp"
6+ #include " mulmod.hpp"
67#include < evmmax/evmmax.hpp>
78#include < bit>
89#include < memory_resource>
@@ -27,53 +28,6 @@ constexpr bool add(std::span<uint64_t> x, std::span<const uint64_t> y) noexcept
2728 return carry;
2829}
2930
30- // / Subtracts y from x: x[] -= y[]. The result is truncated to the size of x.
31- constexpr void sub (std::span<uint64_t > x, std::span<const uint64_t > y) noexcept
32- {
33- assert (x.size () >= y.size ());
34-
35- bool borrow = false ;
36- for (size_t i = 0 ; i < y.size (); ++i)
37- std::tie (x[i], borrow) = subc (x[i], y[i], borrow);
38- for (size_t i = y.size (); borrow && i < x.size (); ++i)
39- std::tie (x[i], borrow) = subc (x[i], uint64_t {0 }, borrow);
40- }
41-
42- // / Multiplies multi-word x by single word y: r[] = x[] * y. Returns the carry word.
43- constexpr uint64_t mul (std::span<uint64_t > r, std::span<const uint64_t > x, uint64_t y) noexcept
44- {
45- assert (r.size () == x.size ());
46-
47- uint64_t c = 0 ;
48- #pragma GCC unroll 4
49- for (size_t i = 0 ; i != x.size (); ++i)
50- {
51- const auto p = umul (x[i], y) + c;
52- r[i] = p[0 ];
53- c = p[1 ];
54- }
55- return c;
56- }
57-
58- // / Multiplies each word of x by y and adds the matching word of p, propagating a carry to the next
59- // / word. Starts with initial carry c. Stores the result in r. Returns the final carry.
60- // / r[] = p[] + x[] * y (+ c).
61- constexpr uint64_t addmul (std::span<uint64_t > r, std::span<const uint64_t > p,
62- std::span<const uint64_t > x, uint64_t y, uint64_t c = 0 ) noexcept
63- {
64- assert (r.size () == p.size ());
65- assert (r.size () == x.size ());
66-
67- #pragma GCC unroll 4
68- for (size_t i = 0 ; i != x.size (); ++i)
69- {
70- const auto t = umul (x[i], y) + p[i] + c;
71- r[i] = t[0 ];
72- c = t[1 ];
73- }
74- return c;
75- }
76-
7731// / Computes multiplication of x times y and truncates the result to the size of r:
7832// / r[] = x[] * y[].
7933constexpr void mul (
@@ -89,7 +43,7 @@ constexpr void mul(
8943 std::swap (x, y);
9044
9145 // First iteration: use mul (not addmul) since r is uninitialized.
92- const auto hi0 = mul (r.first (x.size ()), x, y[0 ]);
46+ const auto hi0 = crypto:: mul (r.first (x.size ()), x, y[0 ]);
9347 if (r.size () > x.size ())
9448 r[x.size ()] = hi0;
9549
@@ -349,8 +303,9 @@ class Exponent
349303// /
350304// / Computes r = x * y * R^-1 mod m (Almost Montgomery Multiplication).
351305// / r must not alias x or y.
352- void mul_amm (std::span<uint64_t > r, std::span<const uint64_t > x, std::span<const uint64_t > y,
353- std::span<const uint64_t > mod, uint64_t mod_inv) noexcept
306+ template <size_t N = std::dynamic_extent>
307+ void mul_amm (std::span<uint64_t , N> r, std::span<const uint64_t , N> x,
308+ std::span<const uint64_t , N> y, std::span<const uint64_t , N> mod, uint64_t mod_inv) noexcept
354309{
355310 // Use Coarsely Integrated Operand Scanning (CIOS) method with the "almost" reduction.
356311 const auto n = r.size ();
@@ -368,7 +323,7 @@ void mul_amm(std::span<uint64_t> r, std::span<const uint64_t> x, std::span<const
368323 // First iteration: r is uninitialized, so use mul instead of addmul.
369324 bool r_carry = false ;
370325 {
371- const auto c1 = mul (r, x, y[0 ]);
326+ const auto c1 = crypto:: mul (r, x, y[0 ]);
372327
373328 const auto m = r[0 ] * mod_inv;
374329 const auto c2 = (umul (mod[0 ], m) + r[0 ])[1 ];
@@ -397,6 +352,15 @@ void mul_amm(std::span<uint64_t> r, std::span<const uint64_t> x, std::span<const
397352 sub (r, mod);
398353}
399354
355+ // / Almost Montgomery Multiplication specialized for 4-word (256-bit) operands.
356+ // / Delegates to mul_amm_256 in mulmod.cpp.
357+ template <>
358+ [[gnu::always_inline]] void mul_amm<4 >(std::span<uint64_t , 4 > r, std::span<const uint64_t , 4 > x,
359+ std::span<const uint64_t , 4 > y, std::span<const uint64_t , 4 > mod, uint64_t mod_inv) noexcept
360+ {
361+ mul_amm_256 (r, x, y, mod, mod_inv);
362+ }
363+
400364// / Computes result[] = base[]^exp % mod[] for odd mod[] (mod[0] % 2 != 0).
401365// / Scratch space required: 4n + 3*base.size() + 2 words, where n = mod.size().
402366void modexp_odd (std::span<uint64_t > result, std::span<const uint64_t > base, Exponent exp,
@@ -424,34 +388,43 @@ void modexp_odd(std::span<uint64_t> result, std::span<const uint64_t> base, Expo
424388 std::ranges::copy (base, u.subspan (n).begin ());
425389 rem (base_mont, u, mod, rem_scratch);
426390
427- // Double-buffer: r1 always holds the current value, r2 is scratch for mul_amm output.
428- auto r_cur = result;
429- auto r_tmp = u.first (n); // Reuse u scratch space.
430- std::ranges::copy (base_mont, r_cur.begin ());
391+ // Double-buffer exponentiation loop, parameterized by mul_amm size.
392+ const auto exp_loop = [&]<size_t N>() {
393+ auto r_cur = std::span<uint64_t , N>{result};
394+ auto r_tmp = std::span<uint64_t , N>{u.first (n)};
395+ const auto bm = std::span<const uint64_t , N>{base_mont};
396+ const auto m = std::span<const uint64_t , N>{mod};
431397
432- for (auto i = exp.bit_width () - 1 ; i != 0 ; --i)
433- {
434- mul_amm (r_tmp, r_cur, r_cur, mod, mod_inv); // Square: r2 = r1 * r1.
435- if (exp[i - 1 ])
436- mul_amm (r_cur, r_tmp, base_mont, mod, mod_inv); // Multiply: r1 = r2 * base_mont.
437- else
438- std::swap (r_cur, r_tmp); // No multiply: adopt r2 as r1.
439- }
398+ std::ranges::copy (bm, r_cur.begin ());
399+ for (auto i = exp.bit_width () - 1 ; i != 0 ; --i)
400+ {
401+ mul_amm<N>(r_tmp, r_cur, r_cur, m, mod_inv); // Square.
402+ if (exp[i - 1 ])
403+ mul_amm<N>(r_cur, r_tmp, bm, m, mod_inv); // Multiply.
404+ else
405+ std::swap (r_cur, r_tmp);
406+ }
440407
441- // Convert from Montgomery form: multiply by 1.
442- std::ranges::fill (base_mont, uint64_t {0 });
443- base_mont[0 ] = 1 ;
444- mul_amm (r_tmp, r_cur, base_mont, mod , mod_inv);
445- std::swap (r_cur, r_tmp);
408+ // Convert from Montgomery form: multiply by 1.
409+ std::ranges::fill (base_mont, uint64_t {0 });
410+ base_mont[0 ] = 1 ;
411+ mul_amm<N> (r_tmp, r_cur, std::span< const uint64_t , N>{ base_mont}, m , mod_inv);
412+ std::swap (r_cur, r_tmp);
446413
447- // Reduce if necessary: AMM can produce mod <= r < 2*mod.
448- if (!less (r_cur, mod))
449- sub (r_cur, mod);
450- assert (less (r_cur, mod));
414+ // If the result ended up in scratch, copy to result.
415+ if (r_cur.data () != result.data ())
416+ std::ranges::copy (r_cur, result.begin ());
417+ };
418+
419+ if (n == 4 )
420+ exp_loop.operator ()<4 >();
421+ else
422+ exp_loop.operator ()<std::dynamic_extent>();
451423
452- // If the result ended up in the scratch buffer, copy to result.
453- if (r_cur.data () != result.data ())
454- std::ranges::copy (r_cur, result.begin ());
424+ // Reduce if necessary: AMM can produce mod <= r < 2*mod.
425+ if (!less (result, mod))
426+ sub (result, mod);
427+ assert (less (result, mod));
455428}
456429
457430// / Trims the multi-word number x[] to k bits.
0 commit comments