Skip to content

Commit 3eeb06a

Browse files
feat: add OptimalBinarySearchTree algorithm
1 parent 5e06b15 commit 3eeb06a

File tree

2 files changed

+181
-0
lines changed

2 files changed

+181
-0
lines changed
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
package com.thealgorithms.dynamicprogramming;
2+
3+
import java.util.Arrays;
4+
import java.util.Comparator;
5+
6+
/**
7+
* Computes the minimum search cost of an optimal binary search tree.
8+
*
9+
* <p>The algorithm sorts the keys, preserves the corresponding search frequencies, and uses
10+
* dynamic programming with Knuth's optimization to compute the minimum weighted search cost.
11+
*
12+
* <p>Reference:
13+
* https://en.wikipedia.org/wiki/Optimal_binary_search_tree
14+
*/
15+
public final class OptimalBinarySearchTree {
16+
private OptimalBinarySearchTree() {
17+
}
18+
19+
/**
20+
* Computes the minimum weighted search cost for the given keys and search frequencies.
21+
*
22+
* @param keys the BST keys
23+
* @param frequencies the search frequencies associated with the keys
24+
* @return the minimum search cost
25+
* @throws IllegalArgumentException if the input is invalid
26+
*/
27+
public static long findOptimalCost(int[] keys, int[] frequencies) {
28+
validateInput(keys, frequencies);
29+
if (keys.length == 0) {
30+
return 0L;
31+
}
32+
33+
int[][] sortedNodes = sortNodes(keys, frequencies);
34+
int nodeCount = sortedNodes.length;
35+
long[] prefixSums = buildPrefixSums(sortedNodes);
36+
long[][] optimalCost = new long[nodeCount][nodeCount];
37+
int[][] root = new int[nodeCount][nodeCount];
38+
39+
for (int index = 0; index < nodeCount; index++) {
40+
optimalCost[index][index] = sortedNodes[index][1];
41+
root[index][index] = index;
42+
}
43+
44+
for (int length = 2; length <= nodeCount; length++) {
45+
for (int start = 0; start <= nodeCount - length; start++) {
46+
int end = start + length - 1;
47+
long frequencySum = prefixSums[end + 1] - prefixSums[start];
48+
optimalCost[start][end] = Long.MAX_VALUE;
49+
50+
int leftBoundary = root[start][end - 1];
51+
int rightBoundary = root[start + 1][end];
52+
for (int currentRoot = leftBoundary; currentRoot <= rightBoundary; currentRoot++) {
53+
long leftCost = currentRoot > start ? optimalCost[start][currentRoot - 1] : 0L;
54+
long rightCost = currentRoot < end ? optimalCost[currentRoot + 1][end] : 0L;
55+
long currentCost = frequencySum + leftCost + rightCost;
56+
57+
if (currentCost < optimalCost[start][end]) {
58+
optimalCost[start][end] = currentCost;
59+
root[start][end] = currentRoot;
60+
}
61+
}
62+
}
63+
}
64+
65+
return optimalCost[0][nodeCount - 1];
66+
}
67+
68+
private static void validateInput(int[] keys, int[] frequencies) {
69+
if (keys == null || frequencies == null) {
70+
throw new IllegalArgumentException("Keys and frequencies cannot be null");
71+
}
72+
if (keys.length != frequencies.length) {
73+
throw new IllegalArgumentException("Keys and frequencies must have the same length");
74+
}
75+
76+
for (int frequency : frequencies) {
77+
if (frequency < 0) {
78+
throw new IllegalArgumentException("Frequencies cannot be negative");
79+
}
80+
}
81+
}
82+
83+
private static int[][] sortNodes(int[] keys, int[] frequencies) {
84+
int[][] sortedNodes = new int[keys.length][2];
85+
for (int index = 0; index < keys.length; index++) {
86+
sortedNodes[index][0] = keys[index];
87+
sortedNodes[index][1] = frequencies[index];
88+
}
89+
90+
Arrays.sort(sortedNodes, Comparator.comparingInt(node -> node[0]));
91+
92+
for (int index = 1; index < sortedNodes.length; index++) {
93+
if (sortedNodes[index - 1][0] == sortedNodes[index][0]) {
94+
throw new IllegalArgumentException("Keys must be distinct");
95+
}
96+
}
97+
98+
return sortedNodes;
99+
}
100+
101+
private static long[] buildPrefixSums(int[][] sortedNodes) {
102+
long[] prefixSums = new long[sortedNodes.length + 1];
103+
for (int index = 0; index < sortedNodes.length; index++) {
104+
prefixSums[index + 1] = prefixSums[index] + sortedNodes[index][1];
105+
}
106+
return prefixSums;
107+
}
108+
}
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package com.thealgorithms.dynamicprogramming;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
import static org.junit.jupiter.api.Assertions.assertThrows;
5+
6+
import java.util.Arrays;
7+
import java.util.stream.Stream;
8+
import org.junit.jupiter.params.ParameterizedTest;
9+
import org.junit.jupiter.params.provider.Arguments;
10+
import org.junit.jupiter.params.provider.MethodSource;
11+
12+
class OptimalBinarySearchTreeTest {
13+
14+
@ParameterizedTest
15+
@MethodSource("validTestCases")
16+
void testFindOptimalCost(int[] keys, int[] frequencies, long expectedCost) {
17+
assertEquals(expectedCost, OptimalBinarySearchTree.findOptimalCost(keys, frequencies));
18+
}
19+
20+
private static Stream<Arguments> validTestCases() {
21+
return Stream.of(Arguments.of(new int[] {}, new int[] {}, 0L), Arguments.of(new int[] {15}, new int[] {9}, 9L), Arguments.of(new int[] {10, 12}, new int[] {34, 50}, 118L), Arguments.of(new int[] {20, 10, 30}, new int[] {50, 34, 8}, 134L),
22+
Arguments.of(new int[] {12, 10, 20, 42, 25, 37}, new int[] {8, 34, 50, 3, 40, 30}, 324L), Arguments.of(new int[] {1, 2, 3}, new int[] {0, 0, 0}, 0L));
23+
}
24+
25+
@ParameterizedTest
26+
@MethodSource("crossCheckTestCases")
27+
void testFindOptimalCostAgainstBruteForce(int[] keys, int[] frequencies) {
28+
assertEquals(bruteForceOptimalCost(keys, frequencies), OptimalBinarySearchTree.findOptimalCost(keys, frequencies));
29+
}
30+
31+
private static Stream<Arguments> crossCheckTestCases() {
32+
return Stream.of(Arguments.of(new int[] {3, 1, 2}, new int[] {4, 2, 6}), Arguments.of(new int[] {5, 2, 8, 6}, new int[] {3, 7, 1, 4}), Arguments.of(new int[] {9, 4, 11, 2}, new int[] {1, 8, 2, 5}));
33+
}
34+
35+
@ParameterizedTest
36+
@MethodSource("invalidTestCases")
37+
void testFindOptimalCostInvalidInput(int[] keys, int[] frequencies) {
38+
assertThrows(IllegalArgumentException.class, () -> OptimalBinarySearchTree.findOptimalCost(keys, frequencies));
39+
}
40+
41+
private static Stream<Arguments> invalidTestCases() {
42+
return Stream.of(Arguments.of(null, new int[] {}), Arguments.of(new int[] {}, null), Arguments.of(new int[] {1, 2}, new int[] {3}), Arguments.of(new int[] {1, 1}, new int[] {2, 3}), Arguments.of(new int[] {1, 2}, new int[] {3, -1}));
43+
}
44+
45+
private static long bruteForceOptimalCost(int[] keys, int[] frequencies) {
46+
int[][] sortedNodes = new int[keys.length][2];
47+
for (int index = 0; index < keys.length; index++) {
48+
sortedNodes[index][0] = keys[index];
49+
sortedNodes[index][1] = frequencies[index];
50+
}
51+
Arrays.sort(sortedNodes, java.util.Comparator.comparingInt(node -> node[0]));
52+
53+
int[] sortedFrequencies = new int[sortedNodes.length];
54+
for (int index = 0; index < sortedNodes.length; index++) {
55+
sortedFrequencies[index] = sortedNodes[index][1];
56+
}
57+
58+
return bruteForceOptimalCost(sortedFrequencies, 0, sortedFrequencies.length - 1, 1);
59+
}
60+
61+
private static long bruteForceOptimalCost(int[] frequencies, int start, int end, int depth) {
62+
if (start > end) {
63+
return 0L;
64+
}
65+
66+
long minimumCost = Long.MAX_VALUE;
67+
for (int root = start; root <= end; root++) {
68+
long currentCost = (long) depth * frequencies[root] + bruteForceOptimalCost(frequencies, start, root - 1, depth + 1) + bruteForceOptimalCost(frequencies, root + 1, end, depth + 1);
69+
minimumCost = Math.min(minimumCost, currentCost);
70+
}
71+
return minimumCost;
72+
}
73+
}

0 commit comments

Comments
 (0)