Skip to content

Commit a878838

Browse files
Copilotnjzjz
andcommitted
fix: move imports to top-level and remove try/except from tests
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
1 parent 2cb3163 commit a878838

File tree

4 files changed

+37
-68
lines changed

4 files changed

+37
-68
lines changed

deepmd/tf/entrypoints/train.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
import json
99
import logging
1010
import time
11+
from pathlib import (
12+
Path,
13+
)
1114
from typing import (
1215
Any,
1316
Optional,
@@ -48,6 +51,9 @@
4851
from deepmd.utils.data_system import (
4952
get_data,
5053
)
54+
from deepmd.utils.path import (
55+
DPPath,
56+
)
5157

5258
__all__ = ["train"]
5359

@@ -236,14 +242,6 @@ def _do_work(
236242
if not is_compress:
237243
stat_file_raw = jdata["training"].get("stat_file", None)
238244
if stat_file_raw is not None and run_opt.is_chief:
239-
from pathlib import (
240-
Path,
241-
)
242-
243-
from deepmd.utils.path import (
244-
DPPath,
245-
)
246-
247245
if not Path(stat_file_raw).exists():
248246
if stat_file_raw.endswith((".h5", ".hdf5")):
249247
with h5py.File(stat_file_raw, "w") as f:

deepmd/tf/model/ener.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
from deepmd.tf.utils.spin import (
2222
Spin,
2323
)
24+
from deepmd.tf.utils.stat import (
25+
compute_output_stats,
26+
)
2427
from deepmd.tf.utils.type_embed import (
2528
TypeEmbedNet,
2629
)
@@ -174,10 +177,6 @@ def _compute_output_stat(
174177
) -> None:
175178
if stat_file_path is not None:
176179
# Use the new stat functionality with file save/load
177-
from deepmd.tf.utils.stat import (
178-
compute_output_stats,
179-
)
180-
181180
# Merge system stats for compatibility
182181
m_all_stat = merge_sys_stat(all_stat)
183182

source/tests/tf/test_stat_file.py

Lines changed: 11 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -51,40 +51,20 @@ def test_stat_file_tf(self) -> None:
5151
mpi_log="master",
5252
)
5353

54-
try:
55-
# Run training - this should create the stat file
56-
_do_work(self.jdata, run_opt, is_compress=False)
54+
# Run training - this should create the stat file
55+
_do_work(self.jdata, run_opt, is_compress=False)
5756

58-
# Check if stat files were created
59-
stat_path = Path(stat_file_path)
60-
self.assertTrue(
61-
stat_path.exists(), "Stat file directory should be created"
62-
)
57+
# Check if stat files were created
58+
stat_path = Path(stat_file_path)
59+
self.assertTrue(stat_path.exists(), "Stat file directory should be created")
6360

64-
# Check for energy bias and std files
65-
bias_file = stat_path / "bias_atom_energy"
66-
std_file = stat_path / "std_atom_energy"
61+
# Check for energy bias and std files
62+
bias_file = stat_path / "bias_atom_energy"
63+
std_file = stat_path / "std_atom_energy"
6764

68-
# At minimum, the directory structure should be created
69-
# Even if files aren't created due to insufficient data, the directory should exist
70-
self.assertTrue(
71-
stat_path.is_dir(), "Stat file path should be a directory"
72-
)
73-
74-
except Exception as e:
75-
# Print the exception for debugging but don't fail the test
76-
# since we're mainly testing that the stat_file parameter is accepted
77-
print(
78-
f"Training encountered an exception (expected for minimal test data): {e}"
79-
)
80-
81-
# Still check that the stat file directory was created
82-
stat_path = Path(stat_file_path)
83-
if stat_path.exists():
84-
self.assertTrue(
85-
stat_path.is_dir(),
86-
"Stat file path should be a directory if created",
87-
)
65+
# At minimum, the directory structure should be created
66+
# Even if files aren't created due to insufficient data, the directory should exist
67+
self.assertTrue(stat_path.is_dir(), "Stat file path should be a directory")
8868

8969

9070
if __name__ == "__main__":

source/tests/tf/test_stat_file_integration.py

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2-
"""
3-
Integration test to validate stat_file functionality end-to-end
4-
"""
2+
"""Integration test to validate stat_file functionality end-to-end."""
53

64
import json
75
import os
@@ -76,27 +74,22 @@ def test_stat_file_save_and_load(self) -> None:
7674
with open(config_file, "w") as f:
7775
json.dump(config, f, indent=2)
7876

79-
try:
80-
# Attempt to run training
81-
# This will fail due to missing data but should still process stat_file parameter
82-
train(
83-
INPUT=config_file,
84-
init_model=None,
85-
restart=None,
86-
output=os.path.join(temp_dir, "output.json"),
87-
init_frz_model=None,
88-
mpi_log="master",
89-
log_level=20,
90-
log_path=None,
91-
is_compress=False,
92-
skip_neighbor_stat=True,
93-
finetune=None,
94-
use_pretrain_script=False,
95-
)
96-
except Exception as e:
97-
# Expected to fail due to missing training data
98-
# But the stat_file parameter should have been processed
99-
print(f"Expected training failure: {e}")
77+
# Attempt to run training
78+
# This will fail due to missing data but should still process stat_file parameter
79+
train(
80+
INPUT=config_file,
81+
init_model=None,
82+
restart=None,
83+
output=os.path.join(temp_dir, "output.json"),
84+
init_frz_model=None,
85+
mpi_log="master",
86+
log_level=20,
87+
log_path=None,
88+
is_compress=False,
89+
skip_neighbor_stat=True,
90+
finetune=None,
91+
use_pretrain_script=False,
92+
)
10093

10194
# The main validation is that the code didn't crash with an unrecognized parameter
10295
# and that if the stat file directory was attempted to be created, it exists
@@ -105,7 +98,6 @@ def test_stat_file_save_and_load(self) -> None:
10598
self.assertTrue(
10699
stat_path.is_dir(), "Stat file path should be a directory"
107100
)
108-
print(f"Stat file directory was created: {stat_file_path}")
109101

110102
# This test primarily validates that the stat_file parameter is accepted
111103
# and processed without errors in the TF pipeline

0 commit comments

Comments
 (0)