Skip to content

Commit f9cbbf7

Browse files
Cache all numba that we can
1 parent 6bef63a commit f9cbbf7

1 file changed

Lines changed: 23 additions & 23 deletions

File tree

pyqrackising/maxcut_tfim_util.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def make_best_theta_buf_64(theta):
301301
return theta_buf
302302

303303

304-
@njit(parallel=True)
304+
@njit(parallel=True, cache=True)
305305
def convert_bool_to_uint(samples):
306306
shots = samples.shape[0]
307307
n = samples.shape[1]
@@ -316,7 +316,7 @@ def convert_bool_to_uint(samples):
316316
return theta
317317

318318

319-
@njit
319+
@njit(cache=True)
320320
def compute_energy(sample, G_m, n_qubits):
321321
energy = 0.0
322322
for u in range(n_qubits):
@@ -328,7 +328,7 @@ def compute_energy(sample, G_m, n_qubits):
328328
return energy
329329

330330

331-
@njit
331+
@njit(cache=True)
332332
def compute_cut(sample, G_m, n_qubits):
333333
l, r = get_cut_base(sample, n_qubits)
334334
cut = 0
@@ -339,7 +339,7 @@ def compute_cut(sample, G_m, n_qubits):
339339
return cut
340340

341341

342-
@njit
342+
@njit(cache=True)
343343
def compute_cut_diff(u, sample, G_m, n_qubits):
344344
energy = 0.0
345345
u_bit = sample[u]
@@ -354,7 +354,7 @@ def compute_cut_diff(u, sample, G_m, n_qubits):
354354
return energy
355355

356356

357-
@njit
357+
@njit(cache=True)
358358
def compute_cut_diff_2(k, l, sample, G_m, n_qubits):
359359
if l < k:
360360
t = k
@@ -384,7 +384,7 @@ def compute_cut_diff_2(k, l, sample, G_m, n_qubits):
384384
return energy
385385

386386

387-
@njit
387+
@njit(cache=True)
388388
def compute_energy_sparse(sample, G_data, G_rows, G_cols, n_qubits):
389389
energy = 0.0
390390
for u in range(n_qubits):
@@ -396,7 +396,7 @@ def compute_energy_sparse(sample, G_data, G_rows, G_cols, n_qubits):
396396
return energy
397397

398398

399-
@njit
399+
@njit(cache=True)
400400
def compute_cut_sparse(sample, G_data, G_rows, G_cols, n_qubits):
401401
l, r = get_cut_base(sample, n_qubits)
402402
s = l if len(l) < len(r) else r
@@ -410,7 +410,7 @@ def compute_cut_sparse(sample, G_data, G_rows, G_cols, n_qubits):
410410
return cut
411411

412412

413-
@njit
413+
@njit(cache=True)
414414
def compute_energy_streaming(sample, G_func, nodes, n_qubits):
415415
energy = 0.0
416416
for u in range(n_qubits):
@@ -422,7 +422,7 @@ def compute_energy_streaming(sample, G_func, nodes, n_qubits):
422422
return energy
423423

424424

425-
@njit
425+
@njit(cache=True)
426426
def compute_cut_streaming(sample, G_func, nodes, n_qubits):
427427
l, r = get_cut_base(sample, n_qubits)
428428
cut = 0
@@ -433,7 +433,7 @@ def compute_cut_streaming(sample, G_func, nodes, n_qubits):
433433
return cut
434434

435435

436-
@njit
436+
@njit(cache=True)
437437
def compute_cut_diff_streaming(u, sample, G_func, nodes, n_qubits):
438438
energy = 0.0
439439
u_bit = sample[u]
@@ -447,7 +447,7 @@ def compute_cut_diff_streaming(u, sample, G_func, nodes, n_qubits):
447447
return energy
448448

449449

450-
@njit
450+
@njit(cache=True)
451451
def compute_cut_diff_2_streaming(k, l, sample, G_func, nodes, n_qubits):
452452
if l < k:
453453
t = k
@@ -475,7 +475,7 @@ def compute_cut_diff_2_streaming(k, l, sample, G_func, nodes, n_qubits):
475475
return energy
476476

477477

478-
@njit
478+
@njit(cache=True)
479479
def compute_cut_diff_between(o_theta, n_theta, G_m, n_qubits):
480480
energy = 0.0
481481

@@ -494,7 +494,7 @@ def compute_cut_diff_between(o_theta, n_theta, G_m, n_qubits):
494494
return energy
495495

496496

497-
@njit
497+
@njit(cache=True)
498498
def compute_cut_diff_between_streaming(o_theta, n_theta, G_func, nodes, n_qubits):
499499
energy = 0.0
500500

@@ -512,7 +512,7 @@ def compute_cut_diff_between_streaming(o_theta, n_theta, G_func, nodes, n_qubits
512512
return energy
513513

514514

515-
@njit
515+
@njit(cache=True)
516516
def get_cut(solution, nodes, n):
517517
bit_string = ""
518518
l, r = [], []
@@ -527,7 +527,7 @@ def get_cut(solution, nodes, n):
527527
return bit_string, l, r
528528

529529

530-
@njit
530+
@njit(cache=True)
531531
def get_cut_base(solution, n):
532532
l, r = [], []
533533
for i in range(n):
@@ -544,7 +544,7 @@ def int_to_bitstring(integer, length):
544544
return (bin(integer)[2:].zfill(length))[::-1]
545545

546546

547-
@njit
547+
@njit(cache=True)
548548
def binary_search(l, t):
549549
left = 0
550550
right = len(l) - 1
@@ -575,7 +575,7 @@ def to_scipy_sparse_upper_triangular(G, nodes, n_nodes):
575575
return lil.tocsr()
576576

577577

578-
@njit(parallel=True)
578+
@njit(parallel=True, cache=True)
579579
def init_theta(h_mult, n_qubits, J_eff, degrees):
580580
theta = np.empty(n_qubits, dtype=np.float64)
581581
h_mult = abs(h_mult)
@@ -608,7 +608,7 @@ def init_thresholds(n_qubits):
608608
return thresholds
609609

610610

611-
@njit
611+
@njit(cache=True)
612612
def probability_by_hamming_weight(J, h, z, theta, t, n_bias, normalized=True, omega=1.5 * np.pi):
613613
zJ = z * J
614614
theta_c = ((np.pi if J > 0 else -np.pi) / 2) if abs(zJ) < epsilon else np.arcsin(max(-1.0, min(1.0, h / zJ)))
@@ -634,7 +634,7 @@ def probability_by_hamming_weight(J, h, z, theta, t, n_bias, normalized=True, om
634634
return bias
635635

636636

637-
@njit
637+
@njit(cache=True)
638638
def maxcut_hamming_cdf(hamming_prob, n_qubits, J_func, degrees, quality, tot_t, h_mult, omega=1.5 * np.pi):
639639
n_steps = 1 << quality
640640
delta_t = tot_t / n_steps
@@ -670,7 +670,7 @@ def maxcut_hamming_cdf(hamming_prob, n_qubits, J_func, degrees, quality, tot_t,
670670
return cum_prob
671671

672672

673-
@njit
673+
@njit(cache=True)
674674
def sample_mag(cum_prob):
675675
p = np.random.random()
676676
m = 0
@@ -690,7 +690,7 @@ def sample_mag(cum_prob):
690690
return m
691691

692692

693-
@njit
693+
@njit(cache=True)
694694
def bit_pick(weights, used, n):
695695
# Count available
696696
p = 0.0
@@ -714,7 +714,7 @@ def bit_pick(weights, used, n):
714714
return node
715715

716716

717-
@njit
717+
@njit(cache=True)
718718
def gray_code_next(state, curr_idx, offset):
719719
prev = curr_idx
720720
curr = curr_idx + 1
@@ -727,7 +727,7 @@ def gray_code_next(state, curr_idx, offset):
727727
return flip_bit
728728

729729

730-
@njit
730+
@njit(cache=True)
731731
def gray_mutation(index, seed_bits, offset):
732732
"""Apply Gray-code-indexed bit flips to a seed bitstring."""
733733
n = seed_bits.shape[0]

0 commit comments

Comments
 (0)