|
| 1 | +# SPDX-License-Identifier: LGPL-3.0-or-later |
| 2 | +# Copyright (C) 2026 Daniel Strano and the Qrack contributors |
| 3 | +# |
| 4 | +# Initial draft produced by (Anthropic) Claude (Sonnet 4.6). |
| 5 | +# |
| 6 | +# bnb_exact.py — warm-start branch-and-bound exact MAXCUT solver (dense) |
| 7 | +# |
| 8 | +# Objective convention (matches PyQrackIsing throughout): |
| 9 | +# Maximise sum_{(i,j) in cut} w_{ij} |
| 10 | +# i.e. edge (i,j) contributes w_{ij} when bits[i] != bits[j]. |
| 11 | +# Edge weights may be positive or negative; no diagonal / self-loops. |
| 12 | +# |
| 13 | +# Parallelism strategy |
| 14 | +# -------------------- |
| 15 | +# The serial LP relaxation bound used in earlier drafts costs O(n^3) per |
| 16 | +# node, making it useless at scale. We replace it with a parallelised |
| 17 | +# decoupled upper bound that costs O(n^2) parallel work per node: |
| 18 | +# |
| 19 | +# UB(node) = fixed_fixed_cut (exact) |
| 20 | +# + for each free-free edge (i,j): max(w_ij, 0) |
| 21 | +# + for each fixed-free edge (i,j): max(w_ij, 0) if i free, etc. |
| 22 | +# |
| 23 | +# This bound is valid and is computed entirely inside @njit(parallel=True) |
| 24 | +# kernels. A batch of frontier nodes is evaluated simultaneously, and leaf |
| 25 | +# scoring is also parallelised across the batch. |
| 26 | + |
| 27 | +import time |
| 28 | +import os |
| 29 | +import networkx as nx |
| 30 | +import numpy as np |
| 31 | +from numba import njit, prange |
| 32 | +from .maxcut_tfim_util import ( |
| 33 | + compute_cut, |
| 34 | + compute_energy, |
| 35 | + get_cut, |
| 36 | + heuristic_threshold, |
| 37 | + int_to_bitstring, |
| 38 | + opencl_context, |
| 39 | +) |
| 40 | + |
| 41 | +dtype = opencl_context.dtype |
| 42 | + |
| 43 | + |
| 44 | +# --------------------------------------------------------------------------- |
| 45 | +# Parallel kernels |
| 46 | +# --------------------------------------------------------------------------- |
| 47 | + |
| 48 | +@njit(parallel=True, cache=True) |
| 49 | +def _upper_bound_batch(G_m, fixed_vars, n, batch_size): |
| 50 | + """ |
| 51 | + Parallel upper bound for a batch of B&B nodes. |
| 52 | + fixed_vars : (batch_size, n) int8, -1=free, 0/1=fixed. |
| 53 | + For each node b, accumulates: |
| 54 | + - exact cut for fixed-fixed pairs |
| 55 | + - max(w, 0) for any edge touching at least one free variable |
| 56 | + Returns ub[batch_size]. |
| 57 | + """ |
| 58 | + ub = np.empty(batch_size) |
| 59 | + for b in prange(batch_size): |
| 60 | + total = 0.0 |
| 61 | + for i in range(n): |
| 62 | + fi = fixed_vars[b, i] |
| 63 | + for j in range(i + 1, n): |
| 64 | + w = G_m[i, j] |
| 65 | + if w == 0.0: |
| 66 | + continue |
| 67 | + fj = fixed_vars[b, j] |
| 68 | + if fi >= 0 and fj >= 0: |
| 69 | + if fi != fj: |
| 70 | + total += w |
| 71 | + else: |
| 72 | + if w > 0.0: |
| 73 | + total += w |
| 74 | + ub[b] = total |
| 75 | + return ub |
| 76 | + |
| 77 | + |
| 78 | +@njit(parallel=True, cache=True) |
| 79 | +def _eval_leaves_cut(G_m, fixed_vars, n, batch_size): |
| 80 | + vals = np.empty(batch_size) |
| 81 | + for b in prange(batch_size): |
| 82 | + cut = 0.0 |
| 83 | + for i in range(n): |
| 84 | + bi = fixed_vars[b, i] |
| 85 | + for j in range(i + 1, n): |
| 86 | + if G_m[i, j] != 0.0 and bi != fixed_vars[b, j]: |
| 87 | + cut += G_m[i, j] |
| 88 | + vals[b] = cut |
| 89 | + return vals |
| 90 | + |
| 91 | + |
| 92 | +@njit(parallel=True, cache=True) |
| 93 | +def _eval_leaves_energy(G_m, fixed_vars, n, batch_size): |
| 94 | + vals = np.empty(batch_size) |
| 95 | + for b in prange(batch_size): |
| 96 | + energy = 0.0 |
| 97 | + for i in range(n): |
| 98 | + bi = fixed_vars[b, i] |
| 99 | + for j in range(i + 1, n): |
| 100 | + val = G_m[i, j] |
| 101 | + energy += -val if bi == fixed_vars[b, j] else val |
| 102 | + vals[b] = energy |
| 103 | + return vals |
| 104 | + |
| 105 | + |
| 106 | +@njit(cache=True) |
| 107 | +def _influence_scores(G_m, fixed_row, n): |
| 108 | + """Sum of absolute free-to-free edge weights for each free variable.""" |
| 109 | + scores = np.full(n, -1.0) |
| 110 | + for i in range(n): |
| 111 | + if fixed_row[i] >= 0: |
| 112 | + continue |
| 113 | + s = 0.0 |
| 114 | + for j in range(n): |
| 115 | + if i == j or fixed_row[j] >= 0: |
| 116 | + continue |
| 117 | + ii = i if i < j else j |
| 118 | + jj = j if i < j else i |
| 119 | + s += abs(G_m[ii, jj]) |
| 120 | + scores[i] = s |
| 121 | + return scores |
| 122 | + |
| 123 | + |
| 124 | +# --------------------------------------------------------------------------- |
| 125 | +# B&B loop |
| 126 | +# --------------------------------------------------------------------------- |
| 127 | + |
| 128 | +def _branch_and_bound(G_m, warm_theta, warm_energy, n, is_spin_glass, |
| 129 | + verbose=True, time_limit=None): |
| 130 | + best_bits = warm_theta.copy() |
| 131 | + best_value = warm_energy |
| 132 | + |
| 133 | + if verbose: |
| 134 | + print(f"Warm-start incumbent: {best_value:.6f}") |
| 135 | + |
| 136 | + root = np.full(n, -1, dtype=np.int8) |
| 137 | + stack = [root] |
| 138 | + |
| 139 | + t_start = time.monotonic() |
| 140 | + nodes_explored = 0 |
| 141 | + nodes_pruned = 0 |
| 142 | + batch_cap = os.cpu_count() * 4 |
| 143 | + |
| 144 | + while stack: |
| 145 | + if time_limit is not None and (time.monotonic() - t_start) > time_limit: |
| 146 | + if verbose: |
| 147 | + print(f"Time limit reached. " |
| 148 | + f"Nodes: {nodes_explored}, Pruned: {nodes_pruned}") |
| 149 | + return best_bits, best_value, False |
| 150 | + |
| 151 | + batch_size = min(batch_cap, len(stack)) |
| 152 | + batch_nodes = [stack.pop() for _ in range(batch_size)] |
| 153 | + batch_arr = np.array(batch_nodes, dtype=np.int8) |
| 154 | + |
| 155 | + nodes_explored += batch_size |
| 156 | + |
| 157 | + ubs = _upper_bound_batch(G_m, batch_arr, n, batch_size) |
| 158 | + |
| 159 | + leaves = [] |
| 160 | + interior = [] |
| 161 | + for k in range(batch_size): |
| 162 | + if ubs[k] <= best_value + 1e-9: |
| 163 | + nodes_pruned += 1 |
| 164 | + continue |
| 165 | + if int(np.sum(batch_arr[k] < 0)) == 0: |
| 166 | + leaves.append(k) |
| 167 | + else: |
| 168 | + interior.append(k) |
| 169 | + |
| 170 | + if leaves: |
| 171 | + leaf_arr = batch_arr[np.array(leaves, dtype=np.int64)] |
| 172 | + leaf_vals = ( |
| 173 | + _eval_leaves_energy(G_m, leaf_arr, n, len(leaves)) |
| 174 | + if is_spin_glass |
| 175 | + else _eval_leaves_cut(G_m, leaf_arr, n, len(leaves)) |
| 176 | + ) |
| 177 | + best_leaf = int(np.argmax(leaf_vals)) |
| 178 | + if leaf_vals[best_leaf] > best_value: |
| 179 | + best_value = float(leaf_vals[best_leaf]) |
| 180 | + best_bits = (leaf_arr[best_leaf] >= 1).copy() |
| 181 | + if verbose: |
| 182 | + print(f" New incumbent: {best_value:.6f}" |
| 183 | + f" (nodes: {nodes_explored})") |
| 184 | + |
| 185 | + for k in interior: |
| 186 | + row = batch_arr[k] |
| 187 | + scores = _influence_scores(G_m, row, n) |
| 188 | + branch_var = int(np.argmax(scores)) |
| 189 | + warm_val = int(best_bits[branch_var]) |
| 190 | + for val in [warm_val, 1 - warm_val]: |
| 191 | + child = row.copy() |
| 192 | + child[branch_var] = np.int8(val) |
| 193 | + stack.append(child) |
| 194 | + |
| 195 | + elapsed = time.monotonic() - t_start |
| 196 | + if verbose: |
| 197 | + print(f"\nExact optimum: {best_value:.6f}") |
| 198 | + print(f"Nodes explored: {nodes_explored} | " |
| 199 | + f"Pruned: {nodes_pruned} | Time: {elapsed:.3f}s") |
| 200 | + |
| 201 | + return best_bits, best_value, True |
| 202 | + |
| 203 | + |
| 204 | +# --------------------------------------------------------------------------- |
| 205 | +# Public API |
| 206 | +# --------------------------------------------------------------------------- |
| 207 | + |
| 208 | +def solve_maxcut_bnb( |
| 209 | + G, |
| 210 | + best_guess=None, |
| 211 | + quality=None, |
| 212 | + shots=None, |
| 213 | + is_spin_glass=False, |
| 214 | + anneal_t=None, |
| 215 | + anneal_h=None, |
| 216 | + repulsion_base=None, |
| 217 | + is_maxcut_gpu=True, |
| 218 | + verbose=True, |
| 219 | + time_limit=None, |
| 220 | + gray_iterations=None, |
| 221 | + gray_seed_multiple=None, |
| 222 | +): |
| 223 | + """ |
| 224 | + Exact MAXCUT/spin-glass solver: warm-start from spin_glass_solver then |
| 225 | + certify via branch and bound with parallel Numba kernels. |
| 226 | +
|
| 227 | + Accepts the same G input as spin_glass_solver (NetworkX graph or dense |
| 228 | + matrix) and the same warm-start formats (str, int, list, or None). |
| 229 | +
|
| 230 | + Parameters |
| 231 | + ---------- |
| 232 | + G : networkx.Graph or ndarray |
| 233 | + best_guess : str | int | list[bool] | None |
| 234 | + quality, shots, anneal_t, anneal_h, repulsion_base, is_maxcut_gpu, |
| 235 | + is_spin_glass, gray_iterations, gray_seed_multiple |
| 236 | + Forwarded to spin_glass_solver when best_guess is None. |
| 237 | + verbose : bool |
| 238 | + time_limit : float or None |
| 239 | +
|
| 240 | + Returns |
| 241 | + ------- |
| 242 | + bitstring : str |
| 243 | + cut_value : float |
| 244 | + partition : tuple(list, list) |
| 245 | + min_energy : float |
| 246 | + certified : bool |
| 247 | + """ |
| 248 | + if isinstance(G, nx.Graph): |
| 249 | + nodes = list(G.nodes()) |
| 250 | + n_qubits = len(nodes) |
| 251 | + G_m = nx.to_numpy_array(G, weight="weight", nonedge=0.0, dtype=dtype) |
| 252 | + else: |
| 253 | + n_qubits = len(G) |
| 254 | + nodes = list(range(n_qubits)) |
| 255 | + G_m = np.asarray(G, dtype=dtype) |
| 256 | + |
| 257 | + if n_qubits < 3: |
| 258 | + if n_qubits == 0: |
| 259 | + return "", 0, ([], []), 0, True |
| 260 | + if n_qubits == 1: |
| 261 | + return "0", 0, (nodes, []), 0, True |
| 262 | + if n_qubits == 2: |
| 263 | + weight = G_m[0, 1] |
| 264 | + if weight < 0.0: |
| 265 | + return "00", 0, (nodes, []), weight, True |
| 266 | + return "01", weight, ([nodes[0]], [nodes[1]]), -weight, True |
| 267 | + |
| 268 | + bitstring = "" |
| 269 | + cut_value = None |
| 270 | + energy_value = None |
| 271 | + if isinstance(best_guess, str): |
| 272 | + bitstring = best_guess |
| 273 | + elif isinstance(best_guess, int): |
| 274 | + bitstring = int_to_bitstring(best_guess, n_qubits) |
| 275 | + elif isinstance(best_guess, list): |
| 276 | + bitstring = "".join(["1" if b else "0" for b in best_guess]) |
| 277 | + else: |
| 278 | + if verbose: |
| 279 | + print("Running PyQrackIsing heuristic...") |
| 280 | + from .spin_glass_solver import spin_glass_solver |
| 281 | + kwargs = {} |
| 282 | + if quality is not None: kwargs["quality"] = quality |
| 283 | + if shots is not None: kwargs["shots"] = shots |
| 284 | + if anneal_t is not None: kwargs["anneal_t"] = anneal_t |
| 285 | + if anneal_h is not None: kwargs["anneal_h"] = anneal_h |
| 286 | + if repulsion_base is not None: kwargs["repulsion_base"] = repulsion_base |
| 287 | + if gray_iterations is not None: kwargs["gray_iterations"] = gray_iterations |
| 288 | + if gray_seed_multiple is not None: kwargs["gray_seed_multiple"] = gray_seed_multiple |
| 289 | + kwargs["is_maxcut_gpu"] = is_maxcut_gpu |
| 290 | + t0 = time.monotonic() |
| 291 | + bitstring, cut_value, _, energy_value = spin_glass_solver( |
| 292 | + G_m, is_spin_glass=is_spin_glass, **kwargs |
| 293 | + ) |
| 294 | + if verbose: |
| 295 | + print(f"Heuristic value: {cut_value:.6f} ({time.monotonic()-t0:.3f}s)") |
| 296 | + |
| 297 | + best_theta = np.array([b == "1" for b in list(bitstring)], dtype=np.bool_) |
| 298 | + if is_spin_glass: |
| 299 | + max_energy = compute_energy(best_theta, G_m, n_qubits) if energy_value is None else energy_value |
| 300 | + elif cut_value is None: |
| 301 | + max_energy = compute_cut(best_theta, G_m, n_qubits) |
| 302 | + else: |
| 303 | + max_energy = cut_value |
| 304 | + |
| 305 | + if verbose: |
| 306 | + print("Starting branch and bound...") |
| 307 | + |
| 308 | + best_theta, max_energy, certified = _branch_and_bound( |
| 309 | + G_m, best_theta, max_energy, n_qubits, is_spin_glass, |
| 310 | + verbose=verbose, time_limit=time_limit, |
| 311 | + ) |
| 312 | + |
| 313 | + bitstring, l, r = get_cut(best_theta, nodes, n_qubits) |
| 314 | + if is_spin_glass: |
| 315 | + cut_value = compute_cut(best_theta, G_m, n_qubits) |
| 316 | + min_energy = -max_energy |
| 317 | + else: |
| 318 | + cut_value = max_energy |
| 319 | + min_energy = compute_energy(best_theta, G_m, n_qubits) |
| 320 | + |
| 321 | + return bitstring, float(cut_value), (l, r), float(min_energy), certified |
0 commit comments