Skip to content

Commit cbb5443

Browse files
committed
Refactored egress tests using NumPy functions
1 parent df20af8 commit cbb5443

3 files changed

Lines changed: 30 additions & 38 deletions

File tree

tests/naive.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ def z_norm(a, axis=0, threshold=1e-7):
1010
return (a - np.mean(a, axis, keepdims=True)) / std
1111

1212

13-
def distance(a, b):
14-
return np.linalg.norm(a - b)
13+
def distance(a, b, axis=0):
14+
return np.linalg.norm(a - b, axis=axis)
1515

1616

1717
def apply_exclusion_zone(D, trivial_idx, excl_zone):

tests/test_aampi.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,7 @@ def test_aampi_identical_subsequence_self_join_egress():
848848
def test_aampi_profile_index_match():
849849
T_full = np.random.rand(64)
850850
m = 3
851+
T_full_subseq = core.rolling_window(T_full, m)
851852
warm_start = 8
852853

853854
T_stream = T_full[:warm_start].copy()
@@ -860,23 +861,17 @@ def test_aampi_profile_index_match():
860861
t = T_full[i]
861862
stream.update(t)
862863

863-
for j in range(stream.I_.shape[0]):
864-
I = stream.I_[j]
865-
left_I = stream.left_I_[j]
866-
867-
if I < 0:
868-
P[j] = np.inf
869-
else:
870-
P[j] = naive.distance(
871-
T_full[j + n + 1 : j + n + 1 + m], T_full[I : I + m]
872-
)
873-
874-
if left_I < 0:
875-
left_P[j] = np.inf
876-
else:
877-
left_P[j] = naive.distance(
878-
T_full[j + n + 1 : j + n + 1 + m], T_full[left_I : left_I + m]
879-
)
864+
P[:] = np.inf
865+
idx = np.argwhere(stream.I_ >= 0).flatten()
866+
P[idx] = naive.distance(
867+
T_full_subseq[idx + n + 1], T_full_subseq[stream.I_[idx]], axis=1
868+
)
869+
870+
left_P[:] = np.inf
871+
idx = np.argwhere(stream.left_I_ >= 0).flatten()
872+
left_P[idx] = naive.distance(
873+
T_full_subseq[idx + n + 1], T_full_subseq[stream.left_I_[idx]], axis=1
874+
)
880875

881876
npt.assert_almost_equal(stream.P_, P)
882877
npt.assert_almost_equal(stream.left_P_, left_P)

tests/test_stumpi.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -856,6 +856,7 @@ def test_stumpi_identical_subsequence_self_join_egress():
856856
def test_stumpi_profile_index_match():
857857
T_full = np.random.rand(64)
858858
m = 3
859+
T_full_subseq = core.rolling_window(T_full, m)
859860
warm_start = 8
860861

861862
T_stream = T_full[:warm_start].copy()
@@ -868,25 +869,21 @@ def test_stumpi_profile_index_match():
868869
t = T_full[i]
869870
stream.update(t)
870871

871-
for j in range(stream.I_.shape[0]):
872-
I = stream.I_[j]
873-
left_I = stream.left_I_[j]
874-
875-
if I < 0:
876-
P[j] = np.inf
877-
else:
878-
P[j] = naive.distance(
879-
naive.z_norm(T_full[j + n + 1 : j + n + 1 + m]),
880-
naive.z_norm(T_full[I : I + m]),
881-
)
882-
883-
if left_I < 0:
884-
left_P[j] = np.inf
885-
else:
886-
left_P[j] = naive.distance(
887-
naive.z_norm(T_full[j + n + 1 : j + n + 1 + m]),
888-
naive.z_norm(T_full[left_I : left_I + m]),
889-
)
872+
P[:] = np.inf
873+
idx = np.argwhere(stream.I_ >= 0).flatten()
874+
P[idx] = naive.distance(
875+
naive.z_norm(T_full_subseq[idx + n + 1], axis=1),
876+
naive.z_norm(T_full_subseq[stream.I_[idx]], axis=1),
877+
axis=1,
878+
)
879+
880+
left_P[:] = np.inf
881+
idx = np.argwhere(stream.left_I_ >= 0).flatten()
882+
left_P[idx] = naive.distance(
883+
naive.z_norm(T_full_subseq[idx + n + 1], axis=1),
884+
naive.z_norm(T_full_subseq[stream.left_I_[idx]], axis=1),
885+
axis=1,
886+
)
890887

891888
npt.assert_almost_equal(stream.P_, P)
892889
npt.assert_almost_equal(stream.left_P_, left_P)

0 commit comments

Comments
 (0)