Skip to content

Commit 6fb0703

Browse files
authored
test: add TensorFlow graph reset in teardown method for entrypoint tests and bias standard tests (#5049)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Tests** * Improved test isolation by resetting the TensorFlow default graph before and after affected tests to ensure a clean graph state between runs. * Adjusted test modules to import the TensorFlow environment wrapper used in tests. * No changes to test assertions or runtime behavior beyond lifecycle improvements. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 87cb6ef commit 6fb0703

2 files changed

Lines changed: 15 additions & 0 deletions

File tree

source/tests/tf/test_nvnmd_entrypoints.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,7 @@ def test_model_qnn_v0(self) -> None:
515515
def tearDown(self) -> None:
516516
# close
517517
nvnmd_cfg.enable = False
518+
tf.reset_default_graph()
518519

519520

520521
class TestNvnmdEntrypointsV1(tf.test.TestCase):
@@ -878,6 +879,7 @@ def test_wrap_qnn_v1(self) -> None:
878879
def tearDown(self) -> None:
879880
# close
880881
nvnmd_cfg.enable = False
882+
tf.reset_default_graph()
881883

882884

883885
if __name__ == "__main__":

source/tests/tf/test_out_bias_std.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
from deepmd.tf.descriptor.se_a import (
88
DescrptSeA,
99
)
10+
from deepmd.tf.env import (
11+
tf,
12+
)
1013
from deepmd.tf.fit.dipole import (
1114
DipoleFittingSeA,
1215
)
@@ -21,6 +24,16 @@
2124
class TestOutBiasStd(unittest.TestCase):
2225
"""Test out_bias and out_std functionality in TensorFlow backend."""
2326

27+
def setUp(self):
28+
"""Resets the default graph before each test."""
29+
super().setUp()
30+
tf.reset_default_graph()
31+
32+
def tearDown(self):
33+
"""Resets the default graph after each test."""
34+
tf.reset_default_graph()
35+
super().tearDown()
36+
2437
def test_init_out_stat_basic(self):
2538
"""Test basic init_out_stat functionality."""
2639
descriptor = DescrptSeA(

0 commit comments

Comments
 (0)