Skip to content

Commit 0294a63

Browse files
Restore B&B solver
1 parent b6361b8 commit 0294a63

5 files changed

Lines changed: 943 additions & 0 deletions

File tree

pyqrackising/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,6 @@
1515
from .solve_maxcut_exact import solve_maxcut_exact
1616
from .solve_maxcut_exact_sparse import solve_maxcut_exact_sparse
1717
from .solve_maxcut_exact_streaming import solve_maxcut_exact_streaming
18+
from .solve_maxcut_bnb import solve_maxcut_bnb
19+
from .solve_maxcut_bnb_sparse import solve_maxcut_bnb_sparse
20+
from .solve_maxcut_bnb_streaming import solve_maxcut_bnb_streaming

pyqrackising/solve_maxcut_bnb.py

Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
# Copyright (C) 2026 Daniel Strano and the Qrack contributors
3+
#
4+
# Initial draft produced by (Anthropic) Claude (Sonnet 4.6).
5+
#
6+
# bnb_exact.py — warm-start branch-and-bound exact MAXCUT solver (dense)
7+
#
8+
# Objective convention (matches PyQrackIsing throughout):
9+
# Maximise sum_{(i,j) in cut} w_{ij}
10+
# i.e. edge (i,j) contributes w_{ij} when bits[i] != bits[j].
11+
# Edge weights may be positive or negative; no diagonal / self-loops.
12+
#
13+
# Parallelism strategy
14+
# --------------------
15+
# The serial LP relaxation bound used in earlier drafts costs O(n^3) per
16+
# node, making it useless at scale. We replace it with a parallelised
17+
# decoupled upper bound that costs O(n^2) parallel work per node:
18+
#
19+
# UB(node) = fixed_fixed_cut (exact)
20+
# + for each free-free edge (i,j): max(w_ij, 0)
21+
# + for each fixed-free edge (i,j): max(w_ij, 0) if i free, etc.
22+
#
23+
# This bound is valid and is computed entirely inside @njit(parallel=True)
24+
# kernels. A batch of frontier nodes is evaluated simultaneously, and leaf
25+
# scoring is also parallelised across the batch.
26+
27+
import time
28+
import os
29+
import networkx as nx
30+
import numpy as np
31+
from numba import njit, prange
32+
from .maxcut_tfim_util import (
33+
compute_cut,
34+
compute_energy,
35+
get_cut,
36+
heuristic_threshold,
37+
int_to_bitstring,
38+
opencl_context,
39+
)
40+
41+
dtype = opencl_context.dtype
42+
43+
44+
# ---------------------------------------------------------------------------
45+
# Parallel kernels
46+
# ---------------------------------------------------------------------------
47+
48+
@njit(parallel=True, cache=True)
49+
def _upper_bound_batch(G_m, fixed_vars, n, batch_size):
50+
"""
51+
Parallel upper bound for a batch of B&B nodes.
52+
fixed_vars : (batch_size, n) int8, -1=free, 0/1=fixed.
53+
For each node b, accumulates:
54+
- exact cut for fixed-fixed pairs
55+
- max(w, 0) for any edge touching at least one free variable
56+
Returns ub[batch_size].
57+
"""
58+
ub = np.empty(batch_size)
59+
for b in prange(batch_size):
60+
total = 0.0
61+
for i in range(n):
62+
fi = fixed_vars[b, i]
63+
for j in range(i + 1, n):
64+
w = G_m[i, j]
65+
if w == 0.0:
66+
continue
67+
fj = fixed_vars[b, j]
68+
if fi >= 0 and fj >= 0:
69+
if fi != fj:
70+
total += w
71+
else:
72+
if w > 0.0:
73+
total += w
74+
ub[b] = total
75+
return ub
76+
77+
78+
@njit(parallel=True, cache=True)
79+
def _eval_leaves_cut(G_m, fixed_vars, n, batch_size):
80+
vals = np.empty(batch_size)
81+
for b in prange(batch_size):
82+
cut = 0.0
83+
for i in range(n):
84+
bi = fixed_vars[b, i]
85+
for j in range(i + 1, n):
86+
if G_m[i, j] != 0.0 and bi != fixed_vars[b, j]:
87+
cut += G_m[i, j]
88+
vals[b] = cut
89+
return vals
90+
91+
92+
@njit(parallel=True, cache=True)
93+
def _eval_leaves_energy(G_m, fixed_vars, n, batch_size):
94+
vals = np.empty(batch_size)
95+
for b in prange(batch_size):
96+
energy = 0.0
97+
for i in range(n):
98+
bi = fixed_vars[b, i]
99+
for j in range(i + 1, n):
100+
val = G_m[i, j]
101+
energy += -val if bi == fixed_vars[b, j] else val
102+
vals[b] = energy
103+
return vals
104+
105+
106+
@njit(cache=True)
107+
def _influence_scores(G_m, fixed_row, n):
108+
"""Sum of absolute free-to-free edge weights for each free variable."""
109+
scores = np.full(n, -1.0)
110+
for i in range(n):
111+
if fixed_row[i] >= 0:
112+
continue
113+
s = 0.0
114+
for j in range(n):
115+
if i == j or fixed_row[j] >= 0:
116+
continue
117+
ii = i if i < j else j
118+
jj = j if i < j else i
119+
s += abs(G_m[ii, jj])
120+
scores[i] = s
121+
return scores
122+
123+
124+
# ---------------------------------------------------------------------------
125+
# B&B loop
126+
# ---------------------------------------------------------------------------
127+
128+
def _branch_and_bound(G_m, warm_theta, warm_energy, n, is_spin_glass,
129+
verbose=True, time_limit=None):
130+
best_bits = warm_theta.copy()
131+
best_value = warm_energy
132+
133+
if verbose:
134+
print(f"Warm-start incumbent: {best_value:.6f}")
135+
136+
root = np.full(n, -1, dtype=np.int8)
137+
stack = [root]
138+
139+
t_start = time.monotonic()
140+
nodes_explored = 0
141+
nodes_pruned = 0
142+
batch_cap = os.cpu_count() * 4
143+
144+
while stack:
145+
if time_limit is not None and (time.monotonic() - t_start) > time_limit:
146+
if verbose:
147+
print(f"Time limit reached. "
148+
f"Nodes: {nodes_explored}, Pruned: {nodes_pruned}")
149+
return best_bits, best_value, False
150+
151+
batch_size = min(batch_cap, len(stack))
152+
batch_nodes = [stack.pop() for _ in range(batch_size)]
153+
batch_arr = np.array(batch_nodes, dtype=np.int8)
154+
155+
nodes_explored += batch_size
156+
157+
ubs = _upper_bound_batch(G_m, batch_arr, n, batch_size)
158+
159+
leaves = []
160+
interior = []
161+
for k in range(batch_size):
162+
if ubs[k] <= best_value + 1e-9:
163+
nodes_pruned += 1
164+
continue
165+
if int(np.sum(batch_arr[k] < 0)) == 0:
166+
leaves.append(k)
167+
else:
168+
interior.append(k)
169+
170+
if leaves:
171+
leaf_arr = batch_arr[np.array(leaves, dtype=np.int64)]
172+
leaf_vals = (
173+
_eval_leaves_energy(G_m, leaf_arr, n, len(leaves))
174+
if is_spin_glass
175+
else _eval_leaves_cut(G_m, leaf_arr, n, len(leaves))
176+
)
177+
best_leaf = int(np.argmax(leaf_vals))
178+
if leaf_vals[best_leaf] > best_value:
179+
best_value = float(leaf_vals[best_leaf])
180+
best_bits = (leaf_arr[best_leaf] >= 1).copy()
181+
if verbose:
182+
print(f" New incumbent: {best_value:.6f}"
183+
f" (nodes: {nodes_explored})")
184+
185+
for k in interior:
186+
row = batch_arr[k]
187+
scores = _influence_scores(G_m, row, n)
188+
branch_var = int(np.argmax(scores))
189+
warm_val = int(best_bits[branch_var])
190+
for val in [warm_val, 1 - warm_val]:
191+
child = row.copy()
192+
child[branch_var] = np.int8(val)
193+
stack.append(child)
194+
195+
elapsed = time.monotonic() - t_start
196+
if verbose:
197+
print(f"\nExact optimum: {best_value:.6f}")
198+
print(f"Nodes explored: {nodes_explored} | "
199+
f"Pruned: {nodes_pruned} | Time: {elapsed:.3f}s")
200+
201+
return best_bits, best_value, True
202+
203+
204+
# ---------------------------------------------------------------------------
205+
# Public API
206+
# ---------------------------------------------------------------------------
207+
208+
def solve_maxcut_bnb(
209+
G,
210+
best_guess=None,
211+
quality=None,
212+
shots=None,
213+
is_spin_glass=False,
214+
anneal_t=None,
215+
anneal_h=None,
216+
repulsion_base=None,
217+
is_maxcut_gpu=True,
218+
verbose=True,
219+
time_limit=None,
220+
gray_iterations=None,
221+
gray_seed_multiple=None,
222+
):
223+
"""
224+
Exact MAXCUT/spin-glass solver: warm-start from spin_glass_solver then
225+
certify via branch and bound with parallel Numba kernels.
226+
227+
Accepts the same G input as spin_glass_solver (NetworkX graph or dense
228+
matrix) and the same warm-start formats (str, int, list, or None).
229+
230+
Parameters
231+
----------
232+
G : networkx.Graph or ndarray
233+
best_guess : str | int | list[bool] | None
234+
quality, shots, anneal_t, anneal_h, repulsion_base, is_maxcut_gpu,
235+
is_spin_glass, gray_iterations, gray_seed_multiple
236+
Forwarded to spin_glass_solver when best_guess is None.
237+
verbose : bool
238+
time_limit : float or None
239+
240+
Returns
241+
-------
242+
bitstring : str
243+
cut_value : float
244+
partition : tuple(list, list)
245+
min_energy : float
246+
certified : bool
247+
"""
248+
if isinstance(G, nx.Graph):
249+
nodes = list(G.nodes())
250+
n_qubits = len(nodes)
251+
G_m = nx.to_numpy_array(G, weight="weight", nonedge=0.0, dtype=dtype)
252+
else:
253+
n_qubits = len(G)
254+
nodes = list(range(n_qubits))
255+
G_m = np.asarray(G, dtype=dtype)
256+
257+
if n_qubits < 3:
258+
if n_qubits == 0:
259+
return "", 0, ([], []), 0, True
260+
if n_qubits == 1:
261+
return "0", 0, (nodes, []), 0, True
262+
if n_qubits == 2:
263+
weight = G_m[0, 1]
264+
if weight < 0.0:
265+
return "00", 0, (nodes, []), weight, True
266+
return "01", weight, ([nodes[0]], [nodes[1]]), -weight, True
267+
268+
bitstring = ""
269+
cut_value = None
270+
energy_value = None
271+
if isinstance(best_guess, str):
272+
bitstring = best_guess
273+
elif isinstance(best_guess, int):
274+
bitstring = int_to_bitstring(best_guess, n_qubits)
275+
elif isinstance(best_guess, list):
276+
bitstring = "".join(["1" if b else "0" for b in best_guess])
277+
else:
278+
if verbose:
279+
print("Running PyQrackIsing heuristic...")
280+
from .spin_glass_solver import spin_glass_solver
281+
kwargs = {}
282+
if quality is not None: kwargs["quality"] = quality
283+
if shots is not None: kwargs["shots"] = shots
284+
if anneal_t is not None: kwargs["anneal_t"] = anneal_t
285+
if anneal_h is not None: kwargs["anneal_h"] = anneal_h
286+
if repulsion_base is not None: kwargs["repulsion_base"] = repulsion_base
287+
if gray_iterations is not None: kwargs["gray_iterations"] = gray_iterations
288+
if gray_seed_multiple is not None: kwargs["gray_seed_multiple"] = gray_seed_multiple
289+
kwargs["is_maxcut_gpu"] = is_maxcut_gpu
290+
t0 = time.monotonic()
291+
bitstring, cut_value, _, energy_value = spin_glass_solver(
292+
G_m, is_spin_glass=is_spin_glass, **kwargs
293+
)
294+
if verbose:
295+
print(f"Heuristic value: {cut_value:.6f} ({time.monotonic()-t0:.3f}s)")
296+
297+
best_theta = np.array([b == "1" for b in list(bitstring)], dtype=np.bool_)
298+
if is_spin_glass:
299+
max_energy = compute_energy(best_theta, G_m, n_qubits) if energy_value is None else energy_value
300+
elif cut_value is None:
301+
max_energy = compute_cut(best_theta, G_m, n_qubits)
302+
else:
303+
max_energy = cut_value
304+
305+
if verbose:
306+
print("Starting branch and bound...")
307+
308+
best_theta, max_energy, certified = _branch_and_bound(
309+
G_m, best_theta, max_energy, n_qubits, is_spin_glass,
310+
verbose=verbose, time_limit=time_limit,
311+
)
312+
313+
bitstring, l, r = get_cut(best_theta, nodes, n_qubits)
314+
if is_spin_glass:
315+
cut_value = compute_cut(best_theta, G_m, n_qubits)
316+
min_energy = -max_energy
317+
else:
318+
cut_value = max_energy
319+
min_energy = compute_energy(best_theta, G_m, n_qubits)
320+
321+
return bitstring, float(cut_value), (l, r), float(min_energy), certified

0 commit comments

Comments
 (0)