Skip to content

Commit 8a96a3f

Browse files
committed
feat(pt): add default_fparam
"numb_fparam": 2, "default_fparam": [0.0, 1.0],
1 parent e71b5c2 commit 8a96a3f

10 files changed

Lines changed: 59 additions & 5 deletions

File tree

deepmd/entrypoints/test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,11 @@ def test_ener(
296296
data.add("atom_ener", 1, atomic=True, must=True, high_prec=False)
297297
if dp.get_dim_fparam() > 0:
298298
data.add(
299-
"fparam", dp.get_dim_fparam(), atomic=False, must=True, high_prec=False
299+
"fparam",
300+
dp.get_dim_fparam(),
301+
atomic=False,
302+
must=not dp.has_default_fparam(),
303+
high_prec=False,
300304
)
301305
if dp.get_dim_aparam() > 0:
302306
data.add("aparam", dp.get_dim_aparam(), atomic=True, must=True, high_prec=False)
@@ -326,7 +330,7 @@ def test_ener(
326330
atype = test_data["type"][:numb_test].reshape([numb_test, -1])
327331
else:
328332
atype = test_data["type"][0]
329-
if dp.get_dim_fparam() > 0:
333+
if dp.get_dim_fparam() > 0 and test_data["find_fparam"] != 0.0:
330334
fparam = test_data["fparam"][:numb_test]
331335
else:
332336
fparam = None

deepmd/infer/deep_eval.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ def get_type_map(self) -> list[str]:
159159
def get_dim_fparam(self) -> int:
160160
"""Get the number (dimension) of frame parameters of this DP."""
161161

162+
def has_default_fparam(self) -> bool:
163+
return False
164+
162165
@abstractmethod
163166
def get_dim_aparam(self) -> int:
164167
"""Get the number (dimension) of atomic parameters of this DP."""
@@ -361,6 +364,9 @@ def get_dim_fparam(self) -> int:
361364
"""Get the number (dimension) of frame parameters of this DP."""
362365
return self.deep_eval.get_dim_fparam()
363366

367+
def has_default_fparam(self) -> bool:
368+
return self.deep_eval.has_default_fparam()
369+
364370
def get_dim_aparam(self) -> int:
365371
"""Get the number (dimension) of atomic parameters of this DP."""
366372
return self.deep_eval.get_dim_aparam()

deepmd/pt/infer/deep_eval.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,9 @@ def get_dim_aparam(self) -> int:
181181
"""Get the number (dimension) of atomic parameters of this DP."""
182182
return self.dp.model["Default"].get_dim_aparam()
183183

184+
def has_default_fparam(self) -> bool:
185+
return self.dp.model["Default"].has_default_fparam()
186+
184187
def get_intensive(self) -> bool:
185188
return self.dp.model["Default"].get_intensive()
186189

deepmd/pt/model/atomic_model/dp_atomic_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,9 @@ def get_dim_fparam(self) -> int:
333333
"""Get the number (dimension) of frame parameters of this atomic model."""
334334
return self.fitting_net.get_dim_fparam()
335335

336+
def has_default_fparam(self) -> bool:
337+
return self.fitting_net.has_default_fparam()
338+
336339
def get_dim_aparam(self) -> int:
337340
"""Get the number (dimension) of atomic parameters of this atomic model."""
338341
return self.fitting_net.get_dim_aparam()

deepmd/pt/model/model/make_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,10 @@ def get_dim_fparam(self) -> int:
526526
"""Get the number (dimension) of frame parameters of this atomic model."""
527527
return self.atomic_model.get_dim_fparam()
528528

529+
@torch.jit.export
530+
def has_default_fparam(self) -> bool:
531+
return self.atomic_model.has_default_fparam()
532+
529533
@torch.jit.export
530534
def get_dim_aparam(self) -> int:
531535
"""Get the number (dimension) of atomic parameters of this atomic model."""

deepmd/pt/model/task/ener.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __init__(
6666
mixed_types: bool = True,
6767
seed: Optional[Union[int, list[int]]] = None,
6868
type_map: Optional[list[str]] = None,
69+
default_fparam: Optional[list] = None,
6970
**kwargs,
7071
) -> None:
7172
super().__init__(
@@ -84,6 +85,7 @@ def __init__(
8485
mixed_types=mixed_types,
8586
seed=seed,
8687
type_map=type_map,
88+
default_fparam=default_fparam,
8789
**kwargs,
8890
)
8991

deepmd/pt/model/task/fitting.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def __init__(
145145
remove_vaccum_contribution: Optional[list[bool]] = None,
146146
type_map: Optional[list[str]] = None,
147147
use_aparam_as_mask: bool = False,
148+
default_fparam: Optional[list] = None,
148149
**kwargs,
149150
) -> None:
150151
super().__init__()
@@ -156,6 +157,7 @@ def __init__(
156157
self.resnet_dt = resnet_dt
157158
self.numb_fparam = numb_fparam
158159
self.numb_aparam = numb_aparam
160+
self.default_fparam = default_fparam
159161
self.dim_case_embd = dim_case_embd
160162
self.activation_function = activation_function
161163
self.precision = precision
@@ -217,6 +219,20 @@ def __init__(
217219
else:
218220
self.case_embd = None
219221

222+
if self.default_fparam is not None:
223+
if self.numb_fparam > 0:
224+
assert (
225+
len(self.default_fparam) == self.numb_fparam
226+
), "default_fparam length mismatch!"
227+
self.register_buffer(
228+
"default_fparam_tensor",
229+
torch.tensor(
230+
np.array(self.default_fparam), dtype=self.prec, device=device
231+
),
232+
)
233+
else:
234+
self.default_fparam_tensor = None
235+
220236
in_dim = (
221237
self.dim_descrpt
222238
+ self.numb_fparam
@@ -333,6 +349,9 @@ def get_dim_fparam(self) -> int:
333349
"""Get the number (dimension) of frame parameters of this atomic model."""
334350
return self.numb_fparam
335351

352+
def has_default_fparam(self) -> bool:
353+
return self.default_fparam is not None
354+
336355
def get_dim_aparam(self) -> int:
337356
"""Get the number (dimension) of atomic parameters of this atomic model."""
338357
return self.numb_aparam
@@ -427,6 +446,13 @@ def _forward_common(
427446
):
428447
# cast the input to internal precsion
429448
xx = descriptor.to(self.prec)
449+
nf, nloc, nd = xx.shape
450+
451+
if self.numb_fparam > 0 and fparam is None:
452+
# use default fparam
453+
assert self.default_fparam_tensor is not None
454+
fparam = torch.tile(self.default_fparam_tensor.unsqueeze(0), [nf, 1])
455+
430456
fparam = fparam.to(self.prec) if fparam is not None else None
431457
aparam = aparam.to(self.prec) if aparam is not None else None
432458

@@ -439,7 +465,6 @@ def _forward_common(
439465
xx_zeros = torch.zeros_like(xx)
440466
else:
441467
xx_zeros = None
442-
nf, nloc, nd = xx.shape
443468
net_dim_out = self._net_out_dim()
444469

445470
if nd != self.dim_descrpt:

deepmd/pt/model/task/invar_fitting.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def __init__(
103103
atom_ener: Optional[list[Optional[torch.Tensor]]] = None,
104104
type_map: Optional[list[str]] = None,
105105
use_aparam_as_mask: bool = False,
106+
default_fparam: Optional[list] = None,
106107
**kwargs,
107108
) -> None:
108109
self.dim_out = dim_out
@@ -128,6 +129,7 @@ def __init__(
128129
else [x is not None for x in atom_ener],
129130
type_map=type_map,
130131
use_aparam_as_mask=use_aparam_as_mask,
132+
default_fparam=default_fparam,
131133
**kwargs,
132134
)
133135

deepmd/pt/train/training.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,7 +1174,8 @@ def get_data(self, is_train=True, task_key="Default"):
11741174
label_dict = {}
11751175
for item_key in batch_data:
11761176
if item_key in input_keys:
1177-
input_dict[item_key] = batch_data[item_key]
1177+
if item_key != "fparam" or batch_data["find_fparam"] != 0.0:
1178+
input_dict[item_key] = batch_data[item_key]
11781179
else:
11791180
if item_key not in ["sid", "fid"]:
11801181
label_dict[item_key] = batch_data[item_key]
@@ -1253,7 +1254,10 @@ def get_additional_data_requirement(_model):
12531254
if _model.get_dim_fparam() > 0:
12541255
fparam_requirement_items = [
12551256
DataRequirementItem(
1256-
"fparam", _model.get_dim_fparam(), atomic=False, must=True
1257+
"fparam",
1258+
_model.get_dim_fparam(),
1259+
atomic=False,
1260+
must=not _model.has_default_fparam(),
12571261
)
12581262
]
12591263
additional_data_requirement += fparam_requirement_items

deepmd/utils/argcheck.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1941,6 +1941,7 @@ def fitting_ener():
19411941
return [
19421942
Argument("numb_fparam", int, optional=True, default=0, doc=doc_numb_fparam),
19431943
Argument("numb_aparam", int, optional=True, default=0, doc=doc_numb_aparam),
1944+
Argument("default_fparam", list, optional=True, default=None),
19441945
Argument(
19451946
"dim_case_embd",
19461947
int,

0 commit comments

Comments
 (0)