diff --git a/README.md b/README.md index fa11a9c3f..0a054c079 100644 --- a/README.md +++ b/README.md @@ -239,7 +239,7 @@ $ java -cp classes com.williamfiset.algorithms.search.BinarySearch # Mathematics -- [[UNTESTED] Chinese remainder theorem](src/main/java/com/williamfiset/algorithms/math/ChineseRemainderTheorem.java) +- [Chinese remainder theorem](src/main/java/com/williamfiset/algorithms/math/ChineseRemainderTheorem.java) - [Prime number sieve (sieve of Eratosthenes)](src/main/java/com/williamfiset/algorithms/math/SieveOfEratosthenes.java) **- O(nlog(log(n)))** - [Prime number sieve (sieve of Eratosthenes, compressed)](src/main/java/com/williamfiset/algorithms/math/CompressedPrimeSieve.java) **- O(nlog(log(n)))** - [Totient function (phi function, relatively prime number count)](src/main/java/com/williamfiset/algorithms/math/EulerTotientFunction.java) **- O(√n)** diff --git a/src/main/java/com/williamfiset/algorithms/math/ChineseRemainderTheorem.java b/src/main/java/com/williamfiset/algorithms/math/ChineseRemainderTheorem.java index a00d29538..a8e237386 100644 --- a/src/main/java/com/williamfiset/algorithms/math/ChineseRemainderTheorem.java +++ b/src/main/java/com/williamfiset/algorithms/math/ChineseRemainderTheorem.java @@ -1,20 +1,29 @@ /** - * Use the chinese remainder theorem to solve a set of congruence equations. + * Solve a set of congruence equations using the Chinese Remainder Theorem (CRT). * - *

The first method (eliminateCoefficient) is used to reduce an equation of the form cx≡a(mod - * m)cx≡a(mod m) to the form x≡a_new(mod m_new)x≡anew(mod m_new), which gets rids of the - * coefficient. A value of null is returned if the coefficient cannot be eliminated. + * Given a system of simultaneous congruences: * - *

The second method (reduce) is used to reduce a set of equations so that the moduli become - * pairwise co-prime (which means that we can apply the Chinese Remainder Theorem). The input and - * output are of the form x≡a_0(mod m_0),...,x≡a_n−1(mod m_n−1)x≡a_0(mod m_0),...,x≡a_n−1(mod - * m_n−1). Note that the number of equations may change during this process. A value of null is - * returned if the set of equations cannot be reduced to co-prime moduli. + * x ≡ a_0 (mod m_0) + * x ≡ a_1 (mod m_1) + * ... + * x ≡ a_{n-1} (mod m_{n-1}) * - *

The third method (crt) is the actual Chinese Remainder Theorem. It assumes that all pairs of - * moduli are co-prime to one another. This solves a set of equations of the form x≡a_0(mod - * m_0),...,x≡v_n−1(mod m_n−1)x≡a_0(mod m_0),...,x≡v_n−1(mod m_n−1). It's output is of the form - * x≡a_new(mod m_new)x≡a_new(mod m_new). + * where all moduli m_i are pairwise coprime (gcd(m_i, m_j) = 1 for i ≠ j), the CRT guarantees a + * unique solution x modulo M = m_0 * m_1 * ... * m_{n-1}. + * + * The solution is constructed as x = sum of a_i * M_i * y_i (mod M), where M_i = M / m_i and y_i + * is the modular inverse of M_i modulo m_i (found via the extended Euclidean algorithm). Each term + * contributes a_i for the i-th congruence and vanishes (mod m_j) for all j ≠ i, so the sum + * satisfies every equation simultaneously. + * + * When moduli are not pairwise coprime, the system must first be reduced. Each modulus is split + * into prime-power factors (e.g. 12 = 4 * 3), converting one equation into several with + * prime-power moduli. Redundant equations are removed and conflicting ones detected. After + * reduction, the moduli are pairwise coprime and the standard CRT applies. + * + * The eliminateCoefficient method handles equations of the form cx ≡ a (mod m) by dividing through + * by gcd(c, m) — which is only possible when gcd(c, m) divides a — and then multiplying by the + * modular inverse of the reduced coefficient. * * @author Micah Stairs */ @@ -24,12 +33,16 @@ public class ChineseRemainderTheorem { - // eliminateCoefficient() takes cx≡a(mod m) and gives x≡a_new(mod m_new). + /** + * Reduces cx ≡ a (mod m) to x ≡ a' (mod m'). + * + * @return {a', m'} or null if unsolvable. + */ public static long[] eliminateCoefficient(long c, long a, long m) { - long d = egcd(c, m)[0]; - if (a % d != 0) return null; + if (a % d != 0) + return null; c /= d; a /= d; @@ -42,36 +55,35 @@ public static long[] eliminateCoefficient(long c, long a, long m) { return new long[] {a, m}; } - // reduce() takes a set of equations and reduces them to an equivalent - // set with pairwise co-prime moduli (or null if not solvable). + /** + * Reduces a system x ≡ a[i] (mod m[i]) to an equivalent system with pairwise coprime moduli. + * + * @return {a[], m[]} with coprime moduli, or null if the system is inconsistent. + */ public static long[][] reduce(long[] a, long[] m) { + List aNew = new ArrayList<>(); + List mNew = new ArrayList<>(); - List aNew = new ArrayList(); - List mNew = new ArrayList(); - - // Split up each equation into prime factors + // Split each modulus into prime-power factors for (int i = 0; i < a.length; i++) { List factors = primeFactorization(m[i]); Collections.sort(factors); - ListIterator iterator = factors.listIterator(); - while (iterator.hasNext()) { - long val = iterator.next(); - long total = val; - while (iterator.hasNext()) { - long nextVal = iterator.next(); - if (nextVal == val) { - total *= val; - } else { - iterator.previous(); - break; - } + + int j = 0; + while (j < factors.size()) { + long p = factors.get(j); + long pk = p; + j++; + while (j < factors.size() && factors.get(j) == p) { + pk *= p; + j++; } - aNew.add(a[i] % total); - mNew.add(total); + aNew.add(a[i] % pk); + mNew.add(pk); } } - // Throw away repeated information and look for conflicts + // Remove redundant equations and detect conflicts for (int i = 0; i < aNew.size(); i++) { for (int j = i + 1; j < aNew.size(); j++) { if (mNew.get(i) % mNew.get(j) == 0 || mNew.get(j) % mNew.get(i) == 0) { @@ -81,111 +93,161 @@ public static long[][] reduce(long[] a, long[] m) { mNew.remove(j); j--; continue; - } else return null; + } else + return null; } else { if ((aNew.get(j) % mNew.get(i)) == aNew.get(i)) { aNew.remove(i); mNew.remove(i); i--; break; - } else return null; + } else + return null; } } } } - // Put result into an array long[][] res = new long[2][aNew.size()]; for (int i = 0; i < aNew.size(); i++) { res[0][i] = aNew.get(i); res[1][i] = mNew.get(i); } - return res; } + /** + * Solves x ≡ a[i] (mod m[i]) assuming all moduli are pairwise coprime. + * + * @return {x, M} where M is the product of all moduli. + */ public static long[] crt(long[] a, long[] m) { - long M = 1; - for (int i = 0; i < m.length; i++) M *= m[i]; - - long[] inv = new long[a.length]; - for (int i = 0; i < inv.length; i++) inv[i] = egcd(M / m[i], m[i])[1]; + for (long mi : m) + M *= mi; long x = 0; for (int i = 0; i < m.length; i++) { - x += (M / m[i]) * a[i] * inv[i]; // Overflow could occur here + long Mi = M / m[i]; + long inv = egcd(Mi, m[i])[1]; + x += Mi * a[i] * inv; x = ((x % M) + M) % M; } return new long[] {x, M}; } - private static ArrayList primeFactorization(long n) { - ArrayList factors = new ArrayList(); - if (n <= 0) throw new IllegalArgumentException(); - else if (n == 1) return factors; - PriorityQueue divisorQueue = new PriorityQueue(); - divisorQueue.add(n); - while (!divisorQueue.isEmpty()) { - long divisor = divisorQueue.remove(); - if (isPrime(divisor)) { - factors.add(divisor); - continue; - } - long next_divisor = pollardRho(divisor); - if (next_divisor == divisor) { - divisorQueue.add(divisor); - } else { - divisorQueue.add(next_divisor); - divisorQueue.add(divisor / next_divisor); - } - } + /** Extended Euclidean algorithm. Returns {gcd(a,b), x, y} where ax + by = gcd(a,b). */ + static long[] egcd(long a, long b) { + if (b == 0) + return new long[] {a, 1, 0}; + long[] ret = egcd(b, a % b); + long tmp = ret[1] - ret[2] * (a / b); + ret[1] = ret[2]; + ret[2] = tmp; + return ret; + } + + private static List primeFactorization(long n) { + if (n <= 0) + throw new IllegalArgumentException(); + List factors = new ArrayList<>(); + factor(n, factors); return factors; } + private static void factor(long n, List factors) { + if (n == 1) + return; + if (isPrime(n)) { + factors.add(n); + return; + } + long d = pollardRho(n); + factor(d, factors); + factor(n / d, factors); + } + private static long pollardRho(long n) { - if (n % 2 == 0) return 2; - // Get a number in the range [2, 10^6] + if (n % 2 == 0) + return 2; long x = 2 + (long) (999999 * Math.random()); long c = 2 + (long) (999999 * Math.random()); long y = x; long d = 1; while (d == 1) { - x = (x * x + c) % n; - y = (y * y + c) % n; - y = (y * y + c) % n; - d = gcf(Math.abs(x - y), n); - if (d == n) break; + x = mulMod(x, x, n) + c; + if (x >= n) + x -= n; + y = mulMod(y, y, n) + c; + if (y >= n) + y -= n; + y = mulMod(y, y, n) + c; + if (y >= n) + y -= n; + d = gcd(Math.abs(x - y), n); + if (d == n) + break; } return d; } - // Extended euclidean algorithm - private static long[] egcd(long a, long b) { - if (b == 0) return new long[] {a, 1, 0}; - else { - long[] ret = egcd(b, a % b); - long tmp = ret[1] - ret[2] * (a / b); - ret[1] = ret[2]; - ret[2] = tmp; - return ret; + /** + * Deterministic Miller-Rabin primality test, correct for all long values. Uses 12 witnesses + * guaranteed correct for n < 3.317 × 10^24. + */ + private static boolean isPrime(long n) { + if (n < 2) + return false; + if (n < 4) + return true; + if (n % 2 == 0 || n % 3 == 0) + return false; + + long d = n - 1; + int r = Long.numberOfTrailingZeros(d); + d >>= r; + + for (long a : new long[] {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) { + if (a >= n) + continue; + long x = powMod(a, d, n); + if (x == 1 || x == n - 1) + continue; + boolean composite = true; + for (int i = 0; i < r - 1; i++) { + x = mulMod(x, x, n); + if (x == n - 1) { + composite = false; + break; + } + } + if (composite) + return false; } + return true; } - private static long gcf(long a, long b) { - return b == 0 ? a : gcf(b, a % b); + private static long powMod(long base, long exp, long mod) { + long result = 1; + base %= mod; + while (exp > 0) { + if ((exp & 1) == 1) + result = mulMod(result, base, mod); + exp >>= 1; + base = mulMod(base, base, mod); + } + return result; } - private static boolean isPrime(long n) { - if (n < 2) return false; - if (n == 2 || n == 3) return true; - if (n % 2 == 0 || n % 3 == 0) return false; - - int limit = (int) Math.sqrt(n); - - for (int i = 5; i <= limit; i += 6) if (n % i == 0 || n % (i + 2) == 0) return false; + private static long mulMod(long a, long b, long mod) { + return java.math.BigInteger.valueOf(a) + .multiply(java.math.BigInteger.valueOf(b)) + .mod(java.math.BigInteger.valueOf(mod)) + .longValue(); + } - return true; + private static long gcd(long a, long b) { + return b == 0 ? a : gcd(b, a % b); } } diff --git a/src/test/java/com/williamfiset/algorithms/math/BUILD b/src/test/java/com/williamfiset/algorithms/math/BUILD index 138ee1134..8a33dd8de 100644 --- a/src/test/java/com/williamfiset/algorithms/math/BUILD +++ b/src/test/java/com/williamfiset/algorithms/math/BUILD @@ -47,3 +47,14 @@ java_test( runtime_deps = JUNIT5_RUNTIME_DEPS, deps = TEST_DEPS, ) + +# bazel test //src/test/java/com/williamfiset/algorithms/math:ChineseRemainderTheoremTest +java_test( + name = "ChineseRemainderTheoremTest", + srcs = ["ChineseRemainderTheoremTest.java"], + main_class = "org.junit.platform.console.ConsoleLauncher", + use_testrunner = False, + args = ["--select-class=com.williamfiset.algorithms.math.ChineseRemainderTheoremTest"], + runtime_deps = JUNIT5_RUNTIME_DEPS, + deps = TEST_DEPS, +) diff --git a/src/test/java/com/williamfiset/algorithms/math/ChineseRemainderTheoremTest.java b/src/test/java/com/williamfiset/algorithms/math/ChineseRemainderTheoremTest.java new file mode 100644 index 000000000..ed312dcb3 --- /dev/null +++ b/src/test/java/com/williamfiset/algorithms/math/ChineseRemainderTheoremTest.java @@ -0,0 +1,167 @@ +package com.williamfiset.algorithms.math; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.jupiter.api.Assertions.assertNull; + +import org.junit.jupiter.api.*; + +public class ChineseRemainderTheoremTest { + + // --- eliminateCoefficient tests --- + + @Test + public void eliminateCoefficient_simple() { + // 3x ≡ 6 (mod 9) → x ≡ 2 (mod 3) + long[] result = ChineseRemainderTheorem.eliminateCoefficient(3, 6, 9); + assertThat(result).isNotNull(); + assertThat(result[0]).isEqualTo(2); + assertThat(result[1]).isEqualTo(3); + } + + @Test + public void eliminateCoefficient_coefficientOne() { + // 1x ≡ 5 (mod 7) → x ≡ 5 (mod 7) + long[] result = ChineseRemainderTheorem.eliminateCoefficient(1, 5, 7); + assertThat(result).isNotNull(); + assertThat(result[0]).isEqualTo(5); + assertThat(result[1]).isEqualTo(7); + } + + @Test + public void eliminateCoefficient_unsolvable() { + // 2x ≡ 3 (mod 4) — no solution since gcd(2,4)=2 does not divide 3 + assertNull(ChineseRemainderTheorem.eliminateCoefficient(2, 3, 4)); + } + + @Test + public void eliminateCoefficient_coprime() { + // 3x ≡ 1 (mod 7) → x ≡ 5 (mod 7) since 3*5=15≡1 (mod 7) + long[] result = ChineseRemainderTheorem.eliminateCoefficient(3, 1, 7); + assertThat(result).isNotNull(); + assertThat(result[0]).isEqualTo(5); + assertThat(result[1]).isEqualTo(7); + } + + // --- crt tests --- + + @Test + public void crt_classicExample() { + // x ≡ 2 (mod 3), x ≡ 3 (mod 5), x ≡ 2 (mod 7) → x ≡ 23 (mod 105) + long[] result = ChineseRemainderTheorem.crt(new long[] {2, 3, 2}, new long[] {3, 5, 7}); + assertThat(result[0]).isEqualTo(23); + assertThat(result[1]).isEqualTo(105); + } + + @Test + public void crt_twoEquations() { + // x ≡ 1 (mod 2), x ≡ 2 (mod 3) → x ≡ 5 (mod 6) + long[] result = ChineseRemainderTheorem.crt(new long[] {1, 2}, new long[] {2, 3}); + assertThat(result[0]).isEqualTo(5); + assertThat(result[1]).isEqualTo(6); + } + + @Test + public void crt_singleEquation() { + long[] result = ChineseRemainderTheorem.crt(new long[] {3}, new long[] {7}); + assertThat(result[0]).isEqualTo(3); + assertThat(result[1]).isEqualTo(7); + } + + @Test + public void crt_resultSatisfiesAllCongruences() { + long[] a = {1, 2, 3}; + long[] m = {5, 7, 11}; + long[] result = ChineseRemainderTheorem.crt(a, m); + for (int i = 0; i < a.length; i++) + assertThat(result[0] % m[i]).isEqualTo(a[i]); + } + + @Test + public void crt_zeroRemainders() { + // x ≡ 0 (mod 3), x ≡ 0 (mod 5) → x ≡ 0 (mod 15) + long[] result = ChineseRemainderTheorem.crt(new long[] {0, 0}, new long[] {3, 5}); + assertThat(result[0]).isEqualTo(0); + assertThat(result[1]).isEqualTo(15); + } + + // --- reduce tests --- + + @Test + public void reduce_alreadyCoprime() { + long[][] result = ChineseRemainderTheorem.reduce(new long[] {1, 2}, new long[] {3, 5}); + assertThat(result).isNotNull(); + // Should preserve the equations since 3 and 5 are already coprime + assertThat(result[0]).asList().containsExactly(1L, 2L); + assertThat(result[1]).asList().containsExactly(3L, 5L); + } + + @Test + public void reduce_sharedPrimeFactor() { + // x ≡ 1 (mod 6), x ≡ 3 (mod 10) — share factor 2 + // 6 = 2·3, 10 = 2·5 → split to mod {2,3,2,5} + long[][] result = ChineseRemainderTheorem.reduce(new long[] {1, 3}, new long[] {6, 10}); + assertThat(result).isNotNull(); + // The reduced system should be solvable via CRT + long[] crtResult = ChineseRemainderTheorem.crt(result[0], result[1]); + // Verify solution satisfies original congruences + assertThat(crtResult[0] % 6).isEqualTo(1); + assertThat(crtResult[0] % 10).isEqualTo(3); + } + + @Test + public void reduce_inconsistent() { + // x ≡ 1 (mod 4), x ≡ 2 (mod 8) — inconsistent since 2 mod 4 = 2 ≠ 1 + assertNull(ChineseRemainderTheorem.reduce(new long[] {1, 2}, new long[] {4, 8})); + } + + @Test + public void reduce_redundantEquation() { + // x ≡ 1 (mod 2), x ≡ 1 (mod 4) — second subsumes first + long[][] result = ChineseRemainderTheorem.reduce(new long[] {1, 1}, new long[] {2, 4}); + assertThat(result).isNotNull(); + long[] crtResult = ChineseRemainderTheorem.crt(result[0], result[1]); + assertThat(crtResult[0] % 4).isEqualTo(1); + } + + // --- egcd tests --- + + @Test + public void egcd_basicProperties() { + long[] result = ChineseRemainderTheorem.egcd(35, 15); + assertThat(result[0]).isEqualTo(5); // gcd(35,15) = 5 + // Verify Bezout's identity: 35*x + 15*y = 5 + assertThat(35 * result[1] + 15 * result[2]).isEqualTo(5); + } + + @Test + public void egcd_coprime() { + long[] result = ChineseRemainderTheorem.egcd(7, 11); + assertThat(result[0]).isEqualTo(1); + assertThat(7 * result[1] + 11 * result[2]).isEqualTo(1); + } + + // --- Integration: reduce + crt --- + + @Test + public void reduceAndCrt_fullPipeline() { + // Solve x ≡ 2 (mod 12), x ≡ 8 (mod 10) + // 12 = 4·3, 10 = 2·5 — share factor 2, consistent since 2 ≡ 0 (mod 2) and 8 ≡ 0 (mod 2) + long[][] reduced = ChineseRemainderTheorem.reduce(new long[] {2, 8}, new long[] {12, 10}); + assertThat(reduced).isNotNull(); + long[] result = ChineseRemainderTheorem.crt(reduced[0], reduced[1]); + assertThat(result[0] % 12).isEqualTo(2); + assertThat(result[0] % 10).isEqualTo(8); + } + + @Test + public void reduceAndCrt_threeCoprime() { + // Already coprime — reduce should pass through, CRT solves directly + long[] a = {2, 3, 2}; + long[] m = {3, 5, 7}; + long[][] reduced = ChineseRemainderTheorem.reduce(a, m); + assertThat(reduced).isNotNull(); + long[] result = ChineseRemainderTheorem.crt(reduced[0], reduced[1]); + assertThat(result[0]).isEqualTo(23); + assertThat(result[1]).isEqualTo(105); + } +}