Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 44 additions & 143 deletions src/main/java/com/williamfiset/algorithms/math/ModPow.java
Original file line number Diff line number Diff line change
@@ -1,172 +1,73 @@
/**
* NOTE: An issue was found with this file when dealing with negative numbers when exponentiating!
* See bug tracking progress on issue
* Computes modular exponentiation: a^n mod m.
*
* <p>An implementation of the modPow(a, n, mod) operation. This implementation is substantially
* faster than Java's BigInteger class because it only uses primitive types.
* Supports negative exponents via modular inverse (requires gcd(a, m) = 1) and negative bases.
* Uses overflow-safe modular multiplication to handle the full range of long values.
*
* <p>Time Complexity O(lg(n))
* Time Complexity: O(log(n))
*
* @author William Fiset, william.alexandre.fiset@gmail.com
*/
package com.williamfiset.algorithms.math;

import java.math.BigInteger;

public class ModPow {

// The values placed into the modPow function cannot be greater
// than MAX or less than MIN otherwise long overflow will
// happen when the values get squared (they will exceed 2^63-1)
private static final long MAX = (long) Math.sqrt(Long.MAX_VALUE);
private static final long MIN = -MAX;

// Computes the Greatest Common Divisor (GCD) of a & b
private static long gcd(long a, long b) {
return b == 0 ? (a < 0 ? -a : a) : gcd(b, a % b);
}

// This function performs the extended euclidean algorithm on two numbers a and b.
// The function returns the gcd(a,b) as well as the numbers x and y such
// that ax + by = gcd(a,b). This calculation is important in number theory
// and can be used for several things such as finding modular inverses and
// solutions to linear Diophantine equations.
private static long[] egcd(long a, long b) {
if (b == 0) return new long[] {a < 0 ? -a : a, 1L, 0L};
long[] v = egcd(b, a % b);
long tmp = v[1] - v[2] * (a / b);
v[1] = v[2];
v[2] = tmp;
return v;
}

// Returns the modular inverse of 'a' mod 'm'
// Make sure m > 0 and 'a' & 'm' are relatively prime.
private static long modInv(long a, long m) {

a = ((a % m) + m) % m;

long[] v = egcd(a, m);
long x = v[1];

return ((x % m) + m) % m;
}

// Computes a^n modulo mod very efficiently in O(lg(n)) time.
// This function supports negative exponent values and a negative
// base, however the modulus must be positive.
/**
* Computes a^n mod m.
*
* @throws ArithmeticException if mod <= 0, or if n < 0 and gcd(a, mod) != 1.
*/
public static long modPow(long a, long n, long mod) {
if (mod <= 0)
throw new ArithmeticException("mod must be > 0");

if (mod <= 0) throw new ArithmeticException("mod must be > 0");
if (a > MAX || mod > MAX)
throw new IllegalArgumentException("Long overflow is upon you, mod or base is too high!");
if (a < MIN || mod < MIN)
throw new IllegalArgumentException("Long overflow is upon you, mod or base is too low!");

// To handle negative exponents we can use the modular
// inverse of a to our advantage since: a^-n mod m = (a^-1)^n mod m
// a^-n mod m = (a^-1)^n mod m
if (n < 0) {
if (gcd(a, mod) != 1)
throw new ArithmeticException("If n < 0 then must have gcd(a, mod) = 1");
return modPow(modInv(a, mod), -n, mod);
}

if (n == 0L) return 1L;
long p = a, r = 1L;
// Normalize base into [0, mod)
a = ((a % mod) + mod) % mod;

for (long i = 0; n != 0; i++) {
long mask = 1L << i;
if ((n & mask) == mask) {
r = (((r * p) % mod) + mod) % mod;
n -= mask;
}
p = ((p * p) % mod + mod) % mod;
long result = 1;
while (n > 0) {
if ((n & 1) == 1)
result = mulMod(result, a, mod);
a = mulMod(a, a, mod);
n >>= 1;
}

return ((r % mod) + mod) % mod;
return result;
}

// Example usage
public static void main(String[] args) {

BigInteger A, N, M, r1;
long a, n, m, r2;

A = BigInteger.valueOf(3);
N = BigInteger.valueOf(4);
M = BigInteger.valueOf(1000000);
a = A.longValue();
n = N.longValue();
m = M.longValue();

// 3 ^ 4 mod 1000000
r1 = A.modPow(N, M); // 81
r2 = modPow(a, n, m); // 81
System.out.println(r1 + " " + r2);

A = BigInteger.valueOf(-45);
N = BigInteger.valueOf(12345);
M = BigInteger.valueOf(987654321);
a = A.longValue();
n = N.longValue();
m = M.longValue();

// Finds -45 ^ 12345 mod 987654321
r1 = A.modPow(N, M); // 323182557
r2 = modPow(a, n, m); // 323182557
System.out.println(r1 + " " + r2);

A = BigInteger.valueOf(6);
N = BigInteger.valueOf(-66);
M = BigInteger.valueOf(101);
a = A.longValue();
n = N.longValue();
m = M.longValue();

// Finds 6 ^ -66 mod 101
r1 = A.modPow(N, M); // 84
r2 = modPow(a, n, m); // 84
System.out.println(r1 + " " + r2);

A = BigInteger.valueOf(-5);
N = BigInteger.valueOf(-7);
M = BigInteger.valueOf(1009);
a = A.longValue();
n = N.longValue();
m = M.longValue();

// Finds -5 ^ -7 mod 1009
r1 = A.modPow(N, M); // 675
r2 = modPow(a, n, m); // 675
System.out.println(r1 + " " + r2);

for (int i = 0; i < 1000; i++) {
A = BigInteger.valueOf(a);
N = BigInteger.valueOf(n);
M = BigInteger.valueOf(m);
a = Math.random() < 0.5 ? randLong(MAX) : -randLong(MAX);
n = randLong();
m = randLong(MAX);
try {
r1 = A.modPow(N, M);
r2 = modPow(a, n, m);
if (r1.longValue() != r2)
System.out.printf("Broke with: a = %d, n = %d, m = %d\n", a, n, m);
} catch (ArithmeticException e) {
}
}
private static long modInv(long a, long m) {
a = ((a % m) + m) % m;
long x = egcd(a, m)[1];
return ((x % m) + m) % m;
}

/* TESTING RELATED METHODS */

static final java.util.Random RANDOM = new java.util.Random();
private static long[] egcd(long a, long b) {
if (b == 0)
return new long[] {a < 0 ? -a : a, 1L, 0L};
long[] v = egcd(b, a % b);
long tmp = v[1] - v[2] * (a / b);
v[1] = v[2];
v[2] = tmp;
return v;
}

// Returns long between [1, bound]
public static long randLong(long bound) {
return java.util.concurrent.ThreadLocalRandom.current().nextLong(1, bound + 1);
private static long gcd(long a, long b) {
a = Math.abs(a);
b = Math.abs(b);
return b == 0 ? a : gcd(b, a % b);
}

public static long randLong() {
return RANDOM.nextLong();
/** Overflow-safe modular multiplication: (a * b) % mod. */
private static long mulMod(long a, long b, long mod) {
return java.math.BigInteger.valueOf(a)
.multiply(java.math.BigInteger.valueOf(b))
.mod(java.math.BigInteger.valueOf(mod))
.longValue();
}
}
11 changes: 11 additions & 0 deletions src/test/java/com/williamfiset/algorithms/math/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,14 @@ java_test(
runtime_deps = JUNIT5_RUNTIME_DEPS,
deps = TEST_DEPS,
)

# bazel test //src/test/java/com/williamfiset/algorithms/math:ModPowTest
java_test(
name = "ModPowTest",
srcs = ["ModPowTest.java"],
main_class = "org.junit.platform.console.ConsoleLauncher",
use_testrunner = False,
args = ["--select-class=com.williamfiset.algorithms.math.ModPowTest"],
runtime_deps = JUNIT5_RUNTIME_DEPS,
deps = TEST_DEPS,
)
100 changes: 100 additions & 0 deletions src/test/java/com/williamfiset/algorithms/math/ModPowTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package com.williamfiset.algorithms.math;

import static com.google.common.truth.Truth.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;

import java.math.BigInteger;
import java.util.concurrent.ThreadLocalRandom;
import org.junit.jupiter.api.*;

public class ModPowTest {

@Test
public void basicPositiveExponent() {
// 3^4 mod 1000000 = 81
assertThat(ModPow.modPow(3, 4, 1000000)).isEqualTo(81);
}

@Test
public void negativeBase() {
// (-45)^12345 mod 987654321
long expected =
BigInteger.valueOf(-45)
.modPow(BigInteger.valueOf(12345), BigInteger.valueOf(987654321))
.longValue();
assertThat(ModPow.modPow(-45, 12345, 987654321)).isEqualTo(expected);
}

@Test
public void negativeExponent() {
// 6^-66 mod 101 = 84
long expected =
BigInteger.valueOf(6)
.modPow(BigInteger.valueOf(-66), BigInteger.valueOf(101))
.longValue();
assertThat(ModPow.modPow(6, -66, 101)).isEqualTo(expected);
}

@Test
public void negativeBaseAndExponent() {
// (-5)^-7 mod 1009
long expected =
BigInteger.valueOf(-5)
.modPow(BigInteger.valueOf(-7), BigInteger.valueOf(1009))
.longValue();
assertThat(ModPow.modPow(-5, -7, 1009)).isEqualTo(expected);
}

@Test
public void exponentZero() {
assertThat(ModPow.modPow(123, 0, 7)).isEqualTo(1);
assertThat(ModPow.modPow(0, 0, 5)).isEqualTo(1);
}

@Test
public void baseZero() {
assertThat(ModPow.modPow(0, 10, 7)).isEqualTo(0);
}

@Test
public void modOne() {
// Anything mod 1 = 0
assertThat(ModPow.modPow(999, 999, 1)).isEqualTo(0);
}

@Test
public void largeValues() {
// Test with values that would overflow without safe multiplication
long a = 1_000_000_000L;
long n = 1_000_000_000L;
long mod = 999_999_937L;
long expected =
BigInteger.valueOf(a).modPow(BigInteger.valueOf(n), BigInteger.valueOf(mod)).longValue();
assertThat(ModPow.modPow(a, n, mod)).isEqualTo(expected);
}

@Test
public void modNonPositiveThrows() {
assertThrows(ArithmeticException.class, () -> ModPow.modPow(2, 3, 0));
assertThrows(ArithmeticException.class, () -> ModPow.modPow(2, 3, -5));
}

@Test
public void negativeExponentNotCoprime() {
// gcd(4, 8) = 4 ≠ 1, so no modular inverse
assertThrows(ArithmeticException.class, () -> ModPow.modPow(4, -1, 8));
}

@Test
public void matchesBigIntegerRandomized() {
ThreadLocalRandom rng = ThreadLocalRandom.current();
for (int i = 0; i < 500; i++) {
long a = rng.nextLong(-1_000_000_000L, 1_000_000_000L);
long n = rng.nextLong(0, 1_000_000_000L);
long mod = rng.nextLong(1, 1_000_000_000L);
long expected =
BigInteger.valueOf(a).modPow(BigInteger.valueOf(n), BigInteger.valueOf(mod)).longValue();
assertThat(ModPow.modPow(a, n, mod)).isEqualTo(expected);
}
}
}
Loading