Skip to content

Commit 8c8313f

Browse files
committed
Added flags for compatibility of numba with python 3.8
1 parent 262b75b commit 8c8313f

1 file changed

Lines changed: 70 additions & 9 deletions

File tree

src/tdamapper/utils/quickselect.py

Lines changed: 70 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,103 @@
1+
import numpy as np
12
from numba import njit
23

4+
_ARR = np.zeros(1)
5+
36

47
@njit
58
def swap(arr, i, j):
69
arr[i], arr[j] = arr[j], arr[i]
710

811

912
@njit
10-
def swap_all(arr, i, j, extra1=None, extra2=None):
13+
def _swap_all(arr, i, j, extra1, use_extra1, extra2, use_extra2):
1114
swap(arr, i, j)
12-
if extra1 is not None:
15+
if use_extra1:
1316
swap(extra1, i, j)
14-
if extra2 is not None:
17+
if use_extra2:
1518
swap(extra2, i, j)
1619

1720

1821
@njit
19-
def partition(data, start, end, p_ord, extra1=None, extra2=None):
22+
def _partition(data, start, end, p_ord, extra1, use_extra1, extra2, use_extra2):
2023
higher = start
2124
for j in range(start, end):
2225
j_ord = data[j]
2326
if j_ord < p_ord:
24-
swap_all(data, higher, j, extra1, extra2)
27+
_swap_all(data, higher, j, extra1, use_extra1, extra2, use_extra2)
2528
higher += 1
2629
return higher
2730

2831

2932
@njit
30-
def quickselect(data, start, end, k, extra1=None, extra2=None):
33+
def _quickselect(data, start, end, k, extra1, use_extra1, extra2, use_extra2):
3134
if (k < start) or (k >= end):
3235
return
3336
start_, end_, higher = start, end, None
3437
while higher != k + 1:
3538
p = data[k]
36-
swap_all(data, start_, k, extra1, extra2)
37-
higher = partition(data, start_ + 1, end_, p, extra1, extra2)
38-
swap_all(data, start_, higher - 1, extra1, extra2)
39+
_swap_all(data, start_, k, extra1, use_extra1, extra2, use_extra2)
40+
higher = _partition(
41+
data, start_ + 1, end_, p, extra1, use_extra1, extra2, use_extra2
42+
)
43+
_swap_all(data, start_, higher - 1, extra1, use_extra1, extra2, use_extra2)
3944
if k <= higher - 1:
4045
end_ = higher
4146
else:
4247
start_ = higher
48+
49+
50+
def _to_array(extra1=None, extra2=None):
51+
extra1_arr = _ARR if extra1 is None else extra1
52+
extra2_arr = _ARR if extra2 is None else extra2
53+
return extra1_arr, extra2_arr
54+
55+
56+
def _use_array(extra1=None, extra2=None):
57+
use_extra1 = extra1 is not None
58+
use_extra2 = extra2 is not None
59+
return use_extra1, use_extra2
60+
61+
62+
def swap_all(arr, i, j, extra1=None, extra2=None):
63+
extra1_arr, extra2_arr = _to_array(extra1, extra2)
64+
use_extra1, use_extra2 = _use_array(extra1, extra2)
65+
_swap_all(
66+
arr,
67+
i,
68+
j,
69+
extra1=extra1_arr,
70+
use_extra1=use_extra1,
71+
extra2=extra2_arr,
72+
use_extra2=use_extra2,
73+
)
74+
75+
76+
def partition(data, start, end, p_ord, extra1=None, extra2=None):
77+
extra1_arr, extra2_arr = _to_array(extra1, extra2)
78+
use_extra1, use_extra2 = _use_array(extra1, extra2)
79+
return _partition(
80+
data,
81+
start,
82+
end,
83+
p_ord,
84+
extra1=extra1_arr,
85+
use_extra1=use_extra1,
86+
extra2=extra2_arr,
87+
use_extra2=use_extra2,
88+
)
89+
90+
91+
def quickselect(data, start, end, k, extra1=None, extra2=None):
92+
extra1_arr, extra2_arr = _to_array(extra1, extra2)
93+
use_extra1, use_extra2 = _use_array(extra1, extra2)
94+
_quickselect(
95+
data,
96+
start,
97+
end,
98+
k,
99+
extra1=extra1_arr,
100+
use_extra1=use_extra1,
101+
extra2=extra2_arr,
102+
use_extra2=use_extra2,
103+
)

0 commit comments

Comments
 (0)