Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ $ java -cp classes com.williamfiset.algorithms.search.BinarySearch

# Mathematics

- [[UNTESTED] Chinese remainder theorem](src/main/java/com/williamfiset/algorithms/math/ChineseRemainderTheorem.java)
- [Chinese remainder theorem](src/main/java/com/williamfiset/algorithms/math/ChineseRemainderTheorem.java)
- [Prime number sieve (sieve of Eratosthenes)](src/main/java/com/williamfiset/algorithms/math/SieveOfEratosthenes.java) **- O(nlog(log(n)))**
- [Prime number sieve (sieve of Eratosthenes, compressed)](src/main/java/com/williamfiset/algorithms/math/CompressedPrimeSieve.java) **- O(nlog(log(n)))**
- [Totient function (phi function, relatively prime number count)](src/main/java/com/williamfiset/algorithms/math/EulerTotientFunction.java) **- O(√n)**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,29 @@
/**
* Use the chinese remainder theorem to solve a set of congruence equations.
* Solve a set of congruence equations using the Chinese Remainder Theorem (CRT).
*
* <p>The first method (eliminateCoefficient) is used to reduce an equation of the form cx≡a(mod
* m)cx≡a(mod m) to the form x≡a_new(mod m_new)x≡anew(mod m_new), which gets rids of the
* coefficient. A value of null is returned if the coefficient cannot be eliminated.
* Given a system of simultaneous congruences:
*
* <p>The second method (reduce) is used to reduce a set of equations so that the moduli become
* pairwise co-prime (which means that we can apply the Chinese Remainder Theorem). The input and
* output are of the form x≡a_0(mod m_0),...,x≡a_n−1(mod m_n−1)x≡a_0(mod m_0),...,x≡a_n−1(mod
* m_n−1). Note that the number of equations may change during this process. A value of null is
* returned if the set of equations cannot be reduced to co-prime moduli.
* x ≡ a_0 (mod m_0)
* x ≡ a_1 (mod m_1)
* ...
* x ≡ a_{n-1} (mod m_{n-1})
*
* <p>The third method (crt) is the actual Chinese Remainder Theorem. It assumes that all pairs of
* moduli are co-prime to one another. This solves a set of equations of the form x≡a_0(mod
* m_0),...,x≡v_n−1(mod m_n−1)x≡a_0(mod m_0),...,x≡v_n−1(mod m_n−1). It's output is of the form
* x≡a_new(mod m_new)x≡a_new(mod m_new).
* where all moduli m_i are pairwise coprime (gcd(m_i, m_j) = 1 for i ≠ j), the CRT guarantees a
* unique solution x modulo M = m_0 * m_1 * ... * m_{n-1}.
*
* The solution is constructed as x = sum of a_i * M_i * y_i (mod M), where M_i = M / m_i and y_i
* is the modular inverse of M_i modulo m_i (found via the extended Euclidean algorithm). Each term
* contributes a_i for the i-th congruence and vanishes (mod m_j) for all j ≠ i, so the sum
* satisfies every equation simultaneously.
*
* When moduli are not pairwise coprime, the system must first be reduced. Each modulus is split
* into prime-power factors (e.g. 12 = 4 * 3), converting one equation into several with
* prime-power moduli. Redundant equations are removed and conflicting ones detected. After
* reduction, the moduli are pairwise coprime and the standard CRT applies.
*
* The eliminateCoefficient method handles equations of the form cx ≡ a (mod m) by dividing through
* by gcd(c, m) — which is only possible when gcd(c, m) divides a — and then multiplying by the
* modular inverse of the reduced coefficient.
*
* @author Micah Stairs
*/
Expand All @@ -24,12 +33,16 @@

public class ChineseRemainderTheorem {

// eliminateCoefficient() takes cx≡a(mod m) and gives x≡a_new(mod m_new).
/**
* Reduces cx ≡ a (mod m) to x ≡ a' (mod m').
*
* @return {a', m'} or null if unsolvable.
*/
public static long[] eliminateCoefficient(long c, long a, long m) {

long d = egcd(c, m)[0];

if (a % d != 0) return null;
if (a % d != 0)
return null;

c /= d;
a /= d;
Expand All @@ -42,36 +55,35 @@ public static long[] eliminateCoefficient(long c, long a, long m) {
return new long[] {a, m};
}

// reduce() takes a set of equations and reduces them to an equivalent
// set with pairwise co-prime moduli (or null if not solvable).
/**
* Reduces a system x ≡ a[i] (mod m[i]) to an equivalent system with pairwise coprime moduli.
*
* @return {a[], m[]} with coprime moduli, or null if the system is inconsistent.
*/
public static long[][] reduce(long[] a, long[] m) {
List<Long> aNew = new ArrayList<>();
List<Long> mNew = new ArrayList<>();

List<Long> aNew = new ArrayList<Long>();
List<Long> mNew = new ArrayList<Long>();

// Split up each equation into prime factors
// Split each modulus into prime-power factors
for (int i = 0; i < a.length; i++) {
List<Long> factors = primeFactorization(m[i]);
Collections.sort(factors);
ListIterator<Long> iterator = factors.listIterator();
while (iterator.hasNext()) {
long val = iterator.next();
long total = val;
while (iterator.hasNext()) {
long nextVal = iterator.next();
if (nextVal == val) {
total *= val;
} else {
iterator.previous();
break;
}

int j = 0;
while (j < factors.size()) {
long p = factors.get(j);
long pk = p;
j++;
while (j < factors.size() && factors.get(j) == p) {
pk *= p;
j++;
}
aNew.add(a[i] % total);
mNew.add(total);
aNew.add(a[i] % pk);
mNew.add(pk);
}
}

// Throw away repeated information and look for conflicts
// Remove redundant equations and detect conflicts
for (int i = 0; i < aNew.size(); i++) {
for (int j = i + 1; j < aNew.size(); j++) {
if (mNew.get(i) % mNew.get(j) == 0 || mNew.get(j) % mNew.get(i) == 0) {
Expand All @@ -81,111 +93,161 @@ public static long[][] reduce(long[] a, long[] m) {
mNew.remove(j);
j--;
continue;
} else return null;
} else
return null;
} else {
if ((aNew.get(j) % mNew.get(i)) == aNew.get(i)) {
aNew.remove(i);
mNew.remove(i);
i--;
break;
} else return null;
} else
return null;
}
}
}
}

// Put result into an array
long[][] res = new long[2][aNew.size()];
for (int i = 0; i < aNew.size(); i++) {
res[0][i] = aNew.get(i);
res[1][i] = mNew.get(i);
}

return res;
}

/**
* Solves x ≡ a[i] (mod m[i]) assuming all moduli are pairwise coprime.
*
* @return {x, M} where M is the product of all moduli.
*/
public static long[] crt(long[] a, long[] m) {

long M = 1;
for (int i = 0; i < m.length; i++) M *= m[i];

long[] inv = new long[a.length];
for (int i = 0; i < inv.length; i++) inv[i] = egcd(M / m[i], m[i])[1];
for (long mi : m)
M *= mi;

long x = 0;
for (int i = 0; i < m.length; i++) {
x += (M / m[i]) * a[i] * inv[i]; // Overflow could occur here
long Mi = M / m[i];
long inv = egcd(Mi, m[i])[1];
x += Mi * a[i] * inv;
x = ((x % M) + M) % M;
}

return new long[] {x, M};
}

private static ArrayList<Long> primeFactorization(long n) {
ArrayList<Long> factors = new ArrayList<Long>();
if (n <= 0) throw new IllegalArgumentException();
else if (n == 1) return factors;
PriorityQueue<Long> divisorQueue = new PriorityQueue<Long>();
divisorQueue.add(n);
while (!divisorQueue.isEmpty()) {
long divisor = divisorQueue.remove();
if (isPrime(divisor)) {
factors.add(divisor);
continue;
}
long next_divisor = pollardRho(divisor);
if (next_divisor == divisor) {
divisorQueue.add(divisor);
} else {
divisorQueue.add(next_divisor);
divisorQueue.add(divisor / next_divisor);
}
}
/** Extended Euclidean algorithm. Returns {gcd(a,b), x, y} where ax + by = gcd(a,b). */
static long[] egcd(long a, long b) {
if (b == 0)
return new long[] {a, 1, 0};
long[] ret = egcd(b, a % b);
long tmp = ret[1] - ret[2] * (a / b);
ret[1] = ret[2];
ret[2] = tmp;
return ret;
}

private static List<Long> primeFactorization(long n) {
if (n <= 0)
throw new IllegalArgumentException();
List<Long> factors = new ArrayList<>();
factor(n, factors);
return factors;
}

private static void factor(long n, List<Long> factors) {
if (n == 1)
return;
if (isPrime(n)) {
factors.add(n);
return;
}
long d = pollardRho(n);
factor(d, factors);
factor(n / d, factors);
}

private static long pollardRho(long n) {
if (n % 2 == 0) return 2;
// Get a number in the range [2, 10^6]
if (n % 2 == 0)
return 2;
long x = 2 + (long) (999999 * Math.random());
long c = 2 + (long) (999999 * Math.random());
long y = x;
long d = 1;
while (d == 1) {
x = (x * x + c) % n;
y = (y * y + c) % n;
y = (y * y + c) % n;
d = gcf(Math.abs(x - y), n);
if (d == n) break;
x = mulMod(x, x, n) + c;
if (x >= n)
x -= n;
y = mulMod(y, y, n) + c;
if (y >= n)
y -= n;
y = mulMod(y, y, n) + c;
if (y >= n)
y -= n;
d = gcd(Math.abs(x - y), n);
if (d == n)
break;
}
return d;
}

// Extended euclidean algorithm
private static long[] egcd(long a, long b) {
if (b == 0) return new long[] {a, 1, 0};
else {
long[] ret = egcd(b, a % b);
long tmp = ret[1] - ret[2] * (a / b);
ret[1] = ret[2];
ret[2] = tmp;
return ret;
/**
* Deterministic Miller-Rabin primality test, correct for all long values. Uses 12 witnesses
* guaranteed correct for n < 3.317 × 10^24.
*/
private static boolean isPrime(long n) {
if (n < 2)
return false;
if (n < 4)
return true;
if (n % 2 == 0 || n % 3 == 0)
return false;

long d = n - 1;
int r = Long.numberOfTrailingZeros(d);
d >>= r;

for (long a : new long[] {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) {
if (a >= n)
continue;
long x = powMod(a, d, n);
if (x == 1 || x == n - 1)
continue;
boolean composite = true;
for (int i = 0; i < r - 1; i++) {
x = mulMod(x, x, n);
if (x == n - 1) {
composite = false;
break;
}
}
if (composite)
return false;
}
return true;
}

private static long gcf(long a, long b) {
return b == 0 ? a : gcf(b, a % b);
private static long powMod(long base, long exp, long mod) {
long result = 1;
base %= mod;
while (exp > 0) {
if ((exp & 1) == 1)
result = mulMod(result, base, mod);
exp >>= 1;
base = mulMod(base, base, mod);
}
return result;
}

private static boolean isPrime(long n) {
if (n < 2) return false;
if (n == 2 || n == 3) return true;
if (n % 2 == 0 || n % 3 == 0) return false;

int limit = (int) Math.sqrt(n);

for (int i = 5; i <= limit; i += 6) if (n % i == 0 || n % (i + 2) == 0) return false;
private static long mulMod(long a, long b, long mod) {
return java.math.BigInteger.valueOf(a)
.multiply(java.math.BigInteger.valueOf(b))
.mod(java.math.BigInteger.valueOf(mod))
.longValue();
}

return true;
private static long gcd(long a, long b) {
return b == 0 ? a : gcd(b, a % b);
}
}
11 changes: 11 additions & 0 deletions src/test/java/com/williamfiset/algorithms/math/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,14 @@ java_test(
runtime_deps = JUNIT5_RUNTIME_DEPS,
deps = TEST_DEPS,
)

# bazel test //src/test/java/com/williamfiset/algorithms/math:ChineseRemainderTheoremTest
java_test(
name = "ChineseRemainderTheoremTest",
srcs = ["ChineseRemainderTheoremTest.java"],
main_class = "org.junit.platform.console.ConsoleLauncher",
use_testrunner = False,
args = ["--select-class=com.williamfiset.algorithms.math.ChineseRemainderTheoremTest"],
runtime_deps = JUNIT5_RUNTIME_DEPS,
deps = TEST_DEPS,
)
Loading
Loading