diff --git a/src/main/java/com/williamfiset/algorithms/math/ModPow.java b/src/main/java/com/williamfiset/algorithms/math/ModPow.java index 94994396d..72163afe4 100644 --- a/src/main/java/com/williamfiset/algorithms/math/ModPow.java +++ b/src/main/java/com/williamfiset/algorithms/math/ModPow.java @@ -1,172 +1,73 @@ /** - * NOTE: An issue was found with this file when dealing with negative numbers when exponentiating! - * See bug tracking progress on issue + * Computes modular exponentiation: a^n mod m. * - *
An implementation of the modPow(a, n, mod) operation. This implementation is substantially - * faster than Java's BigInteger class because it only uses primitive types. + * Supports negative exponents via modular inverse (requires gcd(a, m) = 1) and negative bases. + * Uses overflow-safe modular multiplication to handle the full range of long values. * - *
Time Complexity O(lg(n)) + * Time Complexity: O(log(n)) * * @author William Fiset, william.alexandre.fiset@gmail.com */ package com.williamfiset.algorithms.math; -import java.math.BigInteger; - public class ModPow { - // The values placed into the modPow function cannot be greater - // than MAX or less than MIN otherwise long overflow will - // happen when the values get squared (they will exceed 2^63-1) - private static final long MAX = (long) Math.sqrt(Long.MAX_VALUE); - private static final long MIN = -MAX; - - // Computes the Greatest Common Divisor (GCD) of a & b - private static long gcd(long a, long b) { - return b == 0 ? (a < 0 ? -a : a) : gcd(b, a % b); - } - - // This function performs the extended euclidean algorithm on two numbers a and b. - // The function returns the gcd(a,b) as well as the numbers x and y such - // that ax + by = gcd(a,b). This calculation is important in number theory - // and can be used for several things such as finding modular inverses and - // solutions to linear Diophantine equations. - private static long[] egcd(long a, long b) { - if (b == 0) return new long[] {a < 0 ? -a : a, 1L, 0L}; - long[] v = egcd(b, a % b); - long tmp = v[1] - v[2] * (a / b); - v[1] = v[2]; - v[2] = tmp; - return v; - } - - // Returns the modular inverse of 'a' mod 'm' - // Make sure m > 0 and 'a' & 'm' are relatively prime. - private static long modInv(long a, long m) { - - a = ((a % m) + m) % m; - - long[] v = egcd(a, m); - long x = v[1]; - - return ((x % m) + m) % m; - } - - // Computes a^n modulo mod very efficiently in O(lg(n)) time. - // This function supports negative exponent values and a negative - // base, however the modulus must be positive. + /** + * Computes a^n mod m. + * + * @throws ArithmeticException if mod <= 0, or if n < 0 and gcd(a, mod) != 1. + */ public static long modPow(long a, long n, long mod) { + if (mod <= 0) + throw new ArithmeticException("mod must be > 0"); - if (mod <= 0) throw new ArithmeticException("mod must be > 0"); - if (a > MAX || mod > MAX) - throw new IllegalArgumentException("Long overflow is upon you, mod or base is too high!"); - if (a < MIN || mod < MIN) - throw new IllegalArgumentException("Long overflow is upon you, mod or base is too low!"); - - // To handle negative exponents we can use the modular - // inverse of a to our advantage since: a^-n mod m = (a^-1)^n mod m + // a^-n mod m = (a^-1)^n mod m if (n < 0) { if (gcd(a, mod) != 1) throw new ArithmeticException("If n < 0 then must have gcd(a, mod) = 1"); return modPow(modInv(a, mod), -n, mod); } - if (n == 0L) return 1L; - long p = a, r = 1L; + // Normalize base into [0, mod) + a = ((a % mod) + mod) % mod; - for (long i = 0; n != 0; i++) { - long mask = 1L << i; - if ((n & mask) == mask) { - r = (((r * p) % mod) + mod) % mod; - n -= mask; - } - p = ((p * p) % mod + mod) % mod; + long result = 1; + while (n > 0) { + if ((n & 1) == 1) + result = mulMod(result, a, mod); + a = mulMod(a, a, mod); + n >>= 1; } - - return ((r % mod) + mod) % mod; + return result; } - // Example usage - public static void main(String[] args) { - - BigInteger A, N, M, r1; - long a, n, m, r2; - - A = BigInteger.valueOf(3); - N = BigInteger.valueOf(4); - M = BigInteger.valueOf(1000000); - a = A.longValue(); - n = N.longValue(); - m = M.longValue(); - - // 3 ^ 4 mod 1000000 - r1 = A.modPow(N, M); // 81 - r2 = modPow(a, n, m); // 81 - System.out.println(r1 + " " + r2); - - A = BigInteger.valueOf(-45); - N = BigInteger.valueOf(12345); - M = BigInteger.valueOf(987654321); - a = A.longValue(); - n = N.longValue(); - m = M.longValue(); - - // Finds -45 ^ 12345 mod 987654321 - r1 = A.modPow(N, M); // 323182557 - r2 = modPow(a, n, m); // 323182557 - System.out.println(r1 + " " + r2); - - A = BigInteger.valueOf(6); - N = BigInteger.valueOf(-66); - M = BigInteger.valueOf(101); - a = A.longValue(); - n = N.longValue(); - m = M.longValue(); - - // Finds 6 ^ -66 mod 101 - r1 = A.modPow(N, M); // 84 - r2 = modPow(a, n, m); // 84 - System.out.println(r1 + " " + r2); - - A = BigInteger.valueOf(-5); - N = BigInteger.valueOf(-7); - M = BigInteger.valueOf(1009); - a = A.longValue(); - n = N.longValue(); - m = M.longValue(); - - // Finds -5 ^ -7 mod 1009 - r1 = A.modPow(N, M); // 675 - r2 = modPow(a, n, m); // 675 - System.out.println(r1 + " " + r2); - - for (int i = 0; i < 1000; i++) { - A = BigInteger.valueOf(a); - N = BigInteger.valueOf(n); - M = BigInteger.valueOf(m); - a = Math.random() < 0.5 ? randLong(MAX) : -randLong(MAX); - n = randLong(); - m = randLong(MAX); - try { - r1 = A.modPow(N, M); - r2 = modPow(a, n, m); - if (r1.longValue() != r2) - System.out.printf("Broke with: a = %d, n = %d, m = %d\n", a, n, m); - } catch (ArithmeticException e) { - } - } + private static long modInv(long a, long m) { + a = ((a % m) + m) % m; + long x = egcd(a, m)[1]; + return ((x % m) + m) % m; } - /* TESTING RELATED METHODS */ - - static final java.util.Random RANDOM = new java.util.Random(); + private static long[] egcd(long a, long b) { + if (b == 0) + return new long[] {a < 0 ? -a : a, 1L, 0L}; + long[] v = egcd(b, a % b); + long tmp = v[1] - v[2] * (a / b); + v[1] = v[2]; + v[2] = tmp; + return v; + } - // Returns long between [1, bound] - public static long randLong(long bound) { - return java.util.concurrent.ThreadLocalRandom.current().nextLong(1, bound + 1); + private static long gcd(long a, long b) { + a = Math.abs(a); + b = Math.abs(b); + return b == 0 ? a : gcd(b, a % b); } - public static long randLong() { - return RANDOM.nextLong(); + /** Overflow-safe modular multiplication: (a * b) % mod. */ + private static long mulMod(long a, long b, long mod) { + return java.math.BigInteger.valueOf(a) + .multiply(java.math.BigInteger.valueOf(b)) + .mod(java.math.BigInteger.valueOf(mod)) + .longValue(); } } diff --git a/src/test/java/com/williamfiset/algorithms/math/BUILD b/src/test/java/com/williamfiset/algorithms/math/BUILD index 8a33dd8de..0842a265d 100644 --- a/src/test/java/com/williamfiset/algorithms/math/BUILD +++ b/src/test/java/com/williamfiset/algorithms/math/BUILD @@ -58,3 +58,14 @@ java_test( runtime_deps = JUNIT5_RUNTIME_DEPS, deps = TEST_DEPS, ) + +# bazel test //src/test/java/com/williamfiset/algorithms/math:ModPowTest +java_test( + name = "ModPowTest", + srcs = ["ModPowTest.java"], + main_class = "org.junit.platform.console.ConsoleLauncher", + use_testrunner = False, + args = ["--select-class=com.williamfiset.algorithms.math.ModPowTest"], + runtime_deps = JUNIT5_RUNTIME_DEPS, + deps = TEST_DEPS, +) diff --git a/src/test/java/com/williamfiset/algorithms/math/ModPowTest.java b/src/test/java/com/williamfiset/algorithms/math/ModPowTest.java new file mode 100644 index 000000000..d69e9ad3d --- /dev/null +++ b/src/test/java/com/williamfiset/algorithms/math/ModPowTest.java @@ -0,0 +1,100 @@ +package com.williamfiset.algorithms.math; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.math.BigInteger; +import java.util.concurrent.ThreadLocalRandom; +import org.junit.jupiter.api.*; + +public class ModPowTest { + + @Test + public void basicPositiveExponent() { + // 3^4 mod 1000000 = 81 + assertThat(ModPow.modPow(3, 4, 1000000)).isEqualTo(81); + } + + @Test + public void negativeBase() { + // (-45)^12345 mod 987654321 + long expected = + BigInteger.valueOf(-45) + .modPow(BigInteger.valueOf(12345), BigInteger.valueOf(987654321)) + .longValue(); + assertThat(ModPow.modPow(-45, 12345, 987654321)).isEqualTo(expected); + } + + @Test + public void negativeExponent() { + // 6^-66 mod 101 = 84 + long expected = + BigInteger.valueOf(6) + .modPow(BigInteger.valueOf(-66), BigInteger.valueOf(101)) + .longValue(); + assertThat(ModPow.modPow(6, -66, 101)).isEqualTo(expected); + } + + @Test + public void negativeBaseAndExponent() { + // (-5)^-7 mod 1009 + long expected = + BigInteger.valueOf(-5) + .modPow(BigInteger.valueOf(-7), BigInteger.valueOf(1009)) + .longValue(); + assertThat(ModPow.modPow(-5, -7, 1009)).isEqualTo(expected); + } + + @Test + public void exponentZero() { + assertThat(ModPow.modPow(123, 0, 7)).isEqualTo(1); + assertThat(ModPow.modPow(0, 0, 5)).isEqualTo(1); + } + + @Test + public void baseZero() { + assertThat(ModPow.modPow(0, 10, 7)).isEqualTo(0); + } + + @Test + public void modOne() { + // Anything mod 1 = 0 + assertThat(ModPow.modPow(999, 999, 1)).isEqualTo(0); + } + + @Test + public void largeValues() { + // Test with values that would overflow without safe multiplication + long a = 1_000_000_000L; + long n = 1_000_000_000L; + long mod = 999_999_937L; + long expected = + BigInteger.valueOf(a).modPow(BigInteger.valueOf(n), BigInteger.valueOf(mod)).longValue(); + assertThat(ModPow.modPow(a, n, mod)).isEqualTo(expected); + } + + @Test + public void modNonPositiveThrows() { + assertThrows(ArithmeticException.class, () -> ModPow.modPow(2, 3, 0)); + assertThrows(ArithmeticException.class, () -> ModPow.modPow(2, 3, -5)); + } + + @Test + public void negativeExponentNotCoprime() { + // gcd(4, 8) = 4 ≠ 1, so no modular inverse + assertThrows(ArithmeticException.class, () -> ModPow.modPow(4, -1, 8)); + } + + @Test + public void matchesBigIntegerRandomized() { + ThreadLocalRandom rng = ThreadLocalRandom.current(); + for (int i = 0; i < 500; i++) { + long a = rng.nextLong(-1_000_000_000L, 1_000_000_000L); + long n = rng.nextLong(0, 1_000_000_000L); + long mod = rng.nextLong(1, 1_000_000_000L); + long expected = + BigInteger.valueOf(a).modPow(BigInteger.valueOf(n), BigInteger.valueOf(mod)).longValue(); + assertThat(ModPow.modPow(a, n, mod)).isEqualTo(expected); + } + } +}