Skip to content

Commit e58b5fc

Browse files
committed
Fix bug and clean up codes based on coderabbit
1 parent 9c9f463 commit e58b5fc

3 files changed

Lines changed: 53 additions & 7 deletions

File tree

deepmd/tf/infer/deep_eval.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def __init__(
144144
self.dm = None
145145
if self.modifier_type is not None:
146146
modifier = BaseModifier.get_class_by_type(self.modifier_type)
147-
modifier_params = modifier.get_params(self)
147+
modifier_params = modifier.get_params_from_frozen_model(self)
148148
self.dm = modifier.get_modifier(modifier_params)
149149

150150
def _init_tensors(self) -> None:
@@ -662,8 +662,8 @@ def _get_natoms_and_nframes(
662662
coords: np.ndarray,
663663
atom_types: list[int] | np.ndarray,
664664
) -> tuple[int, int]:
665-
atom_types = np.reshape(atom_types, (-1))
666-
natoms = len(atom_types)
665+
# (natoms,) or (nframes, natoms,)
666+
natoms = np.shape(atom_types)[-1]
667667
if natoms == 0:
668668
assert coords.size == 0
669669
else:

deepmd/tf/modifier/base_modifier.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from abc import (
3+
abstractmethod,
4+
)
5+
26
from deepmd.dpmodel.modifier.base_modifier import (
37
make_base_modifier,
48
)
@@ -11,3 +15,23 @@ class BaseModifier(DeepPot, make_base_modifier()):
1115
def __init__(self, *args, **kwargs) -> None:
1216
"""Construct a basic model for different tasks."""
1317
DeepPot.__init__(self, *args, **kwargs)
18+
19+
@staticmethod
20+
@abstractmethod
21+
def get_params_from_frozen_model(model) -> dict:
22+
"""Extract the modifier parameters from a model.
23+
24+
This method should extract the necessary parameters from a model
25+
to create an instance of this modifier.
26+
27+
Parameters
28+
----------
29+
model
30+
The model from which to extract parameters
31+
32+
Returns
33+
-------
34+
dict
35+
The modifier parameters
36+
"""
37+
pass

deepmd/tf/modifier/dipole_charge.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import os
3+
from typing import (
4+
TYPE_CHECKING,
5+
)
36

47
import numpy as np
58

@@ -13,9 +16,6 @@
1316
op_module,
1417
tf,
1518
)
16-
from deepmd.tf.infer import (
17-
DeepEval,
18-
)
1919
from deepmd.tf.infer.deep_dipole import DeepDipoleOld as DeepDipole
2020
from deepmd.tf.infer.ewald_recp import (
2121
EwaldRecp,
@@ -30,6 +30,11 @@
3030
run_sess,
3131
)
3232

33+
if TYPE_CHECKING:
34+
from deepmd.tf.infer import (
35+
DeepEval,
36+
)
37+
3338

3439
@BaseModifier.register("dipole_charge")
3540
class DipoleChargeModifier(DeepDipole, BaseModifier):
@@ -492,7 +497,24 @@ def modify_data(self, data: dict, data_sys: DeepmdData) -> None:
492497
data["virial"] -= tot_v.reshape(data["virial"].shape)
493498

494499
@staticmethod
495-
def get_params(model: DeepEval):
500+
def get_params_from_frozen_model(model: "DeepEval") -> dict:
501+
"""Extract modifier parameters from a DeepEval model.
502+
503+
Parameters
504+
----------
505+
model : DeepEval
506+
The DeepEval model instance containing the modifier tensors.
507+
508+
Returns
509+
-------
510+
dict
511+
Dictionary containing modifier parameters:
512+
- model_name : str
513+
- model_charge_map : list[int]
514+
- sys_charge_map : list[int]
515+
- ewald_h : float
516+
- ewald_beta : float
517+
"""
496518
t_mdl_name = model._get_tensor("modifier_attr/mdl_name:0")
497519
t_mdl_charge_map = model._get_tensor("modifier_attr/mdl_charge_map:0")
498520
t_sys_charge_map = model._get_tensor("modifier_attr/sys_charge_map:0")

0 commit comments

Comments
 (0)