Skip to content

Commit 629845c

Browse files
Galley Branch and Bound - DFS version (#362)
* branch_and_bound port. No tests yet * added simple test to see if bb is working * removed redundant code like in greedy * more tests * small fixes * small fixes * small fixes * small precommit fixes * small fixes * small benchmark tests * robust benchmarks for bnb * precommit changes * add bnb example * add test case with much worse compile time * prelimanry benchmark merge * fixed benchmarks with new statsfactory * fixed test_galley_bnb * redo tests in benchamrks * prelimanry dfs for bnb * dfs implementaion * test fix for dfs * suffix cost + test * comment out suffix * small fix * tests * dfs * precommit fixes * readd tests * precommit fixes * dfs fix * dfs fixes * tests+ benchmarks * precommit fixes * fixes * remove greedy from dfs * benchmarks * dfs fix * dfs fix * precommit fix * fix * fix * fixes * fixes * revert config * merge bnb * fix * fix ffuncs --------- Co-authored-by: kylebd99 <kylebd99@gmail.com>
1 parent 1648393 commit 629845c

4 files changed

Lines changed: 294 additions & 88 deletions

File tree

benchmarks/galley_benchmarks.py

Lines changed: 108 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
2-
Galley benchmarks: compile profile (with vs without connected components)
3-
and exact branch-and-bound vs greedy.
2+
Galley benchmarks: compile profile (with vs without connected components),
3+
exact branch-and-bound vs greedy, and BFS BnB vs DFS BnB (same optimum).
4+
45
56
Run: ``poetry run python benchmarks/galley_benchmarks.py``
67
"""
@@ -54,17 +55,15 @@ def _make_inner_loader():
5455

5556
_GALLEY_WITH = GalleyLogicalOptimizer(_make_inner_loader())
5657
_GALLEY_WITHOUT = GalleyLogicalOptimizer(_make_inner_loader(), use_components=False)
57-
_GALLEY_GREEDY = GalleyLogicalOptimizer(
58-
_make_inner_loader(), use_exact_branch_and_bound=False
59-
)
60-
_GALLEY_EXACT_BNB = GalleyLogicalOptimizer(
61-
_make_inner_loader(), use_exact_branch_and_bound=True
62-
)
58+
_GALLEY_GREEDY = GalleyLogicalOptimizer(_make_inner_loader(), optimizer="greedy")
59+
_GALLEY_EXACT_BNB = GalleyLogicalOptimizer(_make_inner_loader(), optimizer="bfs")
60+
_GALLEY_EXACT_BNB_DFS = GalleyLogicalOptimizer(_make_inner_loader(), optimizer="dfs")
6361

6462
GALLEY_COMPILE_PROFILE_WITH = LogicNormalizer(LogicExecutor(_GALLEY_WITH))
6563
GALLEY_COMPILE_PROFILE_WITHOUT = LogicNormalizer(LogicExecutor(_GALLEY_WITHOUT))
6664
GALLEY_PIPELINE_GREEDY = LogicNormalizer(LogicExecutor(_GALLEY_GREEDY))
6765
GALLEY_PIPELINE_EXACT_BNB = LogicNormalizer(LogicExecutor(_GALLEY_EXACT_BNB))
66+
GALLEY_PIPELINE_EXACT_BNB_DFS = LogicNormalizer(LogicExecutor(_GALLEY_EXACT_BNB_DFS))
6867

6968

7069
def _run_and_time(
@@ -218,6 +217,23 @@ def make_bnb_slow_example() -> LazyTensor:
218217
return make_chain_expr_from_shapes([(5, 5) for _ in range(12)])
219218

220219

220+
_BNB_CORE_DFS_WINS_SHAPES: list[tuple[int, int]] = [
221+
(8, 3),
222+
(3, 28),
223+
(28, 13),
224+
(13, 6),
225+
(6, 11),
226+
(11, 52),
227+
(52, 61),
228+
(61, 7),
229+
]
230+
231+
232+
def make_bnb_core_dfs_wins_chain() -> LazyTensor:
233+
"""Eight-matrix matmul chain used for BFS vs DFS Galley benchmarks."""
234+
return make_chain_expr_from_shapes(_BNB_CORE_DFS_WINS_SHAPES)
235+
236+
221237
# =============================================================================
222238
# Timing helpers
223239
# =============================================================================
@@ -308,6 +324,37 @@ def run_greedy() -> tuple[Any, dict[str, float]]:
308324
)
309325

310326

327+
def time_galley_bnb_bfs_vs_dfs_compile_profile(
328+
expr,
329+
*,
330+
n: int = DEFAULT_BNB_PROFILE_N,
331+
recursion_limit: int | None = None,
332+
) -> tuple[dict[str, float], dict[str, float]]:
333+
"""
334+
Average ``optimize_plan_s`` and ``downstream_s`` for BFS exact BnB vs DFS
335+
exact BnB (greedy bounds, then exact; same logical plan cost).
336+
"""
337+
bindings: dict = {}
338+
339+
def run_bfs() -> tuple[Any, dict[str, float]]:
340+
return _run_and_time(
341+
GALLEY_PIPELINE_EXACT_BNB,
342+
_GALLEY_EXACT_BNB,
343+
plan_from_expr(expr),
344+
bindings,
345+
)
346+
347+
def run_dfs() -> tuple[Any, dict[str, float]]:
348+
return _run_and_time(
349+
GALLEY_PIPELINE_EXACT_BNB_DFS,
350+
_GALLEY_EXACT_BNB_DFS,
351+
plan_from_expr(expr),
352+
bindings,
353+
)
354+
355+
return _time_profile_pair(run_bfs, run_dfs, n=n, recursion_limit=recursion_limit)
356+
357+
311358
def _print_profile_comparison(
312359
title: str,
313360
rows: list[tuple[str, dict[str, float], tuple[float, int] | None]],
@@ -346,11 +393,22 @@ def _exact_greedy_plan_stats(
346393
) -> tuple[float, float, int, int]:
347394
aq_e = aq_factory()
348395
aq_g = aq_factory()
349-
queries_exact, cost_exact = pruned_query_to_plan(aq_e, use_greedy=False)
350-
queries_greedy, cost_greedy = pruned_query_to_plan(aq_g, use_greedy=True)
396+
queries_exact, cost_exact = pruned_query_to_plan(aq_e, optimizer="bfs")
397+
queries_greedy, cost_greedy = pruned_query_to_plan(aq_g, optimizer="greedy")
351398
return cost_exact, cost_greedy, len(queries_exact), len(queries_greedy)
352399

353400

401+
def _bfs_dfs_exact_plan_stats(
402+
aq_factory: Callable[[], AnnotatedQuery],
403+
) -> tuple[float, float, int, int]:
404+
"""BFS exact vs DFS exact: optimal costs (must match) and subquery counts."""
405+
aq_b = aq_factory()
406+
aq_d = aq_factory()
407+
queries_b, cost_b = pruned_query_to_plan(aq_b, optimizer="bfs")
408+
queries_d, cost_d = pruned_query_to_plan(aq_d, optimizer="dfs")
409+
return cost_b, cost_d, len(queries_b), len(queries_d)
410+
411+
354412
def _run_compile_case(case: BenchmarkCaseSpec) -> None:
355413
print(f"Compile benchmark: {case.title}...", flush=True)
356414
expr = case.build_expr()
@@ -382,6 +440,30 @@ def _run_bnb_case(case: BenchmarkCaseSpec) -> None:
382440
)
383441

384442

443+
def _run_bnb_bfs_vs_dfs_case(case: BenchmarkCaseSpec) -> None:
444+
if case.recursion_limit is not None:
445+
raise ValueError("BFS vs DFS BnB case must use recursion_limit=None")
446+
print(f"BnB BFS vs DFS benchmark: {case.title}...", flush=True)
447+
expr = case.build_expr()
448+
bfs_t, dfs_t = time_galley_bnb_bfs_vs_dfs_compile_profile(
449+
expr, n=max(3, DEFAULT_BNB_PROFILE_N)
450+
)
451+
cost_b, cost_d, nq_b, nq_d = _bfs_dfs_exact_plan_stats(
452+
lambda: _annotated_query_from_lazy_expr(expr)
453+
)
454+
if not np.isclose(cost_b, cost_d, rtol=0.0, atol=1e-6):
455+
raise AssertionError(
456+
f"BFS vs DFS optimal cost mismatch: bfs={cost_b!r} dfs={cost_d!r}"
457+
)
458+
_print_profile_comparison(
459+
case.title,
460+
[
461+
("Exact BnB (BFS)", bfs_t, (cost_b, nq_b)),
462+
("Exact BnB (DFS)", dfs_t, (cost_d, nq_d)),
463+
],
464+
)
465+
466+
385467
def _compile_benchmark_cases(rng: np.random.Generator) -> tuple[BenchmarkCaseSpec, ...]:
386468
return (
387469
BenchmarkCaseSpec(
@@ -411,6 +493,13 @@ def _compile_benchmark_cases(rng: np.random.Generator) -> tuple[BenchmarkCaseSpe
411493
),
412494
)
413495

496+
BNB_BFS_VS_DFS_CASES: tuple[BenchmarkCaseSpec, ...] = (
497+
BenchmarkCaseSpec(
498+
title="BFS vs DFS — eight-matrix chain",
499+
build_expr=make_bnb_core_dfs_wins_chain,
500+
),
501+
)
502+
414503

415504
# =============================================================================
416505
# main sections (all tests from both original files)
@@ -430,6 +519,12 @@ def main_bnb_benchmarks() -> None:
430519
_run_bnb_case(case)
431520

432521

522+
def main_bnb_bfs_vs_dfs_benchmarks() -> None:
523+
"""BFS exact BnB vs DFS exact BnB (same cost; compare optimize_plan_s)."""
524+
for case in BNB_BFS_VS_DFS_CASES:
525+
_run_bnb_bfs_vs_dfs_case(case)
526+
527+
433528
def main() -> None:
434529
print("", flush=True)
435530
print("### Galley compile benchmarks (with vs without components) ###", flush=True)
@@ -440,6 +535,9 @@ def main() -> None:
440535
print("### Galley BnB vs greedy benchmarks ###", flush=True)
441536
main_bnb_benchmarks()
442537
print("", flush=True)
538+
print("### Galley BFS BnB vs DFS BnB benchmarks ###", flush=True)
539+
main_bnb_bfs_vs_dfs_benchmarks()
540+
print("", flush=True)
443541
print("Done.", flush=True)
444542

445543

0 commit comments

Comments
 (0)