Skip to content

Commit a4d1e6d

Browse files
committed
minor fallback plus updated speedtest
1 parent d1bfde2 commit a4d1e6d

2 files changed

Lines changed: 4 additions & 3 deletions

File tree

TPTBox/core/np_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def np_unique(arr: np.ndarray) -> list[int]:
201201
counts = np.bincount(arr.ravel())
202202
return list(np.where(counts > 0)[0])
203203
# For sparse label spaces fall back to np.unique
204-
return list(np.unique(arr))
204+
return old_np_unique(arr)
205205

206206

207207
def np_unique_withoutzero(arr: UINTARRAY) -> list[int]:

TPTBox/tests/speedtests/speedtest_npunique.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@
1818
np_unique,
1919
np_unique_withoutzero,
2020
np_volume,
21+
old_np_unique,
2122
)
2223
from TPTBox.tests.speedtests.speedtest import speed_test
2324
from TPTBox.tests.test_utils import get_nii
2425

2526
def get_nii_array():
2627
num_points = random.randint(1, 30)
27-
nii, points, orientation, sizes = get_nii(x=(140, 140, 150), num_point=num_points)
28+
nii, points, orientation, sizes = get_nii(x=(400, 400, 400), num_point=num_points)
2829
# nii.map_labels_({1: -1}, verbose=False)
2930
arr = nii.get_seg_array().astype(np.uint8)
3031
# arr[arr == 1] = -1
@@ -34,7 +35,7 @@ def get_nii_array():
3435
speed_test(
3536
repeats=50,
3637
get_input_func=get_nii_array,
37-
functions=[np_unique, np.unique, np_is_empty, np.max],
38+
functions=[np_unique, old_np_unique, np.unique, np_is_empty, np.max],
3839
assert_equal_function=lambda x, y: True, # np.all([x[i] == y[i] for i in range(len(x))]), # noqa: ARG005
3940
# np.all([x[i] == y[i] for i in range(len(x))])
4041
)

0 commit comments

Comments
 (0)