Skip to content

Commit c337dea

Browse files
wanghan-iapcmHan Wang
andauthored
feat(pt_expt): full model (#5244)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added a PyTorch experimental energy model with exportable lower-level tracing, descriptor/accessor and output-definition APIs, and output-bias management. * **Bug Fixes** * Fixed device placement for created tensors and added runtime validations for several model accessors. * **Tests** * Expanded test suite with autodiff/derivative validation, export/tracing checks, and cross-backend API consistency tests. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
1 parent a0bd530 commit c337dea

26 files changed

Lines changed: 2016 additions & 138 deletions

deepmd/dpmodel/fitting/general_fitting.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,7 @@ def _call_common(
584584
)
585585

586586
# calculate the prediction
587+
results: dict[str, Array] = {}
587588
if not self.mixed_types:
588589
outs = xp.zeros(
589590
[nf, nloc, net_dim_out],
@@ -622,4 +623,5 @@ def _call_common(
622623
exclude_mask = xp.astype(exclude_mask, xp.bool)
623624
# nf x nloc x nod
624625
outs = xp.where(exclude_mask[:, :, None], outs, xp.zeros_like(outs))
625-
return {self.var_name: outs}
626+
results[self.var_name] = outs
627+
return results

deepmd/dpmodel/model/dp_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,7 @@ def update_sel(
4848
def get_fitting_net(self) -> BaseFitting:
4949
"""Get the fitting network."""
5050
return self.atomic_model.fitting
51+
52+
def get_descriptor(self) -> BaseDescriptor:
53+
"""Get the descriptor."""
54+
return self.atomic_model.descriptor

deepmd/dpmodel/model/ener_model.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,28 @@ def atomic_output_def(self) -> FittingOutputDef:
4747
if self._enable_hessian:
4848
return self.hess_fitting_def
4949
return super().atomic_output_def()
50+
51+
def translated_output_def(self) -> dict[str, Any]:
52+
"""Get the translated output definition.
53+
54+
Maps internal output names to user-facing names, e.g.
55+
``energy_redu`` -> ``energy``, ``energy_derv_r`` -> ``force``.
56+
"""
57+
out_def_data = self.model_output_def().get_data()
58+
output_def = {
59+
"atom_energy": out_def_data["energy"],
60+
"energy": out_def_data["energy_redu"],
61+
}
62+
if self.do_grad_r("energy"):
63+
output_def["force"] = out_def_data["energy_derv_r"]
64+
output_def["force"].squeeze(-2)
65+
if self.do_grad_c("energy"):
66+
output_def["virial"] = out_def_data["energy_derv_c_redu"]
67+
output_def["virial"].squeeze(-2)
68+
output_def["atom_virial"] = out_def_data["energy_derv_c"]
69+
output_def["atom_virial"].squeeze(-2)
70+
if "mask" in out_def_data:
71+
output_def["mask"] = out_def_data["mask"]
72+
if self._enable_hessian:
73+
output_def["hessian"] = out_def_data["energy_derv_r_derv_r"]
74+
return output_def

deepmd/dpmodel/model/make_model.py

Lines changed: 75 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
PRECISION_DICT,
2222
RESERVED_PRECISION_DICT,
2323
NativeOP,
24+
get_xp_precision,
2425
)
2526
from deepmd.dpmodel.model.base_model import (
2627
BaseModel,
@@ -103,7 +104,8 @@ def model_call_from_call_lower(
103104
bb.reshape(nframes, 3, 3),
104105
)
105106
else:
106-
coord_normalized = cc.copy()
107+
xp = array_api_compat.array_namespace(cc)
108+
coord_normalized = xp.reshape(cc, (nframes, nloc, 3))
107109
extended_coord, extended_atype, mapping = extend_coord_with_ghosts(
108110
coord_normalized, atype, bb, rcut
109111
)
@@ -255,7 +257,7 @@ def call(
255257
The keys are defined by the `ModelOutputDef`.
256258
257259
"""
258-
cc, bb, fp, ap, input_prec = self.input_type_cast(
260+
cc, bb, fp, ap, input_prec = self._input_type_cast(
259261
coord, box=box, fparam=fparam, aparam=aparam
260262
)
261263
del coord, box, fparam, aparam
@@ -272,7 +274,7 @@ def call(
272274
aparam=ap,
273275
do_atomic_virial=do_atomic_virial,
274276
)
275-
model_predict = self.output_type_cast(model_predict, input_prec)
277+
model_predict = self._output_type_cast(model_predict, input_prec)
276278
return model_predict
277279

278280
def call_lower(
@@ -321,7 +323,7 @@ def call_lower(
321323
nlist,
322324
extra_nlist_sort=self.need_sorted_nlist_for_lower(),
323325
)
324-
cc_ext, _, fp, ap, input_prec = self.input_type_cast(
326+
cc_ext, _, fp, ap, input_prec = self._input_type_cast(
325327
extended_coord, fparam=fparam, aparam=aparam
326328
)
327329
del extended_coord, fparam, aparam
@@ -334,7 +336,7 @@ def call_lower(
334336
aparam=ap,
335337
do_atomic_virial=do_atomic_virial,
336338
)
337-
model_predict = self.output_type_cast(model_predict, input_prec)
339+
model_predict = self._output_type_cast(model_predict, input_prec)
338340
return model_predict
339341

340342
def forward_common_atomic(
@@ -364,60 +366,107 @@ def forward_common_atomic(
364366
)
365367

366368
forward_lower = call_lower
369+
forward_common = call
370+
forward_common_lower = call_lower
367371

368-
def input_type_cast(
372+
def get_out_bias(self) -> Array:
373+
"""Get the output bias."""
374+
return self.atomic_model.out_bias
375+
376+
def set_out_bias(self, out_bias: Array) -> None:
377+
"""Set the output bias."""
378+
self.atomic_model.out_bias = out_bias
379+
380+
def change_out_bias(
381+
self,
382+
merged: Any,
383+
bias_adjust_mode: str = "change-by-statistic",
384+
) -> None:
385+
"""Change the output bias according to the input data and the pretrained model.
386+
387+
Parameters
388+
----------
389+
merged
390+
The merged data samples.
391+
bias_adjust_mode : str
392+
The mode for changing output bias:
393+
'change-by-statistic' or 'set-by-statistic'.
394+
"""
395+
self.atomic_model.change_out_bias(merged, bias_adjust_mode=bias_adjust_mode)
396+
397+
def _input_type_cast(
369398
self,
370399
coord: Array,
371400
box: Array | None = None,
372401
fparam: Array | None = None,
373402
aparam: Array | None = None,
374-
) -> tuple[Array, Array, np.ndarray | None, np.ndarray | None, str]:
403+
) -> tuple[Array, Array | None, Array | None, Array | None, Any]:
375404
"""Cast the input data to global float type."""
376-
input_prec = RESERVED_PRECISION_DICT[self.precision_dict[coord.dtype.name]]
405+
xp = array_api_compat.array_namespace(coord)
406+
input_dtype = coord.dtype
407+
global_dtype = get_xp_precision(
408+
xp, RESERVED_PRECISION_DICT[self.global_np_float_precision]
409+
)
377410
###
378411
### type checking would not pass jit, convert to coord prec anyway
379412
###
380-
_lst: list[np.ndarray | None] = [
381-
vv.astype(coord.dtype) if vv is not None else None
413+
_lst: list[Array | None] = [
414+
xp.astype(vv, input_dtype) if vv is not None else None
382415
for vv in [box, fparam, aparam]
383416
]
384417
box, fparam, aparam = _lst
385-
if input_prec == RESERVED_PRECISION_DICT[self.global_np_float_precision]:
386-
return coord, box, fparam, aparam, input_prec
418+
if input_dtype == global_dtype:
419+
return coord, box, fparam, aparam, input_dtype
387420
else:
388-
pp = self.global_np_float_precision
389421
return (
390-
coord.astype(pp),
391-
box.astype(pp) if box is not None else None,
392-
fparam.astype(pp) if fparam is not None else None,
393-
aparam.astype(pp) if aparam is not None else None,
394-
input_prec,
422+
xp.astype(coord, global_dtype),
423+
xp.astype(box, global_dtype) if box is not None else None,
424+
xp.astype(fparam, global_dtype) if fparam is not None else None,
425+
xp.astype(aparam, global_dtype) if aparam is not None else None,
426+
input_dtype,
395427
)
396428

397-
def output_type_cast(
429+
def _output_type_cast(
398430
self,
399431
model_ret: dict[str, Array],
400-
input_prec: str,
432+
input_prec: Any,
401433
) -> dict[str, Array]:
402-
"""Convert the model output to the input prec."""
403-
do_cast = (
404-
input_prec != RESERVED_PRECISION_DICT[self.global_np_float_precision]
434+
"""Convert the model output to the input prec.
435+
436+
Parameters
437+
----------
438+
model_ret
439+
The model output.
440+
input_prec
441+
The input dtype returned by ``_input_type_cast``.
442+
"""
443+
model_ret_not_none = [vv for vv in model_ret.values() if vv is not None]
444+
if not model_ret_not_none:
445+
return model_ret
446+
xp = array_api_compat.array_namespace(model_ret_not_none[0])
447+
global_dtype = get_xp_precision(
448+
xp, RESERVED_PRECISION_DICT[self.global_np_float_precision]
449+
)
450+
ener_dtype = get_xp_precision(
451+
xp, RESERVED_PRECISION_DICT[self.global_ener_float_precision]
405452
)
406-
pp = self.precision_dict[input_prec]
453+
do_cast = input_prec != global_dtype
407454
odef = self.model_output_def()
408455
for kk in odef.keys():
409456
if kk not in model_ret.keys():
410457
# do not return energy_derv_c if not do_atomic_virial
411458
continue
412459
if check_operation_applied(odef[kk], OutputVariableOperation.REDU):
413460
model_ret[kk] = (
414-
model_ret[kk].astype(self.global_ener_float_precision)
461+
xp.astype(model_ret[kk], ener_dtype)
415462
if model_ret[kk] is not None
416463
else None
417464
)
418465
elif do_cast:
419466
model_ret[kk] = (
420-
model_ret[kk].astype(pp) if model_ret[kk] is not None else None
467+
xp.astype(model_ret[kk], input_prec)
468+
if model_ret[kk] is not None
469+
else None
421470
)
422471
return model_ret
423472

deepmd/dpmodel/model/transform_output.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def communicate_extended_output(
9898
9999
"""
100100
xp = array_api_compat.get_namespace(mapping)
101+
device = array_api_compat.device(mapping)
101102
mapping_ = mapping
102103
new_ret = {}
103104
for kk in model_output_def.keys_outp():
@@ -117,7 +118,9 @@ def communicate_extended_output(
117118
mapping, tuple(mldims + [1] * len(derv_r_ext_dims))
118119
)
119120
mapping = xp.tile(mapping, [1] * len(mldims) + derv_r_ext_dims)
120-
force = xp.zeros(vldims + derv_r_ext_dims, dtype=vv.dtype)
121+
force = xp.zeros(
122+
vldims + derv_r_ext_dims, dtype=vv.dtype, device=device
123+
)
121124
force = xp_scatter_sum(
122125
force,
123126
1,
@@ -149,7 +152,9 @@ def communicate_extended_output(
149152
nall = hess_1.shape[1]
150153
# (1) -> [nf, nloc1, nall2, *def, 3(1), 3(2)]
151154
hessian1 = xp.zeros(
152-
[*vldims, nall, *vdef.shape, 3, 3], dtype=vv.dtype
155+
[*vldims, nall, *vdef.shape, 3, 3],
156+
dtype=vv.dtype,
157+
device=device,
153158
)
154159
mapping_hess = xp.reshape(
155160
mapping_, (mldims + [1] * (len(vdef.shape) + 3))
@@ -172,7 +177,9 @@ def communicate_extended_output(
172177
nloc = hessian1.shape[2]
173178
# (2) -> [nf, nloc2, nloc1, *def, 3(1), 3(2)]
174179
hessian = xp.zeros(
175-
[*vldims, nloc, *vdef.shape, 3, 3], dtype=vv.dtype
180+
[*vldims, nloc, *vdef.shape, 3, 3],
181+
dtype=vv.dtype,
182+
device=device,
176183
)
177184
mapping_hess = xp.reshape(
178185
mapping_, (mldims + [1] * (len(vdef.shape) + 3))
@@ -218,6 +225,7 @@ def communicate_extended_output(
218225
virial = xp.zeros(
219226
vldims + derv_c_ext_dims,
220227
dtype=vv.dtype,
228+
device=device,
221229
)
222230
virial = xp_scatter_sum(
223231
virial,

deepmd/dpmodel/utils/network.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -280,11 +280,11 @@ def call(self, x): # noqa: ANN001, ANN201
280280
y = xp.astype(y, x.dtype)
281281
y = fn(y)
282282
if self.idt is not None:
283-
y *= self.idt
283+
y = y * self.idt
284284
if self.resnet and self.w.shape[1] == self.w.shape[0]:
285-
y += x
285+
y = y + x
286286
elif self.resnet and self.w.shape[1] == 2 * self.w.shape[0]:
287-
y += xp.concat([x, x], axis=-1)
287+
y = y + xp.concat([x, x], axis=-1)
288288
return y
289289

290290

deepmd/pd/model/model/ener_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def translated_output_def(self) -> dict:
6060
output_def["virial"] = out_def_data["energy_derv_c_redu"]
6161
output_def["virial"].squeeze(-2)
6262
output_def["atom_virial"] = out_def_data["energy_derv_c"]
63-
output_def["atom_virial"].squeeze(-3)
63+
output_def["atom_virial"].squeeze(-2)
6464
if "mask" in out_def_data:
6565
output_def["mask"] = out_def_data["mask"]
6666
return output_def
@@ -140,7 +140,7 @@ def forward_lower(
140140
if do_atomic_virial:
141141
model_predict["extended_virial"] = model_ret[
142142
"energy_derv_c"
143-
].squeeze(-3)
143+
].squeeze(-2)
144144
else:
145145
model_predict["extended_virial"] = paddle.zeros(
146146
[model_predict["energy"].shape[0], 1, 9], dtype=paddle.float64

deepmd/pd/model/model/make_model.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def forward_common(
162162
The keys are defined by the `ModelOutputDef`.
163163
164164
"""
165-
cc, bb, fp, ap, input_prec = self.input_type_cast(
165+
cc, bb, fp, ap, input_prec = self._input_type_cast(
166166
coord, box=box, fparam=fparam, aparam=aparam
167167
)
168168
del coord, box, fparam, aparam
@@ -196,7 +196,7 @@ def forward_common(
196196
mapping,
197197
do_atomic_virial=do_atomic_virial,
198198
)
199-
model_predict = self.output_type_cast(model_predict, input_prec)
199+
model_predict = self._output_type_cast(model_predict, input_prec)
200200
return model_predict
201201

202202
def get_out_bias(self) -> paddle.Tensor:
@@ -283,7 +283,7 @@ def forward_common_lower(
283283
nlist = self.format_nlist(
284284
extended_coord, extended_atype, nlist, extra_nlist_sort=extra_nlist_sort
285285
)
286-
cc_ext, _, fp, ap, input_prec = self.input_type_cast(
286+
cc_ext, _, fp, ap, input_prec = self._input_type_cast(
287287
extended_coord, fparam=fparam, aparam=aparam
288288
)
289289
del extended_coord, fparam, aparam
@@ -303,10 +303,10 @@ def forward_common_lower(
303303
do_atomic_virial=do_atomic_virial,
304304
create_graph=self.training,
305305
)
306-
model_predict = self.output_type_cast(model_predict, input_prec)
306+
model_predict = self._output_type_cast(model_predict, input_prec)
307307
return model_predict
308308

309-
def input_type_cast(
309+
def _input_type_cast(
310310
self,
311311
coord: paddle.Tensor,
312312
box: paddle.Tensor | None = None,
@@ -351,7 +351,7 @@ def input_type_cast(
351351
input_prec,
352352
)
353353

354-
def output_type_cast(
354+
def _output_type_cast(
355355
self,
356356
model_ret: dict[str, paddle.Tensor],
357357
input_prec: str,

deepmd/pt/model/atomic_model/dp_atomic_model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ def set_eval_descriptor_hook(self, enable: bool) -> None:
8383

8484
def eval_descriptor(self) -> torch.Tensor:
8585
"""Evaluate the descriptor."""
86+
if not self.eval_descriptor_list:
87+
raise RuntimeError(
88+
"eval_descriptor_list is empty. "
89+
"Call set_eval_descriptor_hook(True) and perform a forward pass first."
90+
)
8691
return torch.concat(self.eval_descriptor_list)
8792

8893
def set_eval_fitting_last_layer_hook(self, enable: bool) -> None:
@@ -94,6 +99,11 @@ def set_eval_fitting_last_layer_hook(self, enable: bool) -> None:
9499

95100
def eval_fitting_last_layer(self) -> torch.Tensor:
96101
"""Evaluate the fitting last layer output."""
102+
if not self.eval_fitting_last_layer_list:
103+
raise RuntimeError(
104+
"eval_fitting_last_layer_list is empty. "
105+
"Call set_eval_fitting_last_layer_hook(True) and perform a forward pass first."
106+
)
97107
return torch.concat(self.eval_fitting_last_layer_list)
98108

99109
@torch.jit.export

0 commit comments

Comments
 (0)