Skip to content

Commit be6c622

Browse files
committed
fix bug in setting sel_type of pt model
1 parent dfc5f03 commit be6c622

2 files changed

Lines changed: 4 additions & 0 deletions

File tree

deepmd/pt/model/atomic_model/dp_atomic_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def __init__(
6464
self.rcut = self.descriptor.get_rcut()
6565
self.sel = self.descriptor.get_sel()
6666
self.fitting_net = fitting
67+
self.fitting_net.reinit_exclude(self.atom_exclude_types)
6768
super().init_out_stat()
6869
self.enable_eval_descriptor_hook = False
6970
self.enable_eval_fitting_last_layer_hook = False

source/tests/pt/model/test_get_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ def test_model_attr(self) -> None:
6060
]
6161
},
6262
)
63+
full_type_list = np.arange(len(atomic_model.type_map), dtype=int)
64+
atom_exclude_types = np.setdiff1d(full_type_list, self.model.get_sel_type())
65+
self.assertEqual(atom_exclude_types, [1])
6366
self.assertEqual(atomic_model.atom_exclude_types, [1])
6467
self.assertEqual(atomic_model.pair_exclude_types, [[1, 2]])
6568

0 commit comments

Comments
 (0)