Skip to content

Commit da89675

Browse files
committed
test(unit): add branch coverage tests for Search neighbor methods
1 parent 281fa2d commit da89675

1 file changed

Lines changed: 117 additions & 0 deletions

File tree

tests/unit/CodeEntropy/levels/test_search.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from pathlib import Path
2+
from unittest.mock import MagicMock, patch
23

34
import numpy as np
45
import pytest
@@ -158,3 +159,119 @@ def test_distance_boundary_conditions():
158159
distance4 = search.get_distance(c, e, dimensions)
159160

160161
assert distance4 == pytest.approx(1.7320508075688772)
162+
163+
164+
def test_get_RAD_indices_breaks_when_angle_is_nan():
165+
search = Search()
166+
167+
i_coords = np.array([0.0, 0.0, 0.0])
168+
sorted_distances = [(1, 1.0), (2, 2.0)]
169+
number_molecules = 3
170+
171+
frag_1 = MagicMock()
172+
frag_2 = MagicMock()
173+
frag_1.center_of_mass.return_value = np.array([1.0, 0.0, 0.0])
174+
frag_2.center_of_mass.return_value = np.array([2.0, 0.0, 0.0])
175+
176+
system = MagicMock()
177+
system.atoms.fragments = [MagicMock(), frag_1, frag_2]
178+
system.dimensions = np.array([10.0, 10.0, 10.0, 90.0, 90.0, 90.0])
179+
180+
search.get_angle = MagicMock(side_effect=[np.nan])
181+
182+
result = search._get_RAD_indices(
183+
i_coords=i_coords,
184+
sorted_distances=sorted_distances,
185+
system=system,
186+
number_molecules=number_molecules,
187+
)
188+
189+
assert result == [1, 2]
190+
search.get_angle.assert_called_once()
191+
192+
193+
def test_get_grid_neighbors_uses_residue_search_for_non_united_atom():
194+
search = Search()
195+
196+
universe = MagicMock()
197+
fragment = MagicMock()
198+
fragment.indices = [4, 5, 6]
199+
fragment.residues = MagicMock()
200+
201+
universe.atoms.fragments = [fragment]
202+
203+
molecule_atom_group = MagicMock()
204+
universe.select_atoms.return_value = molecule_atom_group
205+
206+
search_result = MagicMock()
207+
final_neighbors = MagicMock()
208+
final_neighbors.fragindices = np.array([7, 8, 9])
209+
210+
search_result.__sub__.return_value = final_neighbors
211+
212+
search_object = MagicMock()
213+
214+
with patch(
215+
"CodeEntropy.levels.search.mda.lib.NeighborSearch.AtomNeighborSearch"
216+
) as mock_ans:
217+
mock_ans.return_value = search_object
218+
mock_ans.search.return_value = search_result
219+
220+
result = search.get_grid_neighbors(
221+
universe=universe,
222+
mol_id=0,
223+
highest_level="residue",
224+
)
225+
226+
universe.select_atoms.assert_called_once_with("index 4:6")
227+
mock_ans.assert_called_once_with(universe.atoms)
228+
mock_ans.search.assert_called_once_with(
229+
search_object,
230+
molecule_atom_group,
231+
radius=3.5,
232+
level="R",
233+
)
234+
search_result.__sub__.assert_called_once_with(fragment.residues)
235+
assert (result == np.array([7, 8, 9])).all()
236+
237+
238+
def test_get_grid_neighbors_uses_atom_search_for_united_atom():
239+
search = Search()
240+
241+
universe = MagicMock()
242+
fragment = MagicMock()
243+
fragment.indices = [10, 11]
244+
universe.atoms.fragments = [fragment]
245+
246+
molecule_atom_group = MagicMock()
247+
universe.select_atoms.return_value = molecule_atom_group
248+
249+
search_result = MagicMock()
250+
final_neighbors = MagicMock()
251+
final_neighbors.fragindices = np.array([2, 3])
252+
253+
search_result.__sub__.return_value = final_neighbors
254+
255+
search_object = MagicMock()
256+
257+
with patch(
258+
"CodeEntropy.levels.search.mda.lib.NeighborSearch.AtomNeighborSearch"
259+
) as mock_ans:
260+
mock_ans.return_value = search_object
261+
mock_ans.search.return_value = search_result
262+
263+
result = search.get_grid_neighbors(
264+
universe=universe,
265+
mol_id=0,
266+
highest_level="united_atom",
267+
)
268+
269+
universe.select_atoms.assert_called_once_with("index 10:11")
270+
mock_ans.search.assert_called_once_with(
271+
search_object,
272+
molecule_atom_group,
273+
radius=3.0,
274+
level="A",
275+
)
276+
search_result.__sub__.assert_called_once_with(molecule_atom_group)
277+
assert (result == np.array([2, 3])).all()

0 commit comments

Comments
 (0)