Skip to content

Commit fffeb83

Browse files
committed
add unit tests for lammps conversion
1 parent b3ed217 commit fffeb83

1 file changed

Lines changed: 53 additions & 0 deletions

File tree

tests/unit/CodeEntropy/levels/test_mda_universe_operations.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,59 @@ def test_merge_forces_scales_kcal(monkeypatch):
118118
assert np.allclose(forces, np.ones((2, 2, 3)) * 4.184)
119119

120120

121+
def test_convert_lammps_transforms_forces_and_energies(monkeypatch):
122+
ops = UniverseOperations()
123+
124+
mock_universe = MagicMock()
125+
transformations_captured = []
126+
127+
def capture_universe(*args, **kwargs):
128+
if "transformations" in kwargs:
129+
transformations_captured.extend(kwargs["transformations"])
130+
return mock_universe
131+
132+
monkeypatch.setattr("CodeEntropy.levels.mda.mda.Universe", capture_universe)
133+
134+
ops.convert_lammps("tpr", "trr", "LAMMPSDUMP")
135+
136+
ts = MagicMock()
137+
ts.forces = np.array([[1.0, 2.0, 3.0]], dtype=float)
138+
ts.data = {"c_5": np.array([1.0]), "c_7": np.array([2.0])}
139+
140+
transformations_captured[0](ts)
141+
142+
assert np.allclose(ts.forces, np.array([[1.0, 2.0, 3.0]], dtype=float) * 4.184)
143+
assert np.allclose(ts.data["c_5"], np.array([1.0], dtype=float) * 4.184)
144+
assert np.allclose(ts.data["c_7"], np.array([[2.0]], dtype=float) * 4.184)
145+
146+
147+
def test_convert_lammps_fallback_on_keyerror(monkeypatch):
148+
ops = UniverseOperations()
149+
150+
transformations_captured = []
151+
call_count = [0]
152+
153+
def mock_universe(*args, **kwargs):
154+
call_count[0] += 1
155+
if call_count[0] == 1:
156+
raise KeyError("c_5")
157+
if "transformations" in kwargs:
158+
transformations_captured.extend(kwargs["transformations"])
159+
return MagicMock()
160+
161+
monkeypatch.setattr("CodeEntropy.levels.mda.mda.Universe", mock_universe)
162+
163+
ops.convert_lammps("tpr", "trr", "LAMMPSDUMP")
164+
165+
ts = MagicMock()
166+
ts.forces = np.array([[1.0, 2.0, 3.0]], dtype=float)
167+
168+
transformations_captured[0](ts)
169+
170+
assert np.allclose(ts.forces, np.array([[1.0, 2.0, 3.0]], dtype=float) * 4.184)
171+
assert call_count[0] == 2
172+
173+
121174
def test_select_atoms_builds_merged_universe_and_loads_timeseries(monkeypatch):
122175
ops = UniverseOperations()
123176

0 commit comments

Comments
 (0)