@@ -321,7 +321,7 @@ def compute_energy(sample, G_m, n_qubits):
321321 orig = np .repeat (sample , n_qubits )
322322 result = orig .copy ()
323323 result = result .reshape (- 1 , n_qubits ) ^ orig .reshape (- 2 , n_qubits )
324- result = (result * 2 ) - 1
324+ result = (result << 1 ) - 1
325325
326326 return (result * G_m ).sum () / 2.0
327327
@@ -337,47 +337,12 @@ def compute_cut(sample, G_m, n_qubits):
337337
338338@njit (cache = True )
339339def compute_cut_diff (u , sample , G_m , n_qubits ):
340- energy = 0.0
341- u_bit = sample [u ]
342- G_u = G_m [u ]
343- for v in range (u ):
344- val = G_u [v ]
345- energy += - val if u_bit == sample [v ] else val
346- for v in range (u + 1 , n_qubits ):
347- val = G_u [v ]
348- energy += - val if u_bit == sample [v ] else val
349-
350- return energy
340+ return G_m [u ] * (((sample [u ] ^ sample ) << 1 ) - 1 )
351341
352342
353343@njit (cache = True )
354344def compute_cut_diff_2 (k , l , sample , G_m , n_qubits ):
355- if l < k :
356- t = k
357- k = l
358- l = t
359- energy = 0.0
360- k_bit = sample [k ]
361- l_bit = sample [l ]
362- G_k = G_m [k ]
363- G_l = G_m [l ]
364- for v in range (k ):
365- val = G_k [v ]
366- energy += - val if k_bit == sample [v ] else val
367- val = G_l [v ]
368- energy += - val if l_bit == sample [v ] else val
369- for v in range (k + 1 , l ):
370- val = G_k [v ]
371- energy += - val if k_bit == sample [v ] else val
372- val = G_l [v ]
373- energy += - val if l_bit == sample [v ] else val
374- for v in range (l + 1 , n_qubits ):
375- val = G_k [v ]
376- energy += - val if k_bit == sample [v ] else val
377- val = G_l [v ]
378- energy += - val if l_bit == sample [v ] else val
379-
380- return energy
345+ return compute_cut_diff (k , sample , G_m , n_qubits ) + compute_cut_diff (l , sample , G_m , n_qubits )
381346
382347
383348@njit (cache = True )
0 commit comments