diff --git a/README.md b/README.md
index fa11a9c3f..0a054c079 100644
--- a/README.md
+++ b/README.md
@@ -239,7 +239,7 @@ $ java -cp classes com.williamfiset.algorithms.search.BinarySearch
# Mathematics
-- [[UNTESTED] Chinese remainder theorem](src/main/java/com/williamfiset/algorithms/math/ChineseRemainderTheorem.java)
+- [Chinese remainder theorem](src/main/java/com/williamfiset/algorithms/math/ChineseRemainderTheorem.java)
- [Prime number sieve (sieve of Eratosthenes)](src/main/java/com/williamfiset/algorithms/math/SieveOfEratosthenes.java) **- O(nlog(log(n)))**
- [Prime number sieve (sieve of Eratosthenes, compressed)](src/main/java/com/williamfiset/algorithms/math/CompressedPrimeSieve.java) **- O(nlog(log(n)))**
- [Totient function (phi function, relatively prime number count)](src/main/java/com/williamfiset/algorithms/math/EulerTotientFunction.java) **- O(√n)**
diff --git a/src/main/java/com/williamfiset/algorithms/math/ChineseRemainderTheorem.java b/src/main/java/com/williamfiset/algorithms/math/ChineseRemainderTheorem.java
index a00d29538..a8e237386 100644
--- a/src/main/java/com/williamfiset/algorithms/math/ChineseRemainderTheorem.java
+++ b/src/main/java/com/williamfiset/algorithms/math/ChineseRemainderTheorem.java
@@ -1,20 +1,29 @@
/**
- * Use the chinese remainder theorem to solve a set of congruence equations.
+ * Solve a set of congruence equations using the Chinese Remainder Theorem (CRT).
*
- *
The first method (eliminateCoefficient) is used to reduce an equation of the form cx≡a(mod
- * m)cx≡a(mod m) to the form x≡a_new(mod m_new)x≡anew(mod m_new), which gets rids of the
- * coefficient. A value of null is returned if the coefficient cannot be eliminated.
+ * Given a system of simultaneous congruences:
*
- *
The second method (reduce) is used to reduce a set of equations so that the moduli become
- * pairwise co-prime (which means that we can apply the Chinese Remainder Theorem). The input and
- * output are of the form x≡a_0(mod m_0),...,x≡a_n−1(mod m_n−1)x≡a_0(mod m_0),...,x≡a_n−1(mod
- * m_n−1). Note that the number of equations may change during this process. A value of null is
- * returned if the set of equations cannot be reduced to co-prime moduli.
+ * x ≡ a_0 (mod m_0)
+ * x ≡ a_1 (mod m_1)
+ * ...
+ * x ≡ a_{n-1} (mod m_{n-1})
*
- *
The third method (crt) is the actual Chinese Remainder Theorem. It assumes that all pairs of
- * moduli are co-prime to one another. This solves a set of equations of the form x≡a_0(mod
- * m_0),...,x≡v_n−1(mod m_n−1)x≡a_0(mod m_0),...,x≡v_n−1(mod m_n−1). It's output is of the form
- * x≡a_new(mod m_new)x≡a_new(mod m_new).
+ * where all moduli m_i are pairwise coprime (gcd(m_i, m_j) = 1 for i ≠ j), the CRT guarantees a
+ * unique solution x modulo M = m_0 * m_1 * ... * m_{n-1}.
+ *
+ * The solution is constructed as x = sum of a_i * M_i * y_i (mod M), where M_i = M / m_i and y_i
+ * is the modular inverse of M_i modulo m_i (found via the extended Euclidean algorithm). Each term
+ * contributes a_i for the i-th congruence and vanishes (mod m_j) for all j ≠ i, so the sum
+ * satisfies every equation simultaneously.
+ *
+ * When moduli are not pairwise coprime, the system must first be reduced. Each modulus is split
+ * into prime-power factors (e.g. 12 = 4 * 3), converting one equation into several with
+ * prime-power moduli. Redundant equations are removed and conflicting ones detected. After
+ * reduction, the moduli are pairwise coprime and the standard CRT applies.
+ *
+ * The eliminateCoefficient method handles equations of the form cx ≡ a (mod m) by dividing through
+ * by gcd(c, m) — which is only possible when gcd(c, m) divides a — and then multiplying by the
+ * modular inverse of the reduced coefficient.
*
* @author Micah Stairs
*/
@@ -24,12 +33,16 @@
public class ChineseRemainderTheorem {
- // eliminateCoefficient() takes cx≡a(mod m) and gives x≡a_new(mod m_new).
+ /**
+ * Reduces cx ≡ a (mod m) to x ≡ a' (mod m').
+ *
+ * @return {a', m'} or null if unsolvable.
+ */
public static long[] eliminateCoefficient(long c, long a, long m) {
-
long d = egcd(c, m)[0];
- if (a % d != 0) return null;
+ if (a % d != 0)
+ return null;
c /= d;
a /= d;
@@ -42,36 +55,35 @@ public static long[] eliminateCoefficient(long c, long a, long m) {
return new long[] {a, m};
}
- // reduce() takes a set of equations and reduces them to an equivalent
- // set with pairwise co-prime moduli (or null if not solvable).
+ /**
+ * Reduces a system x ≡ a[i] (mod m[i]) to an equivalent system with pairwise coprime moduli.
+ *
+ * @return {a[], m[]} with coprime moduli, or null if the system is inconsistent.
+ */
public static long[][] reduce(long[] a, long[] m) {
+ List aNew = new ArrayList<>();
+ List mNew = new ArrayList<>();
- List aNew = new ArrayList();
- List mNew = new ArrayList();
-
- // Split up each equation into prime factors
+ // Split each modulus into prime-power factors
for (int i = 0; i < a.length; i++) {
List factors = primeFactorization(m[i]);
Collections.sort(factors);
- ListIterator iterator = factors.listIterator();
- while (iterator.hasNext()) {
- long val = iterator.next();
- long total = val;
- while (iterator.hasNext()) {
- long nextVal = iterator.next();
- if (nextVal == val) {
- total *= val;
- } else {
- iterator.previous();
- break;
- }
+
+ int j = 0;
+ while (j < factors.size()) {
+ long p = factors.get(j);
+ long pk = p;
+ j++;
+ while (j < factors.size() && factors.get(j) == p) {
+ pk *= p;
+ j++;
}
- aNew.add(a[i] % total);
- mNew.add(total);
+ aNew.add(a[i] % pk);
+ mNew.add(pk);
}
}
- // Throw away repeated information and look for conflicts
+ // Remove redundant equations and detect conflicts
for (int i = 0; i < aNew.size(); i++) {
for (int j = i + 1; j < aNew.size(); j++) {
if (mNew.get(i) % mNew.get(j) == 0 || mNew.get(j) % mNew.get(i) == 0) {
@@ -81,111 +93,161 @@ public static long[][] reduce(long[] a, long[] m) {
mNew.remove(j);
j--;
continue;
- } else return null;
+ } else
+ return null;
} else {
if ((aNew.get(j) % mNew.get(i)) == aNew.get(i)) {
aNew.remove(i);
mNew.remove(i);
i--;
break;
- } else return null;
+ } else
+ return null;
}
}
}
}
- // Put result into an array
long[][] res = new long[2][aNew.size()];
for (int i = 0; i < aNew.size(); i++) {
res[0][i] = aNew.get(i);
res[1][i] = mNew.get(i);
}
-
return res;
}
+ /**
+ * Solves x ≡ a[i] (mod m[i]) assuming all moduli are pairwise coprime.
+ *
+ * @return {x, M} where M is the product of all moduli.
+ */
public static long[] crt(long[] a, long[] m) {
-
long M = 1;
- for (int i = 0; i < m.length; i++) M *= m[i];
-
- long[] inv = new long[a.length];
- for (int i = 0; i < inv.length; i++) inv[i] = egcd(M / m[i], m[i])[1];
+ for (long mi : m)
+ M *= mi;
long x = 0;
for (int i = 0; i < m.length; i++) {
- x += (M / m[i]) * a[i] * inv[i]; // Overflow could occur here
+ long Mi = M / m[i];
+ long inv = egcd(Mi, m[i])[1];
+ x += Mi * a[i] * inv;
x = ((x % M) + M) % M;
}
return new long[] {x, M};
}
- private static ArrayList primeFactorization(long n) {
- ArrayList factors = new ArrayList();
- if (n <= 0) throw new IllegalArgumentException();
- else if (n == 1) return factors;
- PriorityQueue divisorQueue = new PriorityQueue();
- divisorQueue.add(n);
- while (!divisorQueue.isEmpty()) {
- long divisor = divisorQueue.remove();
- if (isPrime(divisor)) {
- factors.add(divisor);
- continue;
- }
- long next_divisor = pollardRho(divisor);
- if (next_divisor == divisor) {
- divisorQueue.add(divisor);
- } else {
- divisorQueue.add(next_divisor);
- divisorQueue.add(divisor / next_divisor);
- }
- }
+ /** Extended Euclidean algorithm. Returns {gcd(a,b), x, y} where ax + by = gcd(a,b). */
+ static long[] egcd(long a, long b) {
+ if (b == 0)
+ return new long[] {a, 1, 0};
+ long[] ret = egcd(b, a % b);
+ long tmp = ret[1] - ret[2] * (a / b);
+ ret[1] = ret[2];
+ ret[2] = tmp;
+ return ret;
+ }
+
+ private static List primeFactorization(long n) {
+ if (n <= 0)
+ throw new IllegalArgumentException();
+ List factors = new ArrayList<>();
+ factor(n, factors);
return factors;
}
+ private static void factor(long n, List factors) {
+ if (n == 1)
+ return;
+ if (isPrime(n)) {
+ factors.add(n);
+ return;
+ }
+ long d = pollardRho(n);
+ factor(d, factors);
+ factor(n / d, factors);
+ }
+
private static long pollardRho(long n) {
- if (n % 2 == 0) return 2;
- // Get a number in the range [2, 10^6]
+ if (n % 2 == 0)
+ return 2;
long x = 2 + (long) (999999 * Math.random());
long c = 2 + (long) (999999 * Math.random());
long y = x;
long d = 1;
while (d == 1) {
- x = (x * x + c) % n;
- y = (y * y + c) % n;
- y = (y * y + c) % n;
- d = gcf(Math.abs(x - y), n);
- if (d == n) break;
+ x = mulMod(x, x, n) + c;
+ if (x >= n)
+ x -= n;
+ y = mulMod(y, y, n) + c;
+ if (y >= n)
+ y -= n;
+ y = mulMod(y, y, n) + c;
+ if (y >= n)
+ y -= n;
+ d = gcd(Math.abs(x - y), n);
+ if (d == n)
+ break;
}
return d;
}
- // Extended euclidean algorithm
- private static long[] egcd(long a, long b) {
- if (b == 0) return new long[] {a, 1, 0};
- else {
- long[] ret = egcd(b, a % b);
- long tmp = ret[1] - ret[2] * (a / b);
- ret[1] = ret[2];
- ret[2] = tmp;
- return ret;
+ /**
+ * Deterministic Miller-Rabin primality test, correct for all long values. Uses 12 witnesses
+ * guaranteed correct for n < 3.317 × 10^24.
+ */
+ private static boolean isPrime(long n) {
+ if (n < 2)
+ return false;
+ if (n < 4)
+ return true;
+ if (n % 2 == 0 || n % 3 == 0)
+ return false;
+
+ long d = n - 1;
+ int r = Long.numberOfTrailingZeros(d);
+ d >>= r;
+
+ for (long a : new long[] {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) {
+ if (a >= n)
+ continue;
+ long x = powMod(a, d, n);
+ if (x == 1 || x == n - 1)
+ continue;
+ boolean composite = true;
+ for (int i = 0; i < r - 1; i++) {
+ x = mulMod(x, x, n);
+ if (x == n - 1) {
+ composite = false;
+ break;
+ }
+ }
+ if (composite)
+ return false;
}
+ return true;
}
- private static long gcf(long a, long b) {
- return b == 0 ? a : gcf(b, a % b);
+ private static long powMod(long base, long exp, long mod) {
+ long result = 1;
+ base %= mod;
+ while (exp > 0) {
+ if ((exp & 1) == 1)
+ result = mulMod(result, base, mod);
+ exp >>= 1;
+ base = mulMod(base, base, mod);
+ }
+ return result;
}
- private static boolean isPrime(long n) {
- if (n < 2) return false;
- if (n == 2 || n == 3) return true;
- if (n % 2 == 0 || n % 3 == 0) return false;
-
- int limit = (int) Math.sqrt(n);
-
- for (int i = 5; i <= limit; i += 6) if (n % i == 0 || n % (i + 2) == 0) return false;
+ private static long mulMod(long a, long b, long mod) {
+ return java.math.BigInteger.valueOf(a)
+ .multiply(java.math.BigInteger.valueOf(b))
+ .mod(java.math.BigInteger.valueOf(mod))
+ .longValue();
+ }
- return true;
+ private static long gcd(long a, long b) {
+ return b == 0 ? a : gcd(b, a % b);
}
}
diff --git a/src/test/java/com/williamfiset/algorithms/math/BUILD b/src/test/java/com/williamfiset/algorithms/math/BUILD
index 138ee1134..8a33dd8de 100644
--- a/src/test/java/com/williamfiset/algorithms/math/BUILD
+++ b/src/test/java/com/williamfiset/algorithms/math/BUILD
@@ -47,3 +47,14 @@ java_test(
runtime_deps = JUNIT5_RUNTIME_DEPS,
deps = TEST_DEPS,
)
+
+# bazel test //src/test/java/com/williamfiset/algorithms/math:ChineseRemainderTheoremTest
+java_test(
+ name = "ChineseRemainderTheoremTest",
+ srcs = ["ChineseRemainderTheoremTest.java"],
+ main_class = "org.junit.platform.console.ConsoleLauncher",
+ use_testrunner = False,
+ args = ["--select-class=com.williamfiset.algorithms.math.ChineseRemainderTheoremTest"],
+ runtime_deps = JUNIT5_RUNTIME_DEPS,
+ deps = TEST_DEPS,
+)
diff --git a/src/test/java/com/williamfiset/algorithms/math/ChineseRemainderTheoremTest.java b/src/test/java/com/williamfiset/algorithms/math/ChineseRemainderTheoremTest.java
new file mode 100644
index 000000000..ed312dcb3
--- /dev/null
+++ b/src/test/java/com/williamfiset/algorithms/math/ChineseRemainderTheoremTest.java
@@ -0,0 +1,167 @@
+package com.williamfiset.algorithms.math;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.jupiter.api.Assertions.assertNull;
+
+import org.junit.jupiter.api.*;
+
+public class ChineseRemainderTheoremTest {
+
+ // --- eliminateCoefficient tests ---
+
+ @Test
+ public void eliminateCoefficient_simple() {
+ // 3x ≡ 6 (mod 9) → x ≡ 2 (mod 3)
+ long[] result = ChineseRemainderTheorem.eliminateCoefficient(3, 6, 9);
+ assertThat(result).isNotNull();
+ assertThat(result[0]).isEqualTo(2);
+ assertThat(result[1]).isEqualTo(3);
+ }
+
+ @Test
+ public void eliminateCoefficient_coefficientOne() {
+ // 1x ≡ 5 (mod 7) → x ≡ 5 (mod 7)
+ long[] result = ChineseRemainderTheorem.eliminateCoefficient(1, 5, 7);
+ assertThat(result).isNotNull();
+ assertThat(result[0]).isEqualTo(5);
+ assertThat(result[1]).isEqualTo(7);
+ }
+
+ @Test
+ public void eliminateCoefficient_unsolvable() {
+ // 2x ≡ 3 (mod 4) — no solution since gcd(2,4)=2 does not divide 3
+ assertNull(ChineseRemainderTheorem.eliminateCoefficient(2, 3, 4));
+ }
+
+ @Test
+ public void eliminateCoefficient_coprime() {
+ // 3x ≡ 1 (mod 7) → x ≡ 5 (mod 7) since 3*5=15≡1 (mod 7)
+ long[] result = ChineseRemainderTheorem.eliminateCoefficient(3, 1, 7);
+ assertThat(result).isNotNull();
+ assertThat(result[0]).isEqualTo(5);
+ assertThat(result[1]).isEqualTo(7);
+ }
+
+ // --- crt tests ---
+
+ @Test
+ public void crt_classicExample() {
+ // x ≡ 2 (mod 3), x ≡ 3 (mod 5), x ≡ 2 (mod 7) → x ≡ 23 (mod 105)
+ long[] result = ChineseRemainderTheorem.crt(new long[] {2, 3, 2}, new long[] {3, 5, 7});
+ assertThat(result[0]).isEqualTo(23);
+ assertThat(result[1]).isEqualTo(105);
+ }
+
+ @Test
+ public void crt_twoEquations() {
+ // x ≡ 1 (mod 2), x ≡ 2 (mod 3) → x ≡ 5 (mod 6)
+ long[] result = ChineseRemainderTheorem.crt(new long[] {1, 2}, new long[] {2, 3});
+ assertThat(result[0]).isEqualTo(5);
+ assertThat(result[1]).isEqualTo(6);
+ }
+
+ @Test
+ public void crt_singleEquation() {
+ long[] result = ChineseRemainderTheorem.crt(new long[] {3}, new long[] {7});
+ assertThat(result[0]).isEqualTo(3);
+ assertThat(result[1]).isEqualTo(7);
+ }
+
+ @Test
+ public void crt_resultSatisfiesAllCongruences() {
+ long[] a = {1, 2, 3};
+ long[] m = {5, 7, 11};
+ long[] result = ChineseRemainderTheorem.crt(a, m);
+ for (int i = 0; i < a.length; i++)
+ assertThat(result[0] % m[i]).isEqualTo(a[i]);
+ }
+
+ @Test
+ public void crt_zeroRemainders() {
+ // x ≡ 0 (mod 3), x ≡ 0 (mod 5) → x ≡ 0 (mod 15)
+ long[] result = ChineseRemainderTheorem.crt(new long[] {0, 0}, new long[] {3, 5});
+ assertThat(result[0]).isEqualTo(0);
+ assertThat(result[1]).isEqualTo(15);
+ }
+
+ // --- reduce tests ---
+
+ @Test
+ public void reduce_alreadyCoprime() {
+ long[][] result = ChineseRemainderTheorem.reduce(new long[] {1, 2}, new long[] {3, 5});
+ assertThat(result).isNotNull();
+ // Should preserve the equations since 3 and 5 are already coprime
+ assertThat(result[0]).asList().containsExactly(1L, 2L);
+ assertThat(result[1]).asList().containsExactly(3L, 5L);
+ }
+
+ @Test
+ public void reduce_sharedPrimeFactor() {
+ // x ≡ 1 (mod 6), x ≡ 3 (mod 10) — share factor 2
+ // 6 = 2·3, 10 = 2·5 → split to mod {2,3,2,5}
+ long[][] result = ChineseRemainderTheorem.reduce(new long[] {1, 3}, new long[] {6, 10});
+ assertThat(result).isNotNull();
+ // The reduced system should be solvable via CRT
+ long[] crtResult = ChineseRemainderTheorem.crt(result[0], result[1]);
+ // Verify solution satisfies original congruences
+ assertThat(crtResult[0] % 6).isEqualTo(1);
+ assertThat(crtResult[0] % 10).isEqualTo(3);
+ }
+
+ @Test
+ public void reduce_inconsistent() {
+ // x ≡ 1 (mod 4), x ≡ 2 (mod 8) — inconsistent since 2 mod 4 = 2 ≠ 1
+ assertNull(ChineseRemainderTheorem.reduce(new long[] {1, 2}, new long[] {4, 8}));
+ }
+
+ @Test
+ public void reduce_redundantEquation() {
+ // x ≡ 1 (mod 2), x ≡ 1 (mod 4) — second subsumes first
+ long[][] result = ChineseRemainderTheorem.reduce(new long[] {1, 1}, new long[] {2, 4});
+ assertThat(result).isNotNull();
+ long[] crtResult = ChineseRemainderTheorem.crt(result[0], result[1]);
+ assertThat(crtResult[0] % 4).isEqualTo(1);
+ }
+
+ // --- egcd tests ---
+
+ @Test
+ public void egcd_basicProperties() {
+ long[] result = ChineseRemainderTheorem.egcd(35, 15);
+ assertThat(result[0]).isEqualTo(5); // gcd(35,15) = 5
+ // Verify Bezout's identity: 35*x + 15*y = 5
+ assertThat(35 * result[1] + 15 * result[2]).isEqualTo(5);
+ }
+
+ @Test
+ public void egcd_coprime() {
+ long[] result = ChineseRemainderTheorem.egcd(7, 11);
+ assertThat(result[0]).isEqualTo(1);
+ assertThat(7 * result[1] + 11 * result[2]).isEqualTo(1);
+ }
+
+ // --- Integration: reduce + crt ---
+
+ @Test
+ public void reduceAndCrt_fullPipeline() {
+ // Solve x ≡ 2 (mod 12), x ≡ 8 (mod 10)
+ // 12 = 4·3, 10 = 2·5 — share factor 2, consistent since 2 ≡ 0 (mod 2) and 8 ≡ 0 (mod 2)
+ long[][] reduced = ChineseRemainderTheorem.reduce(new long[] {2, 8}, new long[] {12, 10});
+ assertThat(reduced).isNotNull();
+ long[] result = ChineseRemainderTheorem.crt(reduced[0], reduced[1]);
+ assertThat(result[0] % 12).isEqualTo(2);
+ assertThat(result[0] % 10).isEqualTo(8);
+ }
+
+ @Test
+ public void reduceAndCrt_threeCoprime() {
+ // Already coprime — reduce should pass through, CRT solves directly
+ long[] a = {2, 3, 2};
+ long[] m = {3, 5, 7};
+ long[][] reduced = ChineseRemainderTheorem.reduce(a, m);
+ assertThat(reduced).isNotNull();
+ long[] result = ChineseRemainderTheorem.crt(reduced[0], reduced[1]);
+ assertThat(result[0]).isEqualTo(23);
+ assertThat(result[1]).isEqualTo(105);
+ }
+}