Skip to content

Commit b8c0859

Browse files
committed
feat(pt): add default_fparam
"numb_fparam": 2, "default_fparam": [0.0, 1.0],
1 parent 272f573 commit b8c0859

10 files changed

Lines changed: 59 additions & 5 deletions

File tree

deepmd/entrypoints/test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,11 @@ def test_ener(
298298
data.add("atom_ener", 1, atomic=True, must=True, high_prec=False)
299299
if dp.get_dim_fparam() > 0:
300300
data.add(
301-
"fparam", dp.get_dim_fparam(), atomic=False, must=True, high_prec=False
301+
"fparam",
302+
dp.get_dim_fparam(),
303+
atomic=False,
304+
must=not dp.has_default_fparam(),
305+
high_prec=False,
302306
)
303307
if dp.get_dim_aparam() > 0:
304308
data.add("aparam", dp.get_dim_aparam(), atomic=True, must=True, high_prec=False)
@@ -334,7 +338,7 @@ def test_ener(
334338
atype = test_data["type"][:numb_test].reshape([numb_test, -1])
335339
else:
336340
atype = test_data["type"][0]
337-
if dp.get_dim_fparam() > 0:
341+
if dp.get_dim_fparam() > 0 and test_data["find_fparam"] != 0.0:
338342
fparam = test_data["fparam"][:numb_test]
339343
else:
340344
fparam = None

deepmd/infer/deep_eval.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,9 @@ def get_type_map(self) -> list[str]:
160160
def get_dim_fparam(self) -> int:
161161
"""Get the number (dimension) of frame parameters of this DP."""
162162

163+
def has_default_fparam(self) -> bool:
164+
return False
165+
163166
@abstractmethod
164167
def get_dim_aparam(self) -> int:
165168
"""Get the number (dimension) of atomic parameters of this DP."""
@@ -370,6 +373,9 @@ def get_dim_fparam(self) -> int:
370373
"""Get the number (dimension) of frame parameters of this DP."""
371374
return self.deep_eval.get_dim_fparam()
372375

376+
def has_default_fparam(self) -> bool:
377+
return self.deep_eval.has_default_fparam()
378+
373379
def get_dim_aparam(self) -> int:
374380
"""Get the number (dimension) of atomic parameters of this DP."""
375381
return self.deep_eval.get_dim_aparam()

deepmd/pt/infer/deep_eval.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,9 @@ def get_dim_aparam(self) -> int:
183183
"""Get the number (dimension) of atomic parameters of this DP."""
184184
return self.dp.model["Default"].get_dim_aparam()
185185

186+
def has_default_fparam(self) -> bool:
187+
return self.dp.model["Default"].has_default_fparam()
188+
186189
def get_intensive(self) -> bool:
187190
return self.dp.model["Default"].get_intensive()
188191

deepmd/pt/model/atomic_model/dp_atomic_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,9 @@ def get_dim_fparam(self) -> int:
308308
"""Get the number (dimension) of frame parameters of this atomic model."""
309309
return self.fitting_net.get_dim_fparam()
310310

311+
def has_default_fparam(self) -> bool:
312+
return self.fitting_net.has_default_fparam()
313+
311314
def get_dim_aparam(self) -> int:
312315
"""Get the number (dimension) of atomic parameters of this atomic model."""
313316
return self.fitting_net.get_dim_aparam()

deepmd/pt/model/model/make_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,10 @@ def get_dim_fparam(self) -> int:
522522
"""Get the number (dimension) of frame parameters of this atomic model."""
523523
return self.atomic_model.get_dim_fparam()
524524

525+
@torch.jit.export
526+
def has_default_fparam(self) -> bool:
527+
return self.atomic_model.has_default_fparam()
528+
525529
@torch.jit.export
526530
def get_dim_aparam(self) -> int:
527531
"""Get the number (dimension) of atomic parameters of this atomic model."""

deepmd/pt/model/task/ener.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def __init__(
5656
mixed_types: bool = True,
5757
seed: Optional[Union[int, list[int]]] = None,
5858
type_map: Optional[list[str]] = None,
59+
default_fparam: Optional[list] = None,
5960
**kwargs,
6061
) -> None:
6162
super().__init__(
@@ -74,6 +75,7 @@ def __init__(
7475
mixed_types=mixed_types,
7576
seed=seed,
7677
type_map=type_map,
78+
default_fparam=default_fparam,
7779
**kwargs,
7880
)
7981

deepmd/pt/model/task/fitting.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def __init__(
227227
remove_vaccum_contribution: Optional[list[bool]] = None,
228228
type_map: Optional[list[str]] = None,
229229
use_aparam_as_mask: bool = False,
230+
default_fparam: Optional[list] = None,
230231
**kwargs,
231232
) -> None:
232233
super().__init__()
@@ -238,6 +239,7 @@ def __init__(
238239
self.resnet_dt = resnet_dt
239240
self.numb_fparam = numb_fparam
240241
self.numb_aparam = numb_aparam
242+
self.default_fparam = default_fparam
241243
self.dim_case_embd = dim_case_embd
242244
self.activation_function = activation_function
243245
self.precision = precision
@@ -299,6 +301,20 @@ def __init__(
299301
else:
300302
self.case_embd = None
301303

304+
if self.default_fparam is not None:
305+
if self.numb_fparam > 0:
306+
assert (
307+
len(self.default_fparam) == self.numb_fparam
308+
), "default_fparam length mismatch!"
309+
self.register_buffer(
310+
"default_fparam_tensor",
311+
torch.tensor(
312+
np.array(self.default_fparam), dtype=self.prec, device=device
313+
),
314+
)
315+
else:
316+
self.default_fparam_tensor = None
317+
302318
in_dim = (
303319
self.dim_descrpt
304320
+ self.numb_fparam
@@ -415,6 +431,9 @@ def get_dim_fparam(self) -> int:
415431
"""Get the number (dimension) of frame parameters of this atomic model."""
416432
return self.numb_fparam
417433

434+
def has_default_fparam(self) -> bool:
435+
return self.default_fparam is not None
436+
418437
def get_dim_aparam(self) -> int:
419438
"""Get the number (dimension) of atomic parameters of this atomic model."""
420439
return self.numb_aparam
@@ -509,6 +528,13 @@ def _forward_common(
509528
):
510529
# cast the input to internal precsion
511530
xx = descriptor.to(self.prec)
531+
nf, nloc, nd = xx.shape
532+
533+
if self.numb_fparam > 0 and fparam is None:
534+
# use default fparam
535+
assert self.default_fparam_tensor is not None
536+
fparam = torch.tile(self.default_fparam_tensor.unsqueeze(0), [nf, 1])
537+
512538
fparam = fparam.to(self.prec) if fparam is not None else None
513539
aparam = aparam.to(self.prec) if aparam is not None else None
514540

@@ -521,7 +547,6 @@ def _forward_common(
521547
xx_zeros = torch.zeros_like(xx)
522548
else:
523549
xx_zeros = None
524-
nf, nloc, nd = xx.shape
525550
net_dim_out = self._net_out_dim()
526551

527552
if nd != self.dim_descrpt:

deepmd/pt/model/task/invar_fitting.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def __init__(
103103
atom_ener: Optional[list[Optional[torch.Tensor]]] = None,
104104
type_map: Optional[list[str]] = None,
105105
use_aparam_as_mask: bool = False,
106+
default_fparam: Optional[list] = None,
106107
**kwargs,
107108
) -> None:
108109
self.dim_out = dim_out
@@ -128,6 +129,7 @@ def __init__(
128129
else [x is not None for x in atom_ener],
129130
type_map=type_map,
130131
use_aparam_as_mask=use_aparam_as_mask,
132+
default_fparam=default_fparam,
131133
**kwargs,
132134
)
133135

deepmd/pt/train/training.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,7 +1126,8 @@ def get_data(self, is_train=True, task_key="Default"):
11261126
label_dict = {}
11271127
for item_key in batch_data:
11281128
if item_key in input_keys:
1129-
input_dict[item_key] = batch_data[item_key]
1129+
if item_key != "fparam" or batch_data["find_fparam"] != 0.0:
1130+
input_dict[item_key] = batch_data[item_key]
11301131
else:
11311132
if item_key not in ["sid", "fid"]:
11321133
label_dict[item_key] = batch_data[item_key]
@@ -1205,7 +1206,10 @@ def get_additional_data_requirement(_model):
12051206
if _model.get_dim_fparam() > 0:
12061207
fparam_requirement_items = [
12071208
DataRequirementItem(
1208-
"fparam", _model.get_dim_fparam(), atomic=False, must=True
1209+
"fparam",
1210+
_model.get_dim_fparam(),
1211+
atomic=False,
1212+
must=not _model.has_default_fparam(),
12091213
)
12101214
]
12111215
additional_data_requirement += fparam_requirement_items

deepmd/utils/argcheck.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1773,6 +1773,7 @@ def fitting_ener():
17731773
return [
17741774
Argument("numb_fparam", int, optional=True, default=0, doc=doc_numb_fparam),
17751775
Argument("numb_aparam", int, optional=True, default=0, doc=doc_numb_aparam),
1776+
Argument("default_fparam", list, optional=True, default=None),
17761777
Argument(
17771778
"dim_case_embd",
17781779
int,

0 commit comments

Comments
 (0)