Skip to content

Commit b6c525a

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 7956dd9 commit b6c525a

1 file changed

Lines changed: 7 additions & 6 deletions

File tree

source/tests/pt/test_init_frz_model.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313

1414
import numpy as np
1515

16+
from deepmd.entrypoints.convert_backend import (
17+
convert_backend,
18+
)
1619
from deepmd.pt.entrypoints.main import (
1720
freeze,
1821
get_trainer,
@@ -23,9 +26,7 @@
2326
from deepmd.tf.utils.convert import (
2427
convert_pbtxt_to_pb,
2528
)
26-
from deepmd.entrypoints.convert_backend import (
27-
convert_backend,
28-
)
29+
2930
from .common import (
3031
run_dp,
3132
)
@@ -34,7 +35,7 @@
3435
class TestInitFrzModel(unittest.TestCase):
3536
def setUp(self) -> None:
3637
input_json = str(Path(__file__).parent / "water/se_atten.json")
37-
with open(input_json, "r", encoding="utf-8") as f:
38+
with open(input_json, encoding="utf-8") as f:
3839
config = json.load(f)
3940
config["model"]["descriptor"]["smooth_type_embedding"] = True
4041
config["training"]["numb_steps"] = 1
@@ -145,7 +146,7 @@ def test_init_frz_model_pb2pth(self) -> None:
145146
"""Test initialization from frozen model with mismatched keys to test strict=False fallback."""
146147
# Create a base model
147148
input_json = str(Path(__file__).parent / "water/se_e2_a.json")
148-
with open(input_json, "r", encoding="utf-8") as f:
149+
with open(input_json, encoding="utf-8") as f:
149150
config = json.load(f)
150151
config["training"]["numb_steps"] = 1
151152
config["training"]["save_freq"] = 1
@@ -160,7 +161,7 @@ def test_init_frz_model_pb2pth(self) -> None:
160161

161162
trainer = get_trainer(config, init_frz_model=frozen_model)
162163
trainer.run()
163-
164+
164165
def tearDown(self) -> None:
165166
for f in os.listdir("."):
166167
if f.startswith("frozen_model") and f.endswith(".pth"):

0 commit comments

Comments
 (0)