Skip to content

Commit 006b2e7

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent f92aa8a commit 006b2e7

File tree

2 files changed

+30
-24
lines changed

2 files changed

+30
-24
lines changed

deepmd/entrypoints/eval_desc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,12 @@ def eval_desc(
126126
)
127127

128128
# save descriptors
129-
system_name = os.path.basename(system_path.rstrip('/'))
129+
system_name = os.path.basename(system_path.rstrip("/"))
130130
desc_file = output_dir / f"{system_name}.npy"
131131
np.save(desc_file, descriptors)
132-
132+
133133
log.info(f"# descriptors saved to {desc_file}")
134134
log.info(f"# descriptor shape: {descriptors.shape}")
135135
log.info("# ----------------------------------- ")
136136

137-
log.info("# eval_desc completed successfully")
137+
log.info("# eval_desc completed successfully")
Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,42 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2-
import unittest
3-
import tempfile
4-
import os
52
import inspect
3+
import os
64
import shutil
7-
from pathlib import Path
5+
import tempfile
6+
import unittest
87

9-
from deepmd.entrypoints.eval_desc import eval_desc
8+
from deepmd.common import (
9+
expand_sys_str,
10+
)
1011
from deepmd.entrypoints import eval_desc as eval_desc_module
11-
from deepmd.common import expand_sys_str
12+
from deepmd.entrypoints.eval_desc import (
13+
eval_desc,
14+
)
1215

1316

1417
class TestEvalDesc(unittest.TestCase):
1518
"""Test the eval-desc CLI functionality."""
16-
19+
1720
def test_eval_desc_function_signature(self) -> None:
1821
"""Test that eval_desc function has the expected signature."""
1922
# Check that it's callable
2023
self.assertTrue(callable(eval_desc))
21-
24+
2225
# Check that it accepts the expected parameters
2326
sig = inspect.signature(eval_desc)
24-
expected_params = {'model', 'system', 'datafile', 'output', 'head'}
25-
actual_params = set(sig.parameters.keys()) - {'kwargs'}
26-
self.assertEqual(expected_params, actual_params,
27-
f"Expected parameters {expected_params}, got {actual_params}")
28-
27+
expected_params = {"model", "system", "datafile", "output", "head"}
28+
actual_params = set(sig.parameters.keys()) - {"kwargs"}
29+
self.assertEqual(
30+
expected_params,
31+
actual_params,
32+
f"Expected parameters {expected_params}, got {actual_params}",
33+
)
34+
2935
def test_eval_desc_module_docstring(self) -> None:
3036
"""Test that eval_desc module has proper documentation."""
3137
self.assertIsNotNone(eval_desc_module.__doc__)
3238
self.assertIn("descriptor", eval_desc_module.__doc__.lower())
33-
39+
3440
def test_eval_desc_expansion_logic(self) -> None:
3541
"""Test system expansion logic without requiring full deepmd."""
3642
# Create test directories
@@ -39,23 +45,23 @@ def test_eval_desc_expansion_logic(self) -> None:
3945
# Test that expand_sys_str is available and works
4046
result = expand_sys_str("nonexistent_path")
4147
self.assertIsInstance(result, list)
42-
48+
4349
# Test with existing directory
4450
os.makedirs(os.path.join(test_dir, "system1"))
4551
result = expand_sys_str(os.path.join(test_dir, "system*"))
4652
self.assertIsInstance(result, list)
47-
53+
4854
finally:
4955
shutil.rmtree(test_dir, ignore_errors=True)
50-
56+
5157
def test_eval_desc_parameter_validation(self) -> None:
5258
"""Test parameter validation without requiring model loading."""
5359
# Test with completely invalid inputs - should fail early
5460
test_dir = tempfile.mkdtemp()
5561
try:
5662
nonexistent = os.path.join(test_dir, "nonexistent")
5763
output = os.path.join(test_dir, "output")
58-
64+
5965
# This should raise RuntimeError about not finding valid system
6066
# before trying to load the model
6167
with self.assertRaises(RuntimeError) as context:
@@ -65,13 +71,13 @@ def test_eval_desc_parameter_validation(self) -> None:
6571
datafile=None,
6672
output=output,
6773
)
68-
74+
6975
# Check that it's the expected error message
7076
self.assertIn("Did not find valid system", str(context.exception))
71-
77+
7278
finally:
7379
shutil.rmtree(test_dir, ignore_errors=True)
7480

7581

7682
if __name__ == "__main__":
77-
unittest.main()
83+
unittest.main()

0 commit comments

Comments
 (0)