Skip to content

Commit a8fb675

Browse files
Optimize repulsion sampling
1 parent 375e231 commit a8fb675

4 files changed

Lines changed: 31 additions & 42 deletions

File tree

pyqrackising/maxcut_tfim.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,8 @@
3434

3535

3636
@njit(cache=True)
37-
def update_repulsion_choice(G_m, weights, n, used, node, repulsion_base):
37+
def update_repulsion_choice(G_m, weights, n, node, repulsion_base):
3838
# Select node
39-
used[node] = True
4039
weights[node] = 0.0
4140

4241
if abs(1.0 - repulsion_base) <= epsilon:
@@ -55,25 +54,24 @@ def local_repulsion_choice(G_m, repulsion_base, n, m, s):
5554
- After choosing a node, its neighbors' probabilities are further reduced
5655
"""
5756

58-
used = np.zeros(n, dtype=np.bool_) # False = available, True = used
59-
6057
# First bit:
6158
node = s % n
6259
if m == 1:
60+
used = np.zeros(n, dtype=np.bool_)
6361
used[node] = True
6462
return used
6563

6664
weights = np.ones(n, dtype=np.float64)
67-
update_repulsion_choice(G_m, weights, n, used, node, repulsion_base)
65+
update_repulsion_choice(G_m, weights, n, node, repulsion_base)
6866

6967
for _ in range(1, m - 1):
70-
node = bit_pick(weights, used, n)
71-
update_repulsion_choice(G_m, weights, n, used, node, repulsion_base)
68+
node = bit_pick(weights, n)
69+
update_repulsion_choice(G_m, weights, n, node, repulsion_base)
7270

73-
node = bit_pick(weights, used, n)
74-
used[node] = True
71+
node = bit_pick(weights, n)
72+
weights[node] = 0.0
7573

76-
return used
74+
return weights == 0.0
7775

7876

7977
@njit(parallel=True, cache=True)

pyqrackising/maxcut_tfim_sparse.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,22 +35,22 @@
3535

3636

3737
@njit(cache=True)
38-
def update_repulsion_choice(G_data, G_rows, G_cols, weights, n, used, node, repulsion_base):
38+
def update_repulsion_choice(G_data, G_rows, G_cols, weights, n, node, repulsion_base):
3939
# Select node
40-
used[node] = True
40+
weights[node] = 0.0
4141

4242
if abs(1.0 - repulsion_base) <= epsilon:
4343
return
4444

4545
# Repulsion: penalize neighbors
4646
for j in range(G_rows[node], G_rows[node + 1]):
4747
nbr = G_cols[j]
48-
if used[nbr]:
48+
if weights[nbr] == 0.0:
4949
continue
5050
weights[nbr] *= repulsion_base ** (-G_data[j])
5151

5252
for nbr in range(node):
53-
if used[nbr]:
53+
if weights[nbr] == 0.0:
5454
continue
5555
start = G_rows[nbr]
5656
end = G_rows[nbr + 1]
@@ -68,25 +68,24 @@ def local_repulsion_choice(G_data, G_rows, G_cols, repulsion_base, n, m, s):
6868
- After choosing a node, its neighbors' probabilities are further reduced
6969
"""
7070

71-
used = np.zeros(n, dtype=np.bool_) # False = available, True = used
72-
7371
# First bit:
7472
node = s % n
7573
if m == 1:
74+
used = np.zeros(n, dtype=np.bool_)
7675
used[node] = True
7776
return used
7877

7978
weights = np.ones(n, dtype=np.float64)
80-
update_repulsion_choice(G_data, G_rows, G_cols, weights, n, used, node, repulsion_base)
79+
update_repulsion_choice(G_data, G_rows, G_cols, weights, n, node, repulsion_base)
8180

8281
for _ in range(1, m - 1):
83-
node = bit_pick(weights, used, n)
84-
update_repulsion_choice(G_data, G_rows, G_cols, weights, n, used, node, repulsion_base)
82+
node = bit_pick(weights, n)
83+
update_repulsion_choice(G_data, G_rows, G_cols, weights, n, node, repulsion_base)
8584

86-
node = bit_pick(weights, used, n)
87-
used[node] = True
85+
node = bit_pick(weights, n)
86+
weights[node] = 0.0
8887

89-
return used
88+
return weights == 0.0
9089

9190

9291
@njit(parallel=True, cache=True)

pyqrackising/maxcut_tfim_streaming.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,16 @@
2323

2424

2525
@njit(cache=True)
26-
def update_repulsion_choice(G_func, nodes, weights, n, used, node, repulsion_base):
26+
def update_repulsion_choice(G_func, nodes, weights, n, node, repulsion_base):
2727
# Select node
28-
used[node] = True
28+
weights[node] = 0.0
2929

3030
if abs(1.0 - repulsion_base) <= epsilon:
3131
return
3232

3333
# Repulsion: penalize neighbors
3434
for nbr in range(n):
35-
if used[nbr]:
35+
if weights[nbr] == 0.0:
3636
continue
3737
weights[nbr] *= repulsion_base ** (-G_func(nodes[node], nodes[nbr]))
3838

@@ -46,25 +46,24 @@ def local_repulsion_choice(G_func, nodes, repulsion_base, n, m, s):
4646
- After choosing a node, its neighbors' probabilities are further reduced
4747
"""
4848

49-
used = np.zeros(n, dtype=np.bool_) # False = available, True = used
50-
5149
# First bit:
5250
node = s % n
5351
if m == 1:
52+
used = np.zeros(n, dtype=np.bool_)
5453
used[node] = True
5554
return used
5655

5756
weights = np.ones(n, dtype=np.float64)
58-
update_repulsion_choice(G_func, nodes, weights, n, used, node, repulsion_base)
57+
update_repulsion_choice(G_func, nodes, weights, n, node, repulsion_base)
5958

6059
for _ in range(1, m - 1):
61-
node = bit_pick(weights, used, n)
62-
update_repulsion_choice(G_func, nodes, weights, n, used, node, repulsion_base)
60+
node = bit_pick(weights, n)
61+
update_repulsion_choice(G_func, nodes, weights, n, node, repulsion_base)
6362

64-
node = bit_pick(weights, used, n)
65-
used[node] = True
63+
node = bit_pick(weights, n)
64+
weights[node] = 0.0
6665

67-
return used
66+
return weights == 0.0
6867

6968

7069
@njit(parallel=True, cache=True)

pyqrackising/maxcut_tfim_util.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -708,21 +708,14 @@ def sample_mag(cum_prob):
708708

709709

710710
@njit(cache=True)
711-
def bit_pick(weights, used, n):
711+
def bit_pick(weights, n):
712712
# Count available
713-
p = 0.0
714-
for i in range(n):
715-
if used[i]:
716-
continue
717-
p += weights[i]
718-
713+
p = weights.sum()
719714
# Normalize & sample
720715
p *= np.random.rand()
721716
cum = 0.0
722717
node = 0
723718
for i in range(n):
724-
if used[i]:
725-
continue
726719
cum += weights[i]
727720
if p < cum:
728721
node = i

0 commit comments

Comments
 (0)