Skip to content

Commit a3f0229

Browse files
committed
test for grid search
1 parent 8f92ef7 commit a3f0229

2 files changed

Lines changed: 55 additions & 2 deletions

File tree

CodeEntropy/levels/search.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def get_distance(self, j_position, i_position, dimensions):
186186

187187
return distance
188188

189-
def get_grid_neighbors(self, universe, search_object, mol_id, highest_level):
189+
def get_grid_neighbors(self, universe, mol_id, highest_level):
190190
"""
191191
Use MDAnalysis neighbor search to find neighbors.
192192
@@ -237,4 +237,4 @@ def get_grid_neighbors(self, universe, search_object, mol_id, highest_level):
237237
# residues from the central molecule
238238
neighbors = search - fragment.residues
239239

240-
return neighbors
240+
return neighbors.fragindices

tests/unit/CodeEntropy/levels/test_search.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,59 @@ def test_get_RAD_neighbors(tmp_path: Path):
7171
assert neighbors == [151, 3, 75, 219, 229, 488, 460, 118, 230, 326]
7272

7373

74+
def test_get_grid_neighbors(tmp_path: Path):
75+
"""
76+
Args:
77+
tmp_path: Pytest provided temporatry directory
78+
"""
79+
args = {}
80+
search = Search()
81+
system = "methane"
82+
repo_root = Path(__file__).resolve().parents[4]
83+
config_path = (
84+
repo_root / "tests" / "regression" / "configs" / system / "config.yaml"
85+
)
86+
87+
tmp_path.mkdir(parents=True, exist_ok=True)
88+
89+
raw = yaml.safe_load(config_path.read_text())
90+
if not isinstance(raw, dict):
91+
raise ValueError(
92+
f"Config must parse to a dict. Got {type(raw)} from {config_path}"
93+
)
94+
95+
cooked = Helpers._abspathify_config_paths(raw, base_dir=config_path.parent)
96+
required: list[Path] = []
97+
run1 = cooked.get("run1")
98+
if isinstance(run1, dict):
99+
ff = run1.get("force_file")
100+
if isinstance(ff, str) and ff:
101+
required.append(Path(ff))
102+
for p in run1.get("top_traj_file") or []:
103+
if isinstance(p, str) and p:
104+
required.append(Path(p))
105+
106+
if required:
107+
Helpers.ensure_testdata_for_system(system, required_paths=required)
108+
109+
runner = CodeEntropyRunner(tmp_path)
110+
parser = runner._config_manager.build_parser()
111+
args, _ = parser.parse_known_args()
112+
args.end = run1.get("end")
113+
args.top_traj_file = run1.get("top_traj_file")
114+
args.file_format = run1.get("file_format")
115+
assert args.end == 1
116+
117+
universe_operations = UniverseOperations()
118+
universe = CodeEntropyRunner._build_universe(args, universe_operations)
119+
120+
neighbors = search.get_grid_neighbors(
121+
universe=universe, mol_id=0, highest_level="united_atom"
122+
)
123+
124+
assert (neighbors == [151, 3, 75, 219]).all
125+
126+
74127
def test_get_angle():
75128
search = Search()
76129
result1 = search.get_angle(a, b, c, dimensions)

0 commit comments

Comments
 (0)