diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index fabc39ae96..260be619fd 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -584,6 +584,7 @@ def _call_common( ) # calculate the prediction + results: dict[str, Array] = {} if not self.mixed_types: outs = xp.zeros( [nf, nloc, net_dim_out], @@ -622,4 +623,5 @@ def _call_common( exclude_mask = xp.astype(exclude_mask, xp.bool) # nf x nloc x nod outs = xp.where(exclude_mask[:, :, None], outs, xp.zeros_like(outs)) - return {self.var_name: outs} + results[self.var_name] = outs + return results diff --git a/deepmd/dpmodel/infer/deep_eval.py b/deepmd/dpmodel/infer/deep_eval.py index 8088ba1d2f..c80898ec74 100644 --- a/deepmd/dpmodel/infer/deep_eval.py +++ b/deepmd/dpmodel/infer/deep_eval.py @@ -358,9 +358,7 @@ def _eval_model( results = [] for odef in request_defs: - # it seems not doing conversion - # dp_name = self._OUTDEF_DP2BACKEND[odef.name] - dp_name = odef.name + dp_name = self._OUTDEF_DP2BACKEND[odef.name] if dp_name in batch_output: shape = self._get_output_shape(odef, nframes, natoms) if batch_output[dp_name] is not None: diff --git a/deepmd/dpmodel/model/dipole_model.py b/deepmd/dpmodel/model/dipole_model.py index d213514551..def0156b2e 100644 --- a/deepmd/dpmodel/model/dipole_model.py +++ b/deepmd/dpmodel/model/dipole_model.py @@ -3,6 +3,9 @@ Any, ) +from deepmd.dpmodel.array_api import ( + Array, +) from deepmd.dpmodel.atomic_model import ( DPDipoleAtomicModel, ) @@ -31,3 +34,90 @@ def __init__( ) -> None: DPModelCommon.__init__(self) DPDipoleModel_.__init__(self, *args, **kwargs) + + def translated_output_def(self) -> dict[str, Any]: + out_def_data = self.model_output_def().get_data() + output_def = { + "dipole": out_def_data["dipole"], + "global_dipole": out_def_data["dipole_redu"], + } + if self.do_grad_r("dipole"): + output_def["force"] = out_def_data["dipole_derv_r"] + output_def["force"].squeeze(-2) + if self.do_grad_c("dipole"): + output_def["virial"] = out_def_data["dipole_derv_c_redu"] + output_def["virial"].squeeze(-2) + output_def["atom_virial"] = out_def_data["dipole_derv_c"] + output_def["atom_virial"].squeeze(-2) + if "mask" in out_def_data: + output_def["mask"] = out_def_data["mask"] + return output_def + + def call( + self, + coord: Array, + atype: Array, + box: Array | None = None, + fparam: Array | None = None, + aparam: Array | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, Array]: + model_ret = self.call_common( + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + if self.get_fitting_net() is not None: + model_predict = {} + model_predict["dipole"] = model_ret["dipole"] + model_predict["global_dipole"] = model_ret["dipole_redu"] + if self.do_grad_r("dipole"): + model_predict["force"] = model_ret.get("dipole_derv_r") + if self.do_grad_c("dipole"): + model_predict["virial"] = model_ret.get("dipole_derv_c_redu") + if do_atomic_virial: + model_predict["atom_virial"] = model_ret.get("dipole_derv_c") + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + else: + model_predict = model_ret + model_predict["updated_coord"] += coord + return model_predict + + def call_lower( + self, + extended_coord: Array, + extended_atype: Array, + nlist: Array, + mapping: Array | None = None, + fparam: Array | None = None, + aparam: Array | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, Array]: + model_ret = self.call_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + if self.get_fitting_net() is not None: + model_predict = {} + model_predict["dipole"] = model_ret["dipole"] + model_predict["global_dipole"] = model_ret["dipole_redu"] + if self.do_grad_r("dipole"): + model_predict["extended_force"] = model_ret.get("dipole_derv_r") + if self.do_grad_c("dipole"): + model_predict["virial"] = model_ret.get("dipole_derv_c_redu") + if do_atomic_virial: + model_predict["extended_virial"] = model_ret.get("dipole_derv_c") + else: + model_predict = model_ret + return model_predict + + forward_lower = call_lower diff --git a/deepmd/dpmodel/model/dos_model.py b/deepmd/dpmodel/model/dos_model.py index 5c5d2a5e90..977e621525 100644 --- a/deepmd/dpmodel/model/dos_model.py +++ b/deepmd/dpmodel/model/dos_model.py @@ -3,6 +3,9 @@ Any, ) +from deepmd.dpmodel.array_api import ( + Array, +) from deepmd.dpmodel.atomic_model import ( DPDOSAtomicModel, ) @@ -31,3 +34,70 @@ def __init__( ) -> None: DPModelCommon.__init__(self) DPDOSModel_.__init__(self, *args, **kwargs) + + def translated_output_def(self) -> dict[str, Any]: + out_def_data = self.model_output_def().get_data() + output_def = { + "atom_dos": out_def_data["dos"], + "dos": out_def_data["dos_redu"], + } + if "mask" in out_def_data: + output_def["mask"] = out_def_data["mask"] + return output_def + + def call( + self, + coord: Array, + atype: Array, + box: Array | None = None, + fparam: Array | None = None, + aparam: Array | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, Array]: + model_ret = self.call_common( + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + if self.get_fitting_net() is not None: + model_predict = {} + model_predict["atom_dos"] = model_ret["dos"] + model_predict["dos"] = model_ret["dos_redu"] + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + else: + model_predict = model_ret + model_predict["updated_coord"] += coord + return model_predict + + def call_lower( + self, + extended_coord: Array, + extended_atype: Array, + nlist: Array, + mapping: Array | None = None, + fparam: Array | None = None, + aparam: Array | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, Array]: + model_ret = self.call_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + if self.get_fitting_net() is not None: + model_predict = {} + model_predict["atom_dos"] = model_ret["dos"] + model_predict["dos"] = model_ret["dos_redu"] + else: + model_predict = model_ret + return model_predict + + forward_lower = call_lower diff --git a/deepmd/dpmodel/model/dp_model.py b/deepmd/dpmodel/model/dp_model.py index 063533f2a7..0dcf6358f9 100644 --- a/deepmd/dpmodel/model/dp_model.py +++ b/deepmd/dpmodel/model/dp_model.py @@ -48,3 +48,7 @@ def update_sel( def get_fitting_net(self) -> BaseFitting: """Get the fitting network.""" return self.atomic_model.fitting + + def get_descriptor(self) -> BaseDescriptor: + """Get the descriptor.""" + return self.atomic_model.descriptor diff --git a/deepmd/dpmodel/model/dp_zbl_model.py b/deepmd/dpmodel/model/dp_zbl_model.py index b5940f4707..c04ae5be64 100644 --- a/deepmd/dpmodel/model/dp_zbl_model.py +++ b/deepmd/dpmodel/model/dp_zbl_model.py @@ -3,6 +3,9 @@ Any, ) +from deepmd.dpmodel.array_api import ( + Array, +) from deepmd.dpmodel.atomic_model.linear_atomic_model import ( DPZBLLinearEnergyAtomicModel, ) @@ -34,6 +37,114 @@ def __init__( ) -> None: super().__init__(*args, **kwargs) + def translated_output_def(self) -> dict[str, Any]: + out_def_data = self.model_output_def().get_data() + output_def = { + "atom_energy": out_def_data["energy"], + "energy": out_def_data["energy_redu"], + } + if self.do_grad_r("energy"): + output_def["force"] = out_def_data["energy_derv_r"] + output_def["force"].squeeze(-2) + if self.do_grad_c("energy"): + output_def["virial"] = out_def_data["energy_derv_c_redu"] + output_def["virial"].squeeze(-2) + output_def["atom_virial"] = out_def_data["energy_derv_c"] + output_def["atom_virial"].squeeze(-2) + if "mask" in out_def_data: + output_def["mask"] = out_def_data["mask"] + return output_def + + def call( + self, + coord: Array, + atype: Array, + box: Array | None = None, + fparam: Array | None = None, + aparam: Array | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, Array]: + model_ret = self.call_common( + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + if self.do_grad_r("energy"): + if model_ret.get("energy_derv_r") is not None: + model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) + else: + model_predict["force"] = model_ret.get("energy_derv_r") + if self.do_grad_c("energy"): + derv_c_redu = model_ret.get("energy_derv_c_redu") + if derv_c_redu is not None: + model_predict["virial"] = derv_c_redu.squeeze(-2) + else: + model_predict["virial"] = derv_c_redu + if do_atomic_virial: + derv_c = model_ret.get("energy_derv_c") + if derv_c is not None: + model_predict["atom_virial"] = derv_c.squeeze(-2) + else: + model_predict["atom_virial"] = derv_c + else: + if model_ret.get("dforce") is not None: + model_predict["force"] = model_ret["dforce"] + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + return model_predict + + def call_lower( + self, + extended_coord: Array, + extended_atype: Array, + nlist: Array, + mapping: Array | None = None, + fparam: Array | None = None, + aparam: Array | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, Array]: + model_ret = self.call_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + if self.do_grad_r("energy"): + if model_ret.get("energy_derv_r") is not None: + model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2) + else: + model_predict["extended_force"] = model_ret.get("energy_derv_r") + if self.do_grad_c("energy"): + derv_c_redu = model_ret.get("energy_derv_c_redu") + if derv_c_redu is not None: + model_predict["virial"] = derv_c_redu.squeeze(-2) + else: + model_predict["virial"] = derv_c_redu + if do_atomic_virial: + derv_c = model_ret.get("energy_derv_c") + if derv_c is not None: + model_predict["extended_virial"] = derv_c.squeeze(-2) + else: + model_predict["extended_virial"] = derv_c + else: + if model_ret.get("dforce") is not None: + model_predict["dforce"] = model_ret["dforce"] + return model_predict + + forward_lower = call_lower + @classmethod def update_sel( cls, diff --git a/deepmd/dpmodel/model/ener_model.py b/deepmd/dpmodel/model/ener_model.py index 9d38a17513..bda12d8f4b 100644 --- a/deepmd/dpmodel/model/ener_model.py +++ b/deepmd/dpmodel/model/ener_model.py @@ -6,6 +6,9 @@ Any, ) +from deepmd.dpmodel.array_api import ( + Array, +) from deepmd.dpmodel.atomic_model import ( DPEnergyAtomicModel, ) @@ -47,3 +50,125 @@ def atomic_output_def(self) -> FittingOutputDef: if self._enable_hessian: return self.hess_fitting_def return super().atomic_output_def() + + def call_lower( + self, + extended_coord: Array, + extended_atype: Array, + nlist: Array, + mapping: Array | None = None, + fparam: Array | None = None, + aparam: Array | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, Array]: + model_ret = self.call_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + if self.do_grad_r("energy"): + if model_ret["energy_derv_r"] is not None: + model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2) + else: + model_predict["extended_force"] = model_ret["energy_derv_r"] + if self.do_grad_c("energy"): + derv_c_redu = model_ret.get("energy_derv_c_redu") + if derv_c_redu is not None: + model_predict["virial"] = derv_c_redu.squeeze(-2) + else: + model_predict["virial"] = derv_c_redu + if do_atomic_virial: + if model_ret["energy_derv_c"] is not None: + model_predict["extended_virial"] = model_ret[ + "energy_derv_c" + ].squeeze(-2) + else: + model_predict["extended_virial"] = model_ret["energy_derv_c"] + else: + if model_ret.get("dforce") is not None: + model_predict["dforce"] = model_ret["dforce"] + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + if self._enable_hessian: + model_predict["hessian"] = model_ret["energy_derv_r_derv_r"] + return model_predict + + def call( + self, + coord: Array, + atype: Array, + box: Array | None = None, + fparam: Array | None = None, + aparam: Array | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, Array]: + model_ret = self.call_common( + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + if self.do_grad_r("energy"): + if model_ret.get("energy_derv_r") is not None: + model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) + else: + model_predict["force"] = model_ret.get("energy_derv_r") + if self.do_grad_c("energy"): + derv_c_redu = model_ret.get("energy_derv_c_redu") + if derv_c_redu is not None: + model_predict["virial"] = derv_c_redu.squeeze(-2) + else: + model_predict["virial"] = derv_c_redu + if do_atomic_virial: + derv_c = model_ret.get("energy_derv_c") + if derv_c is not None: + model_predict["atom_virial"] = derv_c.squeeze(-2) + else: + model_predict["atom_virial"] = derv_c + else: + if model_ret.get("dforce") is not None: + model_predict["force"] = model_ret["dforce"] + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + if self._enable_hessian: + model_predict["hessian"] = model_ret["energy_derv_r_derv_r"] + return model_predict + + forward_lower = call_lower + + def translated_output_def(self) -> dict[str, Any]: + """Get the translated output definition. + + Maps internal output names to user-facing names, e.g. + ``energy_redu`` -> ``energy``, ``energy_derv_r`` -> ``force``. + """ + out_def_data = self.model_output_def().get_data() + output_def = { + "atom_energy": out_def_data["energy"], + "energy": out_def_data["energy_redu"], + } + if self.do_grad_r("energy"): + output_def["force"] = out_def_data["energy_derv_r"] + output_def["force"].squeeze(-2) + if self.do_grad_c("energy"): + output_def["virial"] = out_def_data["energy_derv_c_redu"] + output_def["virial"].squeeze(-2) + output_def["atom_virial"] = out_def_data["energy_derv_c"] + output_def["atom_virial"].squeeze(-2) + if "mask" in out_def_data: + output_def["mask"] = out_def_data["mask"] + if self._enable_hessian: + output_def["hessian"] = out_def_data["energy_derv_r_derv_r"] + return output_def diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index cc9dd12fc5..886518be9d 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -21,6 +21,7 @@ PRECISION_DICT, RESERVED_PRECISION_DICT, NativeOP, + get_xp_precision, ) from deepmd.dpmodel.model.base_model import ( BaseModel, @@ -103,7 +104,8 @@ def model_call_from_call_lower( bb.reshape(nframes, 3, 3), ) else: - coord_normalized = cc.copy() + xp = array_api_compat.array_namespace(cc) + coord_normalized = xp.reshape(cc, (nframes, nloc, 3)) extended_coord, extended_atype, mapping = extend_coord_with_ghosts( coord_normalized, atype, bb, rcut ) @@ -221,7 +223,7 @@ def enable_compression( check_frequency, ) - def call( + def call_common( self, coord: Array, atype: Array, @@ -230,7 +232,7 @@ def call( aparam: Array | None = None, do_atomic_virial: bool = False, ) -> dict[str, Array]: - """Return model prediction. + """Return model prediction with raw internal keys. Parameters ---------- @@ -255,12 +257,12 @@ def call( The keys are defined by the `ModelOutputDef`. """ - cc, bb, fp, ap, input_prec = self.input_type_cast( + cc, bb, fp, ap, input_prec = self._input_type_cast( coord, box=box, fparam=fparam, aparam=aparam ) del coord, box, fparam, aparam model_predict = model_call_from_call_lower( - call_lower=self.call_lower, + call_lower=self.call_common_lower, rcut=self.get_rcut(), sel=self.get_sel(), mixed_types=self.mixed_types(), @@ -272,10 +274,10 @@ def call( aparam=ap, do_atomic_virial=do_atomic_virial, ) - model_predict = self.output_type_cast(model_predict, input_prec) + model_predict = self._output_type_cast(model_predict, input_prec) return model_predict - def call_lower( + def call_common_lower( self, extended_coord: Array, extended_atype: Array, @@ -321,7 +323,7 @@ def call_lower( nlist, extra_nlist_sort=self.need_sorted_nlist_for_lower(), ) - cc_ext, _, fp, ap, input_prec = self.input_type_cast( + cc_ext, _, fp, ap, input_prec = self._input_type_cast( extended_coord, fparam=fparam, aparam=aparam ) del extended_coord, fparam, aparam @@ -334,7 +336,7 @@ def call_lower( aparam=ap, do_atomic_virial=do_atomic_virial, ) - model_predict = self.output_type_cast(model_predict, input_prec) + model_predict = self._output_type_cast(model_predict, input_prec) return model_predict def forward_common_atomic( @@ -363,47 +365,93 @@ def forward_common_atomic( mask=atomic_ret["mask"] if "mask" in atomic_ret else None, ) - forward_lower = call_lower + call = call_common + call_lower = call_common_lower + forward_common = call_common + forward_common_lower = call_common_lower - def input_type_cast( + def get_out_bias(self) -> Array: + """Get the output bias.""" + return self.atomic_model.out_bias + + def set_out_bias(self, out_bias: Array) -> None: + """Set the output bias.""" + self.atomic_model.out_bias = out_bias + + def change_out_bias( + self, + merged: Any, + bias_adjust_mode: str = "change-by-statistic", + ) -> None: + """Change the output bias according to the input data and the pretrained model. + + Parameters + ---------- + merged + The merged data samples. + bias_adjust_mode : str + The mode for changing output bias: + 'change-by-statistic' or 'set-by-statistic'. + """ + self.atomic_model.change_out_bias(merged, bias_adjust_mode=bias_adjust_mode) + + def _input_type_cast( self, coord: Array, box: Array | None = None, fparam: Array | None = None, aparam: Array | None = None, - ) -> tuple[Array, Array, np.ndarray | None, np.ndarray | None, str]: + ) -> tuple[Array, Array | None, Array | None, Array | None, Any]: """Cast the input data to global float type.""" - input_prec = RESERVED_PRECISION_DICT[self.precision_dict[coord.dtype.name]] + xp = array_api_compat.array_namespace(coord) + input_dtype = coord.dtype + global_dtype = get_xp_precision( + xp, RESERVED_PRECISION_DICT[self.global_np_float_precision] + ) ### ### type checking would not pass jit, convert to coord prec anyway ### - _lst: list[np.ndarray | None] = [ - vv.astype(coord.dtype) if vv is not None else None + _lst: list[Array | None] = [ + xp.astype(vv, input_dtype) if vv is not None else None for vv in [box, fparam, aparam] ] box, fparam, aparam = _lst - if input_prec == RESERVED_PRECISION_DICT[self.global_np_float_precision]: - return coord, box, fparam, aparam, input_prec + if input_dtype == global_dtype: + return coord, box, fparam, aparam, input_dtype else: - pp = self.global_np_float_precision return ( - coord.astype(pp), - box.astype(pp) if box is not None else None, - fparam.astype(pp) if fparam is not None else None, - aparam.astype(pp) if aparam is not None else None, - input_prec, + xp.astype(coord, global_dtype), + xp.astype(box, global_dtype) if box is not None else None, + xp.astype(fparam, global_dtype) if fparam is not None else None, + xp.astype(aparam, global_dtype) if aparam is not None else None, + input_dtype, ) - def output_type_cast( + def _output_type_cast( self, model_ret: dict[str, Array], - input_prec: str, + input_prec: Any, ) -> dict[str, Array]: - """Convert the model output to the input prec.""" - do_cast = ( - input_prec != RESERVED_PRECISION_DICT[self.global_np_float_precision] + """Convert the model output to the input prec. + + Parameters + ---------- + model_ret + The model output. + input_prec + The input dtype returned by ``_input_type_cast``. + """ + model_ret_not_none = [vv for vv in model_ret.values() if vv is not None] + if not model_ret_not_none: + return model_ret + xp = array_api_compat.array_namespace(model_ret_not_none[0]) + global_dtype = get_xp_precision( + xp, RESERVED_PRECISION_DICT[self.global_np_float_precision] + ) + ener_dtype = get_xp_precision( + xp, RESERVED_PRECISION_DICT[self.global_ener_float_precision] ) - pp = self.precision_dict[input_prec] + do_cast = input_prec != global_dtype odef = self.model_output_def() for kk in odef.keys(): if kk not in model_ret.keys(): @@ -411,13 +459,15 @@ def output_type_cast( continue if check_operation_applied(odef[kk], OutputVariableOperation.REDU): model_ret[kk] = ( - model_ret[kk].astype(self.global_ener_float_precision) + xp.astype(model_ret[kk], ener_dtype) if model_ret[kk] is not None else None ) elif do_cast: model_ret[kk] = ( - model_ret[kk].astype(pp) if model_ret[kk] is not None else None + xp.astype(model_ret[kk], input_prec) + if model_ret[kk] is not None + else None ) return model_ret diff --git a/deepmd/dpmodel/model/polar_model.py b/deepmd/dpmodel/model/polar_model.py index b898eababd..e6a4cb304e 100644 --- a/deepmd/dpmodel/model/polar_model.py +++ b/deepmd/dpmodel/model/polar_model.py @@ -3,6 +3,9 @@ Any, ) +from deepmd.dpmodel.array_api import ( + Array, +) from deepmd.dpmodel.atomic_model import ( DPPolarAtomicModel, ) @@ -31,3 +34,70 @@ def __init__( ) -> None: DPModelCommon.__init__(self) DPPolarModel_.__init__(self, *args, **kwargs) + + def translated_output_def(self) -> dict[str, Any]: + out_def_data = self.model_output_def().get_data() + output_def = { + "polar": out_def_data["polarizability"], + "global_polar": out_def_data["polarizability_redu"], + } + if "mask" in out_def_data: + output_def["mask"] = out_def_data["mask"] + return output_def + + def call( + self, + coord: Array, + atype: Array, + box: Array | None = None, + fparam: Array | None = None, + aparam: Array | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, Array]: + model_ret = self.call_common( + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + if self.get_fitting_net() is not None: + model_predict = {} + model_predict["polar"] = model_ret["polarizability"] + model_predict["global_polar"] = model_ret["polarizability_redu"] + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + else: + model_predict = model_ret + model_predict["updated_coord"] += coord + return model_predict + + def call_lower( + self, + extended_coord: Array, + extended_atype: Array, + nlist: Array, + mapping: Array | None = None, + fparam: Array | None = None, + aparam: Array | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, Array]: + model_ret = self.call_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + if self.get_fitting_net() is not None: + model_predict = {} + model_predict["polar"] = model_ret["polarizability"] + model_predict["global_polar"] = model_ret["polarizability_redu"] + else: + model_predict = model_ret + return model_predict + + forward_lower = call_lower diff --git a/deepmd/dpmodel/model/property_model.py b/deepmd/dpmodel/model/property_model.py index 20c884cd20..8324561045 100644 --- a/deepmd/dpmodel/model/property_model.py +++ b/deepmd/dpmodel/model/property_model.py @@ -3,6 +3,9 @@ Any, ) +from deepmd.dpmodel.array_api import ( + Array, +) from deepmd.dpmodel.atomic_model import ( DPPropertyAtomicModel, ) @@ -33,3 +36,68 @@ def __init__( def get_var_name(self) -> str: """Get the name of the property.""" return self.get_fitting_net().var_name + + def translated_output_def(self) -> dict[str, Any]: + out_def_data = self.model_output_def().get_data() + var_name = self.get_var_name() + output_def = { + f"atom_{var_name}": out_def_data[var_name], + var_name: out_def_data[f"{var_name}_redu"], + } + if "mask" in out_def_data: + output_def["mask"] = out_def_data["mask"] + return output_def + + def call( + self, + coord: Array, + atype: Array, + box: Array | None = None, + fparam: Array | None = None, + aparam: Array | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, Array]: + model_ret = self.call_common( + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + var_name = self.get_var_name() + model_predict = {} + model_predict[f"atom_{var_name}"] = model_ret[var_name] + model_predict[var_name] = model_ret[f"{var_name}_redu"] + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + return model_predict + + def call_lower( + self, + extended_coord: Array, + extended_atype: Array, + nlist: Array, + mapping: Array | None = None, + fparam: Array | None = None, + aparam: Array | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, Array]: + model_ret = self.call_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + var_name = self.get_var_name() + model_predict = {} + model_predict[f"atom_{var_name}"] = model_ret[var_name] + model_predict[var_name] = model_ret[f"{var_name}_redu"] + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + return model_predict + + forward_lower = call_lower diff --git a/deepmd/dpmodel/model/spin_model.py b/deepmd/dpmodel/model/spin_model.py index 521978bdde..72bc9cdbbd 100644 --- a/deepmd/dpmodel/model/spin_model.py +++ b/deepmd/dpmodel/model/spin_model.py @@ -309,12 +309,19 @@ def model_output_def(self) -> ModelOutputDef: backbone_model_atomic_output_def[var_name].magnetic = True return ModelOutputDef(backbone_model_atomic_output_def) + def translated_output_def(self) -> dict: + """Get the translated output definition. + + SpinModel returns raw keys from call(), so translated_output_def + returns the raw model output definition. + """ + return self.model_output_def().get_data() + def __getattr__(self, name: str) -> Any: """Get attribute from the wrapped model.""" - if name in self.__dict__: - return self.__dict__[name] - else: - return getattr(self.backbone_model, name) + if "backbone_model" not in self.__dict__: + raise AttributeError(name) + return getattr(self.backbone_model, name) def serialize(self) -> dict: return { @@ -377,7 +384,7 @@ def call( coord_updated, atype_updated = self.process_spin_input(coord, atype, spin) if aparam is not None: aparam = self.expand_aparam(aparam, nloc * 2) - model_predict = self.backbone_model.call( + model_predict = self.backbone_model.call_common( coord_updated, atype_updated, box, @@ -447,7 +454,7 @@ def call_lower( ) if aparam is not None: aparam = self.expand_aparam(aparam, nloc * 2) - model_predict = self.backbone_model.call_lower( + model_predict = self.backbone_model.call_common_lower( extended_coord_updated, extended_atype_updated, nlist_updated, diff --git a/deepmd/dpmodel/model/transform_output.py b/deepmd/dpmodel/model/transform_output.py index 028f3e5f0f..b697898896 100644 --- a/deepmd/dpmodel/model/transform_output.py +++ b/deepmd/dpmodel/model/transform_output.py @@ -98,6 +98,7 @@ def communicate_extended_output( """ xp = array_api_compat.get_namespace(mapping) + device = array_api_compat.device(mapping) mapping_ = mapping new_ret = {} for kk in model_output_def.keys_outp(): @@ -117,7 +118,9 @@ def communicate_extended_output( mapping, tuple(mldims + [1] * len(derv_r_ext_dims)) ) mapping = xp.tile(mapping, [1] * len(mldims) + derv_r_ext_dims) - force = xp.zeros(vldims + derv_r_ext_dims, dtype=vv.dtype) + force = xp.zeros( + vldims + derv_r_ext_dims, dtype=vv.dtype, device=device + ) force = xp_scatter_sum( force, 1, @@ -149,7 +152,9 @@ def communicate_extended_output( nall = hess_1.shape[1] # (1) -> [nf, nloc1, nall2, *def, 3(1), 3(2)] hessian1 = xp.zeros( - [*vldims, nall, *vdef.shape, 3, 3], dtype=vv.dtype + [*vldims, nall, *vdef.shape, 3, 3], + dtype=vv.dtype, + device=device, ) mapping_hess = xp.reshape( mapping_, (mldims + [1] * (len(vdef.shape) + 3)) @@ -172,7 +177,9 @@ def communicate_extended_output( nloc = hessian1.shape[2] # (2) -> [nf, nloc2, nloc1, *def, 3(1), 3(2)] hessian = xp.zeros( - [*vldims, nloc, *vdef.shape, 3, 3], dtype=vv.dtype + [*vldims, nloc, *vdef.shape, 3, 3], + dtype=vv.dtype, + device=device, ) mapping_hess = xp.reshape( mapping_, (mldims + [1] * (len(vdef.shape) + 3)) @@ -218,6 +225,7 @@ def communicate_extended_output( virial = xp.zeros( vldims + derv_c_ext_dims, dtype=vv.dtype, + device=device, ) virial = xp_scatter_sum( virial, diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index 785a42e248..b385ce6005 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -280,11 +280,11 @@ def call(self, x): # noqa: ANN001, ANN201 y = xp.astype(y, x.dtype) y = fn(y) if self.idt is not None: - y *= self.idt + y = y * self.idt if self.resnet and self.w.shape[1] == self.w.shape[0]: - y += x + y = y + x elif self.resnet and self.w.shape[1] == 2 * self.w.shape[0]: - y += xp.concat([x, x], axis=-1) + y = y + xp.concat([x, x], axis=-1) return y diff --git a/deepmd/jax/infer/deep_eval.py b/deepmd/jax/infer/deep_eval.py index 4008d75a53..184a1e6a3e 100644 --- a/deepmd/jax/infer/deep_eval.py +++ b/deepmd/jax/infer/deep_eval.py @@ -388,8 +388,6 @@ def _eval_model( results = [] for odef in request_defs: - # it seems not doing conversion - # dp_name = self._OUTDEF_DP2BACKEND[odef.name] dp_name = odef.name if dp_name in batch_output: shape = self._get_output_shape(odef, nframes, natoms) diff --git a/deepmd/jax/jax2tf/serialization.py b/deepmd/jax/jax2tf/serialization.py index 31d0d7eb82..4881ca98f8 100644 --- a/deepmd/jax/jax2tf/serialization.py +++ b/deepmd/jax/jax2tf/serialization.py @@ -34,7 +34,7 @@ def deserialize_to_file(model_file: str, data: dict) -> None: if model_file.endswith(".savedmodel"): model = BaseModel.deserialize(data["model"]) model_def_script = data["model_def_script"] - call_lower = model.call_lower + call_lower = model.call_common_lower tf_model = tf.Module() diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index 5d3432aab8..14386d9f3d 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -49,7 +49,7 @@ def deserialize_to_file(model_file: str, data: dict) -> None: elif model_file.endswith(".hlo"): model = BaseModel.deserialize(data["model"]) model_def_script = data["model_def_script"] - call_lower = model.call_lower + call_lower = model.call_common_lower nf, nloc, nghost = jax_export.symbolic_shape("nf, nloc, nghost") diff --git a/deepmd/pd/model/model/ener_model.py b/deepmd/pd/model/model/ener_model.py index 3a57e79d3a..8813fe3e94 100644 --- a/deepmd/pd/model/model/ener_model.py +++ b/deepmd/pd/model/model/ener_model.py @@ -60,7 +60,7 @@ def translated_output_def(self) -> dict: output_def["virial"] = out_def_data["energy_derv_c_redu"] output_def["virial"].squeeze(-2) output_def["atom_virial"] = out_def_data["energy_derv_c"] - output_def["atom_virial"].squeeze(-3) + output_def["atom_virial"].squeeze(-2) if "mask" in out_def_data: output_def["mask"] = out_def_data["mask"] return output_def @@ -92,7 +92,7 @@ def forward( model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) if do_atomic_virial: model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze( - -3 + -2 ) else: model_predict["atom_virial"] = paddle.zeros( @@ -140,7 +140,7 @@ def forward_lower( if do_atomic_virial: model_predict["extended_virial"] = model_ret[ "energy_derv_c" - ].squeeze(-3) + ].squeeze(-2) else: model_predict["extended_virial"] = paddle.zeros( [model_predict["energy"].shape[0], 1, 9], dtype=paddle.float64 diff --git a/deepmd/pd/model/model/make_model.py b/deepmd/pd/model/model/make_model.py index 72811c9e1c..321c939061 100644 --- a/deepmd/pd/model/model/make_model.py +++ b/deepmd/pd/model/model/make_model.py @@ -162,7 +162,7 @@ def forward_common( The keys are defined by the `ModelOutputDef`. """ - cc, bb, fp, ap, input_prec = self.input_type_cast( + cc, bb, fp, ap, input_prec = self._input_type_cast( coord, box=box, fparam=fparam, aparam=aparam ) del coord, box, fparam, aparam @@ -196,7 +196,7 @@ def forward_common( mapping, do_atomic_virial=do_atomic_virial, ) - model_predict = self.output_type_cast(model_predict, input_prec) + model_predict = self._output_type_cast(model_predict, input_prec) return model_predict def get_out_bias(self) -> paddle.Tensor: @@ -283,7 +283,7 @@ def forward_common_lower( nlist = self.format_nlist( extended_coord, extended_atype, nlist, extra_nlist_sort=extra_nlist_sort ) - cc_ext, _, fp, ap, input_prec = self.input_type_cast( + cc_ext, _, fp, ap, input_prec = self._input_type_cast( extended_coord, fparam=fparam, aparam=aparam ) del extended_coord, fparam, aparam @@ -303,10 +303,10 @@ def forward_common_lower( do_atomic_virial=do_atomic_virial, create_graph=self.training, ) - model_predict = self.output_type_cast(model_predict, input_prec) + model_predict = self._output_type_cast(model_predict, input_prec) return model_predict - def input_type_cast( + def _input_type_cast( self, coord: paddle.Tensor, box: paddle.Tensor | None = None, @@ -351,7 +351,7 @@ def input_type_cast( input_prec, ) - def output_type_cast( + def _output_type_cast( self, model_ret: dict[str, paddle.Tensor], input_prec: str, diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index af2e8954df..a71427d5e9 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -83,6 +83,11 @@ def set_eval_descriptor_hook(self, enable: bool) -> None: def eval_descriptor(self) -> torch.Tensor: """Evaluate the descriptor.""" + if not self.eval_descriptor_list: + raise RuntimeError( + "eval_descriptor_list is empty. " + "Call set_eval_descriptor_hook(True) and perform a forward pass first." + ) return torch.concat(self.eval_descriptor_list) def set_eval_fitting_last_layer_hook(self, enable: bool) -> None: @@ -94,6 +99,11 @@ def set_eval_fitting_last_layer_hook(self, enable: bool) -> None: def eval_fitting_last_layer(self) -> torch.Tensor: """Evaluate the fitting last layer output.""" + if not self.eval_fitting_last_layer_list: + raise RuntimeError( + "eval_fitting_last_layer_list is empty. " + "Call set_eval_fitting_last_layer_hook(True) and perform a forward pass first." + ) return torch.concat(self.eval_fitting_last_layer_list) @torch.jit.export diff --git a/deepmd/pt/model/model/dipole_model.py b/deepmd/pt/model/model/dipole_model.py index c6813ce079..0ea2a6544c 100644 --- a/deepmd/pt/model/model/dipole_model.py +++ b/deepmd/pt/model/model/dipole_model.py @@ -47,7 +47,7 @@ def translated_output_def(self) -> dict[str, Any]: output_def["virial"] = out_def_data["dipole_derv_c_redu"] output_def["virial"].squeeze(-2) output_def["atom_virial"] = out_def_data["dipole_derv_c"] - output_def["atom_virial"].squeeze(-3) + output_def["atom_virial"].squeeze(-2) if "mask" in out_def_data: output_def["mask"] = out_def_data["mask"] return output_def @@ -79,7 +79,7 @@ def forward( model_predict["virial"] = model_ret["dipole_derv_c_redu"].squeeze(-2) if do_atomic_virial: model_predict["atom_virial"] = model_ret["dipole_derv_c"].squeeze( - -3 + -2 ) if "mask" in model_ret: model_predict["mask"] = model_ret["mask"] @@ -122,7 +122,7 @@ def forward_lower( if do_atomic_virial: model_predict["extended_virial"] = model_ret[ "dipole_derv_c" - ].squeeze(-3) + ].squeeze(-2) else: model_predict = model_ret return model_predict diff --git a/deepmd/pt/model/model/dp_linear_model.py b/deepmd/pt/model/model/dp_linear_model.py index b43f849258..fadcc71fdd 100644 --- a/deepmd/pt/model/model/dp_linear_model.py +++ b/deepmd/pt/model/model/dp_linear_model.py @@ -52,7 +52,7 @@ def translated_output_def(self) -> dict[str, OutputVariableDef]: output_def["virial"] = out_def_data["energy_derv_c_redu"] output_def["virial"].squeeze(-2) output_def["atom_virial"] = out_def_data["energy_derv_c"] - output_def["atom_virial"].squeeze(-3) + output_def["atom_virial"].squeeze(-2) if "mask" in out_def_data: output_def["mask"] = out_def_data["mask"] return output_def @@ -83,7 +83,7 @@ def forward( if self.do_grad_c("energy"): model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) if do_atomic_virial: - model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3) + model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-2) else: model_predict["force"] = model_ret["dforce"] if "mask" in model_ret: @@ -123,7 +123,7 @@ def forward_lower( model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) if do_atomic_virial: model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze( - -3 + -2 ) else: assert model_ret["dforce"] is not None diff --git a/deepmd/pt/model/model/dp_zbl_model.py b/deepmd/pt/model/model/dp_zbl_model.py index d533cbe125..f44cb926d0 100644 --- a/deepmd/pt/model/model/dp_zbl_model.py +++ b/deepmd/pt/model/model/dp_zbl_model.py @@ -49,7 +49,7 @@ def translated_output_def(self) -> dict[str, Any]: output_def["virial"] = out_def_data["energy_derv_c_redu"] output_def["virial"].squeeze(-2) output_def["atom_virial"] = out_def_data["energy_derv_c"] - output_def["atom_virial"].squeeze(-3) + output_def["atom_virial"].squeeze(-2) if "mask" in out_def_data: output_def["mask"] = out_def_data["mask"] return output_def @@ -80,7 +80,7 @@ def forward( if self.do_grad_c("energy"): model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) if do_atomic_virial: - model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3) + model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-2) else: model_predict["force"] = model_ret["dforce"] if "mask" in model_ret: @@ -120,7 +120,7 @@ def forward_lower( model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) if do_atomic_virial: model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze( - -3 + -2 ) else: assert model_ret["dforce"] is not None diff --git a/deepmd/pt/model/model/ener_model.py b/deepmd/pt/model/model/ener_model.py index 8f8a3cbad7..36beb33ff6 100644 --- a/deepmd/pt/model/model/ener_model.py +++ b/deepmd/pt/model/model/ener_model.py @@ -83,7 +83,7 @@ def translated_output_def(self) -> dict[str, Any]: output_def["virial"] = out_def_data["energy_derv_c_redu"] output_def["virial"].squeeze(-2) output_def["atom_virial"] = out_def_data["energy_derv_c"] - output_def["atom_virial"].squeeze(-3) + output_def["atom_virial"].squeeze(-2) if "mask" in out_def_data: output_def["mask"] = out_def_data["mask"] if self._hessian_enabled: @@ -117,7 +117,7 @@ def forward( model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) if do_atomic_virial: model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze( - -3 + -2 ) else: model_predict["force"] = model_ret["dforce"] @@ -164,7 +164,7 @@ def forward_lower( if do_atomic_virial: model_predict["extended_virial"] = model_ret[ "energy_derv_c" - ].squeeze(-3) + ].squeeze(-2) else: assert model_ret["dforce"] is not None model_predict["dforce"] = model_ret["dforce"] diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index c958a62bf6..87a1d6b9c5 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -164,7 +164,7 @@ def forward_common( The keys are defined by the `ModelOutputDef`. """ - cc, bb, fp, ap, input_prec = self.input_type_cast( + cc, bb, fp, ap, input_prec = self._input_type_cast( coord, box=box, fparam=fparam, aparam=aparam ) del coord, box, fparam, aparam @@ -198,7 +198,7 @@ def forward_common( mapping, do_atomic_virial=do_atomic_virial, ) - model_predict = self.output_type_cast(model_predict, input_prec) + model_predict = self._output_type_cast(model_predict, input_prec) return model_predict def get_out_bias(self) -> torch.Tensor: @@ -285,7 +285,7 @@ def forward_common_lower( nlist = self.format_nlist( extended_coord, extended_atype, nlist, extra_nlist_sort=extra_nlist_sort ) - cc_ext, _, fp, ap, input_prec = self.input_type_cast( + cc_ext, _, fp, ap, input_prec = self._input_type_cast( extended_coord, fparam=fparam, aparam=aparam ) del extended_coord, fparam, aparam @@ -306,10 +306,10 @@ def forward_common_lower( create_graph=self.training, mask=atomic_ret["mask"] if "mask" in atomic_ret else None, ) - model_predict = self.output_type_cast(model_predict, input_prec) + model_predict = self._output_type_cast(model_predict, input_prec) return model_predict - def input_type_cast( + def _input_type_cast( self, coord: torch.Tensor, box: torch.Tensor | None = None, @@ -354,7 +354,7 @@ def input_type_cast( input_prec, ) - def output_type_cast( + def _output_type_cast( self, model_ret: dict[str, torch.Tensor], input_prec: str, diff --git a/deepmd/pt_expt/atomic_model/dp_atomic_model.py b/deepmd/pt_expt/atomic_model/dp_atomic_model.py index 4e7a178557..b87935bd09 100644 --- a/deepmd/pt_expt/atomic_model/dp_atomic_model.py +++ b/deepmd/pt_expt/atomic_model/dp_atomic_model.py @@ -1,14 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from typing import ( - Any, -) import torch from deepmd.dpmodel.atomic_model.dp_atomic_model import DPAtomicModel as DPAtomicModelDP from deepmd.pt_expt.common import ( - dpmodel_setattr, register_dpmodel_mapping, + torch_module, ) from deepmd.pt_expt.descriptor.base_descriptor import ( BaseDescriptor, @@ -18,38 +15,11 @@ ) -class DPAtomicModel(DPAtomicModelDP, torch.nn.Module): +@torch_module +class DPAtomicModel(DPAtomicModelDP): base_descriptor_cls = BaseDescriptor base_fitting_cls = BaseFitting - def __init__( - self, descriptor: Any, fitting: Any, *args: Any, **kwargs: Any - ) -> None: - torch.nn.Module.__init__(self) - # Convert descriptor and fitting to pt_expt versions if they are dpmodel instances - # The dpmodel_setattr mechanism will handle this automatically via registry - from deepmd.pt_expt.common import ( - try_convert_module, - ) - - descriptor_pt = try_convert_module(descriptor) - fitting_pt = try_convert_module(fitting) - # If conversion failed (not registered), use original (assume already pt_expt) - if descriptor_pt is None: - descriptor_pt = descriptor - if fitting_pt is None: - fitting_pt = fitting - DPAtomicModelDP.__init__(self, descriptor_pt, fitting_pt, *args, **kwargs) - - def __call__(self, *args: Any, **kwargs: Any) -> Any: - # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. - return torch.nn.Module.__call__(self, *args, **kwargs) - - def __setattr__(self, name: str, value: Any) -> None: - handled, value = dpmodel_setattr(self, name, value) - if not handled: - super().__setattr__(name, value) - def forward( self, extended_coord: torch.Tensor, diff --git a/deepmd/pt_expt/fitting/ener_fitting.py b/deepmd/pt_expt/fitting/ener_fitting.py index 425040ae75..1c91f09526 100644 --- a/deepmd/pt_expt/fitting/ener_fitting.py +++ b/deepmd/pt_expt/fitting/ener_fitting.py @@ -1,17 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from typing import ( - Any, -) import torch from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP from deepmd.pt_expt.common import ( - dpmodel_setattr, register_dpmodel_mapping, -) -from deepmd.pt_expt.utils.network import ( - NetworkCollection, + torch_module, ) from .base_fitting import ( @@ -20,27 +14,13 @@ @BaseFitting.register("ener") -class EnergyFittingNet(EnergyFittingNetDP, torch.nn.Module): +@torch_module +class EnergyFittingNet(EnergyFittingNetDP): """Energy fitting net for pt_expt backend. This inherits from dpmodel EnergyFittingNet to get the correct serialize() method. """ - def __init__(self, *args: Any, **kwargs: Any) -> None: - torch.nn.Module.__init__(self) - EnergyFittingNetDP.__init__(self, *args, **kwargs) - # Convert dpmodel NetworkCollection to pt_expt NetworkCollection - self.nets = NetworkCollection.deserialize(self.nets.serialize()) - - def __call__(self, *args: Any, **kwargs: Any) -> Any: - # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. - return torch.nn.Module.__call__(self, *args, **kwargs) - - def __setattr__(self, name: str, value: Any) -> None: - handled, value = dpmodel_setattr(self, name, value) - if not handled: - super().__setattr__(name, value) - def forward( self, descriptor: torch.Tensor, diff --git a/deepmd/pt_expt/fitting/invar_fitting.py b/deepmd/pt_expt/fitting/invar_fitting.py index aa37026284..640afe232e 100644 --- a/deepmd/pt_expt/fitting/invar_fitting.py +++ b/deepmd/pt_expt/fitting/invar_fitting.py @@ -1,40 +1,20 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from typing import ( - Any, -) import torch from deepmd.dpmodel.fitting.invar_fitting import InvarFitting as InvarFittingDP from deepmd.pt_expt.common import ( - dpmodel_setattr, register_dpmodel_mapping, + torch_module, ) from deepmd.pt_expt.fitting.base_fitting import ( BaseFitting, ) -from deepmd.pt_expt.utils.network import ( - NetworkCollection, -) @BaseFitting.register("invar") -class InvarFitting(InvarFittingDP, torch.nn.Module): - def __init__(self, *args: Any, **kwargs: Any) -> None: - torch.nn.Module.__init__(self) - InvarFittingDP.__init__(self, *args, **kwargs) - # Convert dpmodel NetworkCollection to pt_expt NetworkCollection - self.nets = NetworkCollection.deserialize(self.nets.serialize()) - - def __call__(self, *args: Any, **kwargs: Any) -> Any: - # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. - return torch.nn.Module.__call__(self, *args, **kwargs) - - def __setattr__(self, name: str, value: Any) -> None: - handled, value = dpmodel_setattr(self, name, value) - if not handled: - super().__setattr__(name, value) - +@torch_module +class InvarFitting(InvarFittingDP): def forward( self, descriptor: torch.Tensor, diff --git a/deepmd/pt_expt/model/__init__.py b/deepmd/pt_expt/model/__init__.py new file mode 100644 index 0000000000..5d1c5ffb5d --- /dev/null +++ b/deepmd/pt_expt/model/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from .ener_model import ( + EnergyModel, +) + +__all__ = [ + "EnergyModel", +] diff --git a/deepmd/pt_expt/model/ener_model.py b/deepmd/pt_expt/model/ener_model.py new file mode 100644 index 0000000000..5f30f3a227 --- /dev/null +++ b/deepmd/pt_expt/model/ener_model.py @@ -0,0 +1,171 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch +from torch.fx.experimental.proxy_tensor import ( + make_fx, +) + +from deepmd.dpmodel.model.dp_model import ( + DPModelCommon, +) +from deepmd.pt_expt.atomic_model import ( + DPEnergyAtomicModel, +) + +from .make_model import ( + make_model, +) + +DPEnergyModel_ = make_model(DPEnergyAtomicModel) + + +class EnergyModel(DPModelCommon, DPEnergyModel_): + model_type = "ener" + + def __init__( + self, + *args: Any, + **kwargs: Any, + ) -> None: + DPModelCommon.__init__(self) + DPEnergyModel_.__init__(self, *args, **kwargs) + + def forward( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, torch.Tensor]: + model_ret = self.call_common( + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + if self.do_grad_r("energy"): + model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) + if self.do_grad_c("energy"): + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + if do_atomic_virial: + model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-2) + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + return model_predict + + def _forward_lower( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, torch.Tensor]: + model_ret = self.call_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + if self.do_grad_r("energy"): + model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2) + if self.do_grad_c("energy"): + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + if do_atomic_virial: + model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze( + -2 + ) + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + return model_predict + + def forward_lower( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, torch.Tensor]: + return self._forward_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + + def forward_lower_exportable( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + ) -> torch.nn.Module: + """Trace ``_forward_lower`` into an exportable module. + + Uses ``make_fx`` to trace through ``torch.autograd.grad``, + decomposing the backward pass into primitive ops. The returned + module can be passed directly to ``torch.export.export``. + + Parameters + ---------- + extended_coord, extended_atype, nlist, mapping, fparam, aparam, do_atomic_virial + Sample inputs with representative shapes (used for tracing). + + Returns + ------- + torch.nn.Module + A traced module whose ``forward`` accepts + ``(extended_coord, extended_atype, nlist, mapping, fparam, aparam)`` + and returns a dict with the same keys as ``forward_lower``. + """ + model = self + + def fn( + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None, + fparam: torch.Tensor | None, + aparam: torch.Tensor | None, + ) -> dict[str, torch.Tensor]: + extended_coord = extended_coord.detach().requires_grad_(True) + return model._forward_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + + return make_fx(fn)( + extended_coord, extended_atype, nlist, mapping, fparam, aparam + ) diff --git a/deepmd/pt_expt/model/make_model.py b/deepmd/pt_expt/model/make_model.py new file mode 100644 index 0000000000..d26733696d --- /dev/null +++ b/deepmd/pt_expt/model/make_model.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch + +from deepmd.dpmodel.atomic_model.base_atomic_model import ( + BaseAtomicModel, +) +from deepmd.dpmodel.model.make_model import make_model as make_model_dp +from deepmd.pt_expt.common import ( + torch_module, +) + +from .transform_output import ( + fit_output_to_model_output, +) + + +def make_model(T_AtomicModel: type[BaseAtomicModel]) -> type: + """Make a model as a derived class of an atomic model. + + Wraps dpmodel's make_model with torch.nn.Module and overrides + forward_common_atomic to use autograd-based derivatives. + + Parameters + ---------- + T_AtomicModel + The atomic model. + + Returns + ------- + CM + The model. + + """ + DPModel = make_model_dp(T_AtomicModel) + + @torch_module + class CM(DPModel): + def forward(self, *args: Any, **kwargs: Any) -> dict[str, torch.Tensor]: + """Default forward delegates to call(). + + Subclasses (e.g. EnergyModel) override this with output translation. + """ + return self.call(*args, **kwargs) + + def forward_common_atomic( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, torch.Tensor]: + atomic_ret = self.atomic_model.forward_common_atomic( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + ) + return fit_output_to_model_output( + atomic_ret, + self.atomic_output_def(), + extended_coord, + do_atomic_virial=do_atomic_virial, + create_graph=self.training, + mask=atomic_ret.get("mask"), + ) + + return CM diff --git a/deepmd/pt_expt/model/transform_output.py b/deepmd/pt_expt/model/transform_output.py new file mode 100644 index 0000000000..5fb1ac4e46 --- /dev/null +++ b/deepmd/pt_expt/model/transform_output.py @@ -0,0 +1,184 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +import torch + +from deepmd.dpmodel import ( + FittingOutputDef, + OutputVariableDef, + get_deriv_name, + get_reduce_name, +) +from deepmd.pt_expt.utils import ( + env, +) + + +def atomic_virial_corr( + extended_coord: torch.Tensor, + atom_energy: torch.Tensor, +) -> torch.Tensor: + nall = extended_coord.shape[1] + nf = extended_coord.shape[0] + nloc = atom_energy.shape[1] + coord, _ = torch.split(extended_coord, [nloc, nall - nloc], dim=1) + # no derivative with respect to the loc coord. + coord = coord.detach() + ce = coord * atom_energy + sumce = torch.sum(ce, dim=1) # [nf, 3] + + # Use vmap to batch the 3 backward passes (one per spatial component) + basis = torch.eye(3, dtype=sumce.dtype, device=sumce.device) # [3, 3] + basis = basis.unsqueeze(1).expand(3, nf, 3) # [3, nf, 3] + + def grad_fn(grad_output: torch.Tensor) -> torch.Tensor: + result = torch.autograd.grad( + [sumce], + [extended_coord], + grad_outputs=[grad_output], + create_graph=False, + retain_graph=True, + )[0] + assert result is not None + return result + + # [3, nf, nall, 3] — batched over the 3 spatial components + extended_virial_corr = torch.vmap(grad_fn)(basis) + # [3, nf, nall, 3] -> [nf, nall, 3, 3] + return extended_virial_corr.permute(1, 2, 3, 0) + + +def task_deriv_one( + atom_energy: torch.Tensor, + energy: torch.Tensor, + extended_coord: torch.Tensor, + do_virial: bool = True, + do_atomic_virial: bool = False, + create_graph: bool = True, +) -> tuple[torch.Tensor, torch.Tensor | None]: + faked_grad = torch.ones_like(energy) + lst: list[torch.Tensor | None] = [faked_grad] + extended_force = torch.autograd.grad( + [energy], + [extended_coord], + grad_outputs=lst, + create_graph=create_graph, + retain_graph=True, + )[0] + assert extended_force is not None + extended_force = -extended_force + if do_virial: + extended_virial = torch.einsum( + "...ik,...ij->...ikj", extended_force, extended_coord + ) + # the correction sums to zero, which does not contribute to global virial + if do_atomic_virial: + extended_virial_corr = atomic_virial_corr(extended_coord, atom_energy) + extended_virial = extended_virial + extended_virial_corr + # to [...,3,3] -> [...,9] + extended_virial = extended_virial.view(list(extended_virial.shape[:-2]) + [9]) # noqa:RUF005 + else: + extended_virial = None + return extended_force, extended_virial + + +def get_leading_dims( + vv: torch.Tensor, + vdef: OutputVariableDef, +) -> list[int]: + """Get the dimensions of nf x nloc.""" + vshape = vv.shape + return list(vshape[: (len(vshape) - len(vdef.shape))]) + + +def take_deriv( + vv: torch.Tensor, + svv: torch.Tensor, + vdef: OutputVariableDef, + coord_ext: torch.Tensor, + do_virial: bool = False, + do_atomic_virial: bool = False, + create_graph: bool = True, +) -> tuple[torch.Tensor, torch.Tensor | None]: + size = 1 + for ii in vdef.shape: + size *= ii + vv1 = vv.view(list(get_leading_dims(vv, vdef)) + [size]) # noqa: RUF005 + svv1 = svv.view(list(get_leading_dims(svv, vdef)) + [size]) # noqa: RUF005 + split_vv1 = torch.split(vv1, [1] * size, dim=-1) + split_svv1 = torch.split(svv1, [1] * size, dim=-1) + split_ff, split_avir = [], [] + for vvi, svvi in zip(split_vv1, split_svv1): + # nf x nloc x 3, nf x nloc x 9 + ffi, aviri = task_deriv_one( + vvi, + svvi, + coord_ext, + do_virial=do_virial, + do_atomic_virial=do_atomic_virial, + create_graph=create_graph, + ) + # nf x nloc x 1 x 3, nf x nloc x 1 x 9 + ffi = ffi.unsqueeze(-2) + split_ff.append(ffi) + if do_virial: + assert aviri is not None + aviri = aviri.unsqueeze(-2) + split_avir.append(aviri) + # nf x nall x v_dim x 3, nf x nall x v_dim x 9 + out_lead_shape = list(coord_ext.shape[:-1]) + vdef.shape + ff = torch.concat(split_ff, dim=-2).view(out_lead_shape + [3]) # noqa: RUF005 + if do_virial: + avir = torch.concat(split_avir, dim=-2).view(out_lead_shape + [9]) # noqa: RUF005 + else: + avir = None + return ff, avir + + +def fit_output_to_model_output( + fit_ret: dict[str, torch.Tensor], + fit_output_def: FittingOutputDef, + coord_ext: torch.Tensor, + do_atomic_virial: bool = False, + create_graph: bool = True, + mask: torch.Tensor | None = None, +) -> dict[str, torch.Tensor]: + """Transform the output of the fitting network to + the model output. + + """ + redu_prec = env.GLOBAL_PT_ENER_FLOAT_PRECISION + model_ret = dict(fit_ret.items()) + for kk, vv in fit_ret.items(): + vdef = fit_output_def[kk] + shap = vdef.shape + atom_axis = -(len(shap) + 1) + if vdef.reducible: + kk_redu = get_reduce_name(kk) + if vdef.intensive: + if mask is not None: + model_ret[kk_redu] = torch.sum( + vv.to(redu_prec), dim=atom_axis + ) / torch.sum(mask, dim=-1, keepdim=True) + else: + model_ret[kk_redu] = torch.mean(vv.to(redu_prec), dim=atom_axis) + else: + model_ret[kk_redu] = torch.sum(vv.to(redu_prec), dim=atom_axis) + if vdef.r_differentiable: + kk_derv_r, kk_derv_c = get_deriv_name(kk) + dr, dc = take_deriv( + vv, + model_ret[kk_redu], + vdef, + coord_ext, + do_virial=vdef.c_differentiable, + do_atomic_virial=do_atomic_virial, + create_graph=create_graph, + ) + model_ret[kk_derv_r] = dr + if vdef.c_differentiable: + assert dc is not None + model_ret[kk_derv_c] = dc + model_ret[kk_derv_c + "_redu"] = torch.sum( + model_ret[kk_derv_c].to(redu_prec), dim=1 + ) + return model_ret diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 1611ab53d2..929907c2f3 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -26,6 +26,21 @@ class TorchArrayParam(torch.nn.Parameter): + """Parameter subclass that supports ``np.array(param)`` conversion. + + Note: this class is intentionally NOT used for model parameters. + ``make_fx`` (``torch.fx.experimental.proxy_tensor``) uses + ``ProxyTorchDispatchMode`` to intercept tensor operations. When an + operand is a *subclass* of ``torch.Tensor`` (including subclasses of + ``torch.nn.Parameter``), PyTorch invokes the ``__torch_function__`` + protocol which the proxy dispatch mode does not handle, causing + ``aten.mm`` and other ops to fail with "Multiple dispatch failed … + returned NotImplemented". Using plain ``torch.nn.Parameter`` avoids + this because the proxy mode is designed to work with the base + ``Parameter`` type. ``TorchArrayParam`` is kept only for backward + compatibility and should not be used for new code. + """ + def __new__( # noqa: PYI034 cls, data: Any = None, requires_grad: bool = True ) -> "TorchArrayParam": @@ -40,6 +55,31 @@ def __array__(self, dtype: Any | None = None) -> np.ndarray: # do not apply torch_module until its setattr working to register parameters class NativeLayer(NativeLayerDP, torch.nn.Module): + """PyTorch layer wrapping dpmodel's ``NativeLayer``. + + Two aspects of the inherited dpmodel ``call()`` are incompatible with + ``make_fx`` tracing (used to export ``forward_lower`` with + ``autograd.grad``-based force/virial computation): + + 1. **Ellipsis indexing** (``self.w[...]``): On a ``torch.Tensor`` + this triggers ``aten.alias``, an op that ``ProxyTorchDispatchMode`` + does not support, resulting in "Multiple dispatch failed for + ``aten.alias.default``". + 2. **``array_api_compat`` wrappers** (``xp = array_api_compat + .array_namespace(x); xp.matmul(…)``): The wrappers re-enter + ``torch.matmul`` through Python, which goes through the + ``__torch_function__`` protocol. Under the proxy dispatch mode + this path also fails with "Multiple dispatch failed". + + This class therefore overrides ``call()`` with an implementation that + uses plain ``torch`` ops exclusively (``torch.matmul``, ``torch.tanh``, + etc.), avoiding both issues. + + Trainable weights are stored as plain ``torch.nn.Parameter`` (not + ``TorchArrayParam``) for the same ``make_fx`` compatibility reason — + see the ``TorchArrayParam`` docstring. + """ + def __init__(self, *args: Any, **kwargs: Any) -> None: torch.nn.Module.__init__(self) NativeLayerDP.__init__(self, *args, **kwargs) @@ -61,8 +101,8 @@ def __setattr__(self, name: str, value: Any) -> None: if getattr(self, "trainable", False): param = ( value - if isinstance(value, TorchArrayParam) - else TorchArrayParam(val, requires_grad=True) + if isinstance(value, torch.nn.Parameter) + else torch.nn.Parameter(val, requires_grad=True) ) if name in self._parameters: self._parameters[name] = param @@ -76,10 +116,78 @@ def __setattr__(self, name: str, value: Any) -> None: return return super().__setattr__(name, value) + def call(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass using pure torch ops. + + Overrides dpmodel's ``call()`` to ensure compatibility with + ``make_fx`` (``torch.fx.experimental.proxy_tensor``). + + The dpmodel implementation uses ``self.w[...]`` and + ``array_api_compat.array_namespace(x).matmul(…)`` for + backend-agnostic array operations. Both patterns break under + ``make_fx``'s ``ProxyTorchDispatchMode``: + + - ``self.w[...]`` emits ``aten.alias`` which the proxy mode + cannot dispatch. + - ``array_api_compat`` re-enters ``torch.matmul`` via Python, + hitting ``__torch_function__`` which the proxy mode returns + ``NotImplemented`` for. + + This override uses ``torch.matmul``, ``torch.cat``, and + ``_torch_activation`` directly, sidestepping both issues. + """ + if self.w is None or self.activation_function is None: + raise ValueError("w, b, and activation_function must be set") + y = ( + torch.matmul(x, self.w) + self.b + if self.b is not None + else torch.matmul(x, self.w) + ) + if y.dtype != x.dtype: + y = y.to(x.dtype) + y = _torch_activation(y, self.activation_function) + if self.idt is not None: + y = y * self.idt + if self.resnet and self.w.shape[1] == self.w.shape[0]: + y = y + x + elif self.resnet and self.w.shape[1] == 2 * self.w.shape[0]: + y = y + torch.cat([x, x], dim=-1) + return y + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.call(x) +def _torch_activation(x: torch.Tensor, name: str) -> torch.Tensor: + """Apply activation function using native torch ops. + + The dpmodel ``get_activation_fn`` returns closures that call + ``array_api_compat.array_namespace(x).tanh(x)`` etc. Under + ``make_fx`` proxy tracing, the ``array_api_compat`` indirection + triggers ``__torch_function__`` dispatch failures. This function + calls ``torch.tanh`` and friends directly to avoid the issue. + """ + name = name.lower() + if name == "tanh": + return torch.tanh(x) + elif name == "relu": + return torch.relu(x) + elif name in ("gelu", "gelu_tf"): + return torch.nn.functional.gelu(x, approximate="tanh") + elif name == "relu6": + return torch.clamp(x, min=0.0, max=6.0) + elif name == "softplus": + return torch.nn.functional.softplus(x) + elif name == "sigmoid": + return torch.sigmoid(x) + elif name == "silu": + return torch.nn.functional.silu(x) + elif name in ("none", "linear"): + return x + else: + raise NotImplementedError(name) + + @torch_module class NativeNet(make_multilayer_network(NativeLayer, NativeOP)): def __init__(self, layers: list[dict] | None = None) -> None: diff --git a/source/tests/common/dpmodel/test_dp_model.py b/source/tests/common/dpmodel/test_dp_model.py index af4eea624d..ade365188d 100644 --- a/source/tests/common/dpmodel/test_dp_model.py +++ b/source/tests/common/dpmodel/test_dp_model.py @@ -46,8 +46,8 @@ def test_self_consistency( md0 = EnergyModel(ds, ft, type_map=type_map) md1 = EnergyModel.deserialize(md0.serialize()) - ret0 = md0.call_lower(self.coord_ext, self.atype_ext, self.nlist) - ret1 = md1.call_lower(self.coord_ext, self.atype_ext, self.nlist) + ret0 = md0.call_common_lower(self.coord_ext, self.atype_ext, self.nlist) + ret1 = md1.call_common_lower(self.coord_ext, self.atype_ext, self.nlist) np.testing.assert_allclose(ret0["energy"], ret1["energy"]) np.testing.assert_allclose(ret0["energy_redu"], ret1["energy_redu"]) @@ -80,8 +80,8 @@ def test_prec_consistency(self) -> None: args32 = [self.coord_ext, self.atype_ext, self.nlist] args32[0] = args32[0].astype(np.float32) - model_l_ret_64 = md1.call_lower(*args64, fparam=fparam, aparam=aparam) - model_l_ret_32 = md1.call_lower(*args32, fparam=fparam, aparam=aparam) + model_l_ret_64 = md1.call_common_lower(*args64, fparam=fparam, aparam=aparam) + model_l_ret_32 = md1.call_common_lower(*args32, fparam=fparam, aparam=aparam) for ii in model_l_ret_32.keys(): if model_l_ret_32[ii] is None: @@ -134,8 +134,8 @@ def test_prec_consistency(self) -> None: args32[0] = args32[0].astype(np.float32) args32[2] = args32[2].astype(np.float32) - model_l_ret_64 = md1.call(*args64, fparam=fparam, aparam=aparam) - model_l_ret_32 = md1.call(*args32, fparam=fparam, aparam=aparam) + model_l_ret_64 = md1.call_common(*args64, fparam=fparam, aparam=aparam) + model_l_ret_32 = md1.call_common(*args32, fparam=fparam, aparam=aparam) for ii in model_l_ret_32.keys(): if model_l_ret_32[ii] is None: diff --git a/source/tests/common/dpmodel/test_padding_atoms.py b/source/tests/common/dpmodel/test_padding_atoms.py index d4ea39f598..29e34c09a9 100644 --- a/source/tests/common/dpmodel/test_padding_atoms.py +++ b/source/tests/common/dpmodel/test_padding_atoms.py @@ -69,8 +69,8 @@ def test_padding_atoms_consistency(self): result = model.call(*args) # test intensive np.testing.assert_allclose( - result[f"{var_name}_redu"], - np.mean(result[f"{var_name}"], axis=1), + result[var_name], + np.mean(result[f"atom_{var_name}"], axis=1), atol=self.atol, ) # test padding atoms @@ -93,8 +93,8 @@ def test_padding_atoms_consistency(self): args = [coord_padding, atype_padding, self.cell] result_padding = model.call(*args) np.testing.assert_allclose( - result[f"{var_name}_redu"], - result_padding[f"{var_name}_redu"], + result[var_name], + result_padding[var_name], atol=self.atol, ) diff --git a/source/tests/consistent/model/common.py b/source/tests/consistent/model/common.py index 778ae519c6..04966c02a1 100644 --- a/source/tests/consistent/model/common.py +++ b/source/tests/consistent/model/common.py @@ -14,6 +14,7 @@ INSTALLED_JAX, INSTALLED_PD, INSTALLED_PT, + INSTALLED_PT_EXPT, INSTALLED_TF, ) @@ -30,6 +31,8 @@ from deepmd.jax.env import ( jnp, ) +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.common import to_torch_array as pt_expt_numpy_to_torch if INSTALLED_PD: from deepmd.pd.utils.utils import to_numpy_array as paddle_to_numpy from deepmd.pd.utils.utils import to_paddle_tensor as numpy_to_paddle @@ -104,6 +107,19 @@ def eval_pt_model(self, pt_obj: Any, natoms, coords, atype, box) -> Any: ).items() } + def eval_pt_expt_model(self, pt_expt_obj: Any, natoms, coords, atype, box) -> Any: + coord_tensor = pt_expt_numpy_to_torch(coords) + coord_tensor.requires_grad_(True) + return { + kk: vv.detach().cpu().numpy() + for kk, vv in pt_expt_obj( + coord_tensor, + pt_expt_numpy_to_torch(atype), + box=pt_expt_numpy_to_torch(box), + do_atomic_virial=True, + ).items() + } + def eval_jax_model(self, jax_obj: Any, natoms, coords, atype, box) -> Any: def assert_jax_array(arr): assert isinstance(arr, jnp.ndarray) or arr is None diff --git a/source/tests/consistent/model/test_dipole.py b/source/tests/consistent/model/test_dipole.py index 339dcae7c3..5a57bf5904 100644 --- a/source/tests/consistent/model/test_dipole.py +++ b/source/tests/consistent/model/test_dipole.py @@ -190,7 +190,7 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: # shape not matched. ravel... if backend in {self.RefBackend.DP, self.RefBackend.JAX}: return ( - ret["dipole_redu"].ravel(), + ret["global_dipole"].ravel(), ret["dipole"].ravel(), ) elif backend is self.RefBackend.PT: diff --git a/source/tests/consistent/model/test_dos.py b/source/tests/consistent/model/test_dos.py index ef72e9096b..c6d1fde425 100644 --- a/source/tests/consistent/model/test_dos.py +++ b/source/tests/consistent/model/test_dos.py @@ -184,8 +184,8 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: # shape not matched. ravel... if backend in {self.RefBackend.DP, self.RefBackend.JAX}: return ( - ret["dos_redu"].ravel(), ret["dos"].ravel(), + ret["atom_dos"].ravel(), ) elif backend is self.RefBackend.PT: return ( diff --git a/source/tests/consistent/model/test_dpa1.py b/source/tests/consistent/model/test_dpa1.py index 8b8fab7ae1..83aa33a332 100644 --- a/source/tests/consistent/model/test_dpa1.py +++ b/source/tests/consistent/model/test_dpa1.py @@ -223,8 +223,8 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: # shape not matched. ravel... if backend is self.RefBackend.DP: return ( - ret["energy_redu"].ravel(), ret["energy"].ravel(), + ret["atom_energy"].ravel(), SKIP_FLAG, SKIP_FLAG, SKIP_FLAG, @@ -255,10 +255,10 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: ) elif backend is self.RefBackend.JAX: return ( - ret["energy_redu"].ravel(), ret["energy"].ravel(), - ret["energy_derv_r"].ravel(), - ret["energy_derv_c_redu"].ravel(), - ret["energy_derv_c"].ravel(), + ret["atom_energy"].ravel(), + ret["force"].ravel(), + ret["virial"].ravel(), + ret["atom_virial"].ravel(), ) raise ValueError(f"Unknown backend: {backend}") diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index d56b9a257b..29a84a9363 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -26,6 +26,7 @@ INSTALLED_JAX, INSTALLED_PD, INSTALLED_PT, + INSTALLED_PT_EXPT, INSTALLED_TF, SKIP_FLAG, CommonTest, @@ -53,6 +54,11 @@ from deepmd.pd.utils.utils import to_paddle_tensor as numpy_to_paddle else: EnergyModelPD = None +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.common import to_torch_array as pt_expt_numpy_to_torch + from deepmd.pt_expt.model import EnergyModel as EnergyModelPTExpt +else: + EnergyModelPTExpt = None from deepmd.utils.argcheck import ( model_args, ) @@ -115,8 +121,8 @@ def data(self) -> dict: dp_class = EnergyModelDP pt_class = EnergyModelPT pd_class = EnergyModelPD + pt_expt_class = EnergyModelPTExpt jax_class = EnergyModelJAX - pd_class = EnergyModelPD args = model_args() def get_reference_backend(self): @@ -128,6 +134,8 @@ def get_reference_backend(self): return self.RefBackend.PT if not self.skip_tf: return self.RefBackend.TF + if not self.skip_pt_expt and self.pt_expt_class is not None: + return self.RefBackend.PT_EXPT if not self.skip_jax: return self.RefBackend.JAX if not self.skip_pd: @@ -156,6 +164,9 @@ def pass_data_to_cls(self, cls, data) -> Any: model = get_model_pt(data) model.atomic_model.out_bias.uniform_() return model + elif cls is EnergyModelPTExpt: + dp_model = get_model_dp(data) + return EnergyModelPTExpt.deserialize(dp_model.serialize()) elif cls is EnergyModelJAX: return get_model_jax(data) elif cls is EnergyModelPD: @@ -229,6 +240,15 @@ def eval_pt(self, pt_obj: Any) -> Any: self.box, ) + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + return self.eval_pt_expt_model( + pt_expt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + def eval_jax(self, jax_obj: Any) -> Any: return self.eval_jax_model( jax_obj, @@ -251,8 +271,8 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: # shape not matched. ravel... if backend is self.RefBackend.DP: return ( - ret["energy_redu"].ravel(), ret["energy"].ravel(), + ret["atom_energy"].ravel(), SKIP_FLAG, SKIP_FLAG, SKIP_FLAG, @@ -265,6 +285,14 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: ret["virial"].ravel(), ret["atom_virial"].ravel(), ) + elif backend is self.RefBackend.PT_EXPT: + return ( + ret["energy"].ravel(), + ret["atom_energy"].ravel(), + ret["force"].ravel(), + ret["virial"].ravel(), + ret["atom_virial"].ravel(), + ) elif backend is self.RefBackend.TF: return ( ret[0].ravel(), @@ -275,11 +303,11 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: ) elif backend is self.RefBackend.JAX: return ( - ret["energy_redu"].ravel(), ret["energy"].ravel(), - ret["energy_derv_r"].ravel(), - ret["energy_derv_c_redu"].ravel(), - ret["energy_derv_c"].ravel(), + ret["atom_energy"].ravel(), + ret["force"].ravel(), + ret["virial"].ravel(), + ret["atom_virial"].ravel(), ) elif backend is self.RefBackend.PD: return ( @@ -339,6 +367,7 @@ def data(self) -> dict: tf_class = EnergyModelTF dp_class = EnergyModelDP pt_class = EnergyModelPT + pt_expt_class = EnergyModelPTExpt jax_class = EnergyModelJAX pd_class = EnergyModelPD args = model_args() @@ -350,6 +379,8 @@ def get_reference_backend(self): """ if not self.skip_pt: return self.RefBackend.PT + if not self.skip_pt_expt and self.pt_expt_class is not None: + return self.RefBackend.PT_EXPT if not self.skip_jax: return self.RefBackend.JAX if not self.skip_dp: @@ -374,6 +405,9 @@ def pass_data_to_cls(self, cls, data) -> Any: return get_model_dp(data) elif cls is EnergyModelPT: return get_model_pt(data) + elif cls is EnergyModelPTExpt: + dp_model = get_model_dp(data) + return EnergyModelPTExpt.deserialize(dp_model.serialize()) elif cls is EnergyModelJAX: return get_model_jax(data) elif cls is EnergyModelPD: @@ -460,6 +494,20 @@ def eval_pt(self, pt_obj: Any) -> Any: ).items() } + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + coord_tensor = pt_expt_numpy_to_torch(self.extended_coord) + coord_tensor.requires_grad_(True) + return { + kk: vv.detach().cpu().numpy() if vv is not None else None + for kk, vv in pt_expt_obj.forward_lower( + coord_tensor, + pt_expt_numpy_to_torch(self.extended_atype), + pt_expt_numpy_to_torch(self.nlist), + pt_expt_numpy_to_torch(self.mapping), + do_atomic_virial=True, + ).items() + } + def eval_jax(self, jax_obj: Any) -> Any: return { kk: to_numpy_array(vv) @@ -488,8 +536,8 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: # shape not matched. ravel... if backend is self.RefBackend.DP: return ( - ret["energy_redu"].ravel(), ret["energy"].ravel(), + ret["atom_energy"].ravel(), SKIP_FLAG, SKIP_FLAG, SKIP_FLAG, @@ -502,13 +550,21 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: ret["virial"].ravel(), ret["extended_virial"].ravel(), ) + elif backend is self.RefBackend.PT_EXPT: + return ( + ret["energy"].ravel(), + ret["atom_energy"].ravel(), + ret["extended_force"].ravel(), + ret["virial"].ravel(), + ret["extended_virial"].ravel(), + ) elif backend is self.RefBackend.JAX: return ( - ret["energy_redu"].ravel(), ret["energy"].ravel(), - ret["energy_derv_r"].ravel(), - ret["energy_derv_c_redu"].ravel(), - ret["energy_derv_c"].ravel(), + ret["atom_energy"].ravel(), + ret["extended_force"].ravel(), + ret["virial"].ravel(), + ret["extended_virial"].ravel(), ) elif backend is self.RefBackend.PD: return ( @@ -519,3 +575,668 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: ret["extended_virial"].flatten(), ) raise ValueError(f"Unknown backend: {backend}") + + +@unittest.skipUnless(INSTALLED_PT, "PyTorch is not installed") +class TestEnerModelAPIs(unittest.TestCase): + """Test consistency of model-level APIs between pt and dpmodel backends. + + Both models are constructed from the same serialized weights + (dpmodel -> serialize -> pt deserialize) so that numerical outputs + can be compared directly. + """ + + def setUp(self) -> None: + from deepmd.utils.argcheck import ( + model_args, + ) + + data = model_args().normalize_value( + { + "type_map": ["O", "H"], + "descriptor": { + "type": "se_e2_a", + "sel": [20, 20], + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [3, 6], + "resnet_dt": False, + "axis_neuron": 2, + "precision": "float64", + "type_one_side": True, + "seed": 1, + }, + "fitting_net": { + "neuron": [5, 5], + "resnet_dt": True, + "precision": "float64", + "seed": 1, + }, + }, + trim_pattern="_*", + ) + # Build dpmodel first, then deserialize into pt to share weights + self.dp_model = get_model_dp(data) + serialized = self.dp_model.serialize() + self.pt_model = EnergyModelPT.deserialize(serialized) + + # Coords / atype / box + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, -1, 3) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32).reshape(1, -1) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, 9) + + # Build extended coords + nlist for lower-level calls + rcut = 6.0 + nframes, nloc = self.atype.shape[:2] + coord_normalized = normalize_coord( + self.coords.reshape(nframes, nloc, 3), + self.box.reshape(nframes, 3, 3), + ) + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( + coord_normalized, self.atype, self.box, rcut + ) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + nloc, + rcut, + [20, 20], + distinguish_types=True, + ) + self.extended_coord = extended_coord.reshape(nframes, -1, 3) + self.extended_atype = extended_atype + self.mapping = mapping + self.nlist = nlist + + def test_translated_output_def(self) -> None: + """translated_output_def should return the same keys on dp and pt.""" + dp_def = self.dp_model.translated_output_def() + pt_def = self.pt_model.translated_output_def() + self.assertEqual(set(dp_def.keys()), set(pt_def.keys())) + for key in dp_def: + self.assertEqual(dp_def[key].shape, pt_def[key].shape) + + def test_get_descriptor(self) -> None: + """get_descriptor should return a non-None object on both backends.""" + self.assertIsNotNone(self.dp_model.get_descriptor()) + self.assertIsNotNone(self.pt_model.get_descriptor()) + + def test_get_fitting_net(self) -> None: + """get_fitting_net should return a non-None object on both backends.""" + self.assertIsNotNone(self.dp_model.get_fitting_net()) + self.assertIsNotNone(self.pt_model.get_fitting_net()) + + def test_get_out_bias(self) -> None: + """get_out_bias should return numerically equal values on dp and pt. + + Freshly constructed models have zero bias; the shape (n_output x ntypes x odim) + is verified. Non-zero bias round-trip is covered by test_set_out_bias. + """ + dp_bias = to_numpy_array(self.dp_model.get_out_bias()) + pt_bias = torch_to_numpy(self.pt_model.get_out_bias()) + np.testing.assert_allclose(dp_bias, pt_bias, rtol=1e-10, atol=1e-10) + # Verify shape is sensible (n_output_keys x ntypes x odim) + self.assertEqual(dp_bias.shape[1], 2) # ntypes + self.assertGreater(dp_bias.shape[0], 0) # at least one output key + + def test_set_out_bias(self) -> None: + """set_out_bias should update the bias on both backends.""" + dp_bias = to_numpy_array(self.dp_model.get_out_bias()) + new_bias = dp_bias + 1.0 + # dp + self.dp_model.set_out_bias(new_bias) + np.testing.assert_allclose( + to_numpy_array(self.dp_model.get_out_bias()), + new_bias, + rtol=1e-10, + atol=1e-10, + ) + # pt + self.pt_model.set_out_bias(numpy_to_torch(new_bias)) + np.testing.assert_allclose( + torch_to_numpy(self.pt_model.get_out_bias()), + new_bias, + rtol=1e-10, + atol=1e-10, + ) + + def test_forward_common_alias(self) -> None: + """forward_common should be the same as call_common on dpmodel.""" + ret_call = self.dp_model.call_common( + self.coords, + self.atype, + box=self.box, + ) + ret_fc = self.dp_model.forward_common( + self.coords, + self.atype, + box=self.box, + ) + for key in ret_call: + np.testing.assert_equal(ret_call[key], ret_fc[key]) + + def test_forward_common_lower_alias(self) -> None: + """forward_common_lower should be the same as call_common_lower on dpmodel.""" + ret_call = self.dp_model.call_common_lower( + self.extended_coord, + self.extended_atype, + self.nlist, + self.mapping, + ) + ret_fc = self.dp_model.forward_common_lower( + self.extended_coord, + self.extended_atype, + self.nlist, + self.mapping, + ) + for key in ret_call: + np.testing.assert_equal(ret_call[key], ret_fc[key]) + + def test_model_output_def(self) -> None: + """model_output_def should return the same keys and shapes on dp and pt.""" + dp_def = self.dp_model.model_output_def().get_data() + pt_def = self.pt_model.model_output_def().get_data() + self.assertEqual(set(dp_def.keys()), set(pt_def.keys())) + for key in dp_def: + self.assertEqual(dp_def[key].shape, pt_def[key].shape) + + def test_model_output_type(self) -> None: + """model_output_type should return the same list on dp and pt.""" + self.assertEqual( + self.dp_model.model_output_type(), + self.pt_model.model_output_type(), + ) + + def test_do_grad_r(self) -> None: + """do_grad_r should return the same value on dp and pt.""" + self.assertEqual( + self.dp_model.do_grad_r("energy"), + self.pt_model.do_grad_r("energy"), + ) + self.assertTrue(self.dp_model.do_grad_r("energy")) + + def test_do_grad_c(self) -> None: + """do_grad_c should return the same value on dp and pt.""" + self.assertEqual( + self.dp_model.do_grad_c("energy"), + self.pt_model.do_grad_c("energy"), + ) + self.assertTrue(self.dp_model.do_grad_c("energy")) + + def test_get_rcut(self) -> None: + """get_rcut should return the same value on dp and pt.""" + self.assertEqual(self.dp_model.get_rcut(), self.pt_model.get_rcut()) + self.assertAlmostEqual(self.dp_model.get_rcut(), 6.0) + + def test_get_type_map(self) -> None: + """get_type_map should return the same list on dp and pt.""" + self.assertEqual(self.dp_model.get_type_map(), self.pt_model.get_type_map()) + self.assertEqual(self.dp_model.get_type_map(), ["O", "H"]) + + def test_get_sel(self) -> None: + """get_sel should return the same list on dp and pt.""" + self.assertEqual(self.dp_model.get_sel(), self.pt_model.get_sel()) + self.assertEqual(self.dp_model.get_sel(), [20, 20]) + + def test_get_nsel(self) -> None: + """get_nsel should return the same value on dp and pt.""" + self.assertEqual(self.dp_model.get_nsel(), self.pt_model.get_nsel()) + self.assertEqual(self.dp_model.get_nsel(), 40) + + def test_get_nnei(self) -> None: + """get_nnei should return the same value on dp and pt.""" + self.assertEqual(self.dp_model.get_nnei(), self.pt_model.get_nnei()) + self.assertEqual(self.dp_model.get_nnei(), 40) + + def test_mixed_types(self) -> None: + """mixed_types should return the same value on dp and pt.""" + self.assertEqual(self.dp_model.mixed_types(), self.pt_model.mixed_types()) + # se_e2_a is not mixed-types + self.assertFalse(self.dp_model.mixed_types()) + + def test_has_message_passing(self) -> None: + """has_message_passing should return the same value on dp and pt.""" + self.assertEqual( + self.dp_model.has_message_passing(), + self.pt_model.has_message_passing(), + ) + self.assertFalse(self.dp_model.has_message_passing()) + + def test_need_sorted_nlist_for_lower(self) -> None: + """need_sorted_nlist_for_lower should return the same value on dp and pt.""" + self.assertEqual( + self.dp_model.need_sorted_nlist_for_lower(), + self.pt_model.need_sorted_nlist_for_lower(), + ) + self.assertFalse(self.dp_model.need_sorted_nlist_for_lower()) + + def test_get_dim_fparam(self) -> None: + """get_dim_fparam should return the same value on dp and pt.""" + self.assertEqual(self.dp_model.get_dim_fparam(), self.pt_model.get_dim_fparam()) + self.assertEqual(self.dp_model.get_dim_fparam(), 0) + + def test_get_dim_aparam(self) -> None: + """get_dim_aparam should return the same value on dp and pt.""" + self.assertEqual(self.dp_model.get_dim_aparam(), self.pt_model.get_dim_aparam()) + self.assertEqual(self.dp_model.get_dim_aparam(), 0) + + def test_get_sel_type(self) -> None: + """get_sel_type should return the same list on dp and pt.""" + self.assertEqual(self.dp_model.get_sel_type(), self.pt_model.get_sel_type()) + # For this model config, all types are selected (empty list) + self.assertEqual(self.dp_model.get_sel_type(), [0, 1]) + + def test_is_aparam_nall(self) -> None: + """is_aparam_nall should return the same value on dp and pt.""" + self.assertEqual(self.dp_model.is_aparam_nall(), self.pt_model.is_aparam_nall()) + self.assertFalse(self.dp_model.is_aparam_nall()) + + def test_atomic_output_def(self) -> None: + """atomic_output_def should return the same keys and shapes on dp and pt.""" + dp_def = self.dp_model.atomic_output_def() + pt_def = self.pt_model.atomic_output_def() + self.assertEqual(set(dp_def.keys()), set(pt_def.keys())) + for key in dp_def.keys(): + self.assertEqual(dp_def[key].shape, pt_def[key].shape) + + def test_format_nlist(self) -> None: + """format_nlist should produce the same result on dp and pt.""" + dp_nlist = self.dp_model.format_nlist( + self.extended_coord, + self.extended_atype, + self.nlist, + ) + pt_nlist = torch_to_numpy( + self.pt_model.format_nlist( + numpy_to_torch(self.extended_coord), + numpy_to_torch(self.extended_atype), + numpy_to_torch(self.nlist), + ) + ) + np.testing.assert_equal(dp_nlist, pt_nlist) + + def test_forward_common_atomic(self) -> None: + """forward_common_atomic should produce consistent results on dp and pt. + + Compares at the atomic_model level, where both backends define this method. + """ + dp_ret = self.dp_model.atomic_model.forward_common_atomic( + self.extended_coord, + self.extended_atype, + self.nlist, + mapping=self.mapping, + ) + pt_ret = self.pt_model.atomic_model.forward_common_atomic( + numpy_to_torch(self.extended_coord), + numpy_to_torch(self.extended_atype), + numpy_to_torch(self.nlist), + mapping=numpy_to_torch(self.mapping), + ) + # Compare the common keys + common_keys = set(dp_ret.keys()) & set(pt_ret.keys()) + self.assertTrue(len(common_keys) > 0) + for key in common_keys: + if dp_ret[key] is not None and pt_ret[key] is not None: + np.testing.assert_allclose( + dp_ret[key], + torch_to_numpy(pt_ret[key]), + rtol=1e-10, + atol=1e-10, + err_msg=f"Mismatch in forward_common_atomic key '{key}'", + ) + + def test_has_default_fparam(self) -> None: + """has_default_fparam should return the same value on dp and pt.""" + self.assertEqual( + self.dp_model.has_default_fparam(), + self.pt_model.has_default_fparam(), + ) + self.assertFalse(self.dp_model.has_default_fparam()) + + def test_get_default_fparam(self) -> None: + """get_default_fparam should return None on both dp and pt (no fparam configured).""" + dp_val = self.dp_model.get_default_fparam() + pt_val = self.pt_model.get_default_fparam() + self.assertIsNone(dp_val) + self.assertIsNone(pt_val) + # Note: both return None because no default_fparam is configured. + # A non-trivial return requires configuring default_fparam in the fitting net. + + def test_change_out_bias(self) -> None: + """change_out_bias should produce consistent bias on dp and pt.""" + nframes = 2 + # Use realistic coords (from setUp, tiled for 2 frames) + coords_2f = np.tile(self.coords, (nframes, 1, 1)) # (2, 6, 3) + atype_2f = np.array([[0, 0, 1, 1, 1, 1], [0, 1, 1, 0, 1, 1]], dtype=np.int32) + box_2f = np.tile(self.box.reshape(1, 3, 3), (nframes, 1, 1)) + natoms_data = np.array([[6, 6, 2, 4], [6, 6, 2, 4]], dtype=np.int32) + energy_data = np.array([10.0, 20.0]).reshape(nframes, 1) + + # dpmodel stat data (numpy) + dp_merged = [ + { + "coord": coords_2f, + "atype": atype_2f, + "atype_ext": atype_2f, + "box": box_2f, + "natoms": natoms_data, + "energy": energy_data, + "find_energy": np.float32(1.0), + } + ] + # pt stat data (torch tensors) + pt_merged = [ + { + "coord": numpy_to_torch(coords_2f), + "atype": numpy_to_torch(atype_2f), + "atype_ext": numpy_to_torch(atype_2f), + "box": numpy_to_torch(box_2f), + "natoms": numpy_to_torch(natoms_data), + "energy": numpy_to_torch(energy_data), + "find_energy": np.float32(1.0), + } + ] + + # Save initial (zero) bias + dp_bias_init = to_numpy_array(self.dp_model.get_out_bias()).copy() + + # Test "set-by-statistic" mode + self.dp_model.change_out_bias(dp_merged, bias_adjust_mode="set-by-statistic") + self.pt_model.change_out_bias(pt_merged, bias_adjust_mode="set-by-statistic") + dp_bias = to_numpy_array(self.dp_model.get_out_bias()) + pt_bias = torch_to_numpy(self.pt_model.get_out_bias()) + np.testing.assert_allclose(dp_bias, pt_bias, rtol=1e-10, atol=1e-10) + # Verify bias actually changed from initial zeros + self.assertFalse( + np.allclose(dp_bias, dp_bias_init), + "set-by-statistic did not change the bias from initial values", + ) + + # Test "change-by-statistic" mode (adjusts bias based on model predictions) + dp_bias_before = dp_bias.copy() + self.dp_model.change_out_bias(dp_merged, bias_adjust_mode="change-by-statistic") + self.pt_model.change_out_bias(pt_merged, bias_adjust_mode="change-by-statistic") + dp_bias2 = to_numpy_array(self.dp_model.get_out_bias()) + pt_bias2 = torch_to_numpy(self.pt_model.get_out_bias()) + np.testing.assert_allclose(dp_bias2, pt_bias2, rtol=1e-10, atol=1e-10) + # Verify change-by-statistic further modified the bias + self.assertFalse( + np.allclose(dp_bias2, dp_bias_before), + "change-by-statistic did not further change the bias", + ) + + def test_change_type_map(self) -> None: + """change_type_map should produce consistent results on dp and pt. + + Uses a DPA1 (se_atten) descriptor since se_e2_a does not support + change_type_map (non-mixed-types descriptors raise NotImplementedError). + """ + from deepmd.utils.argcheck import model_args as model_args_fn + + data = model_args_fn().normalize_value( + { + "type_map": ["O", "H"], + "descriptor": { + "type": "se_atten", + "sel": 20, + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [3, 6], + "resnet_dt": False, + "axis_neuron": 2, + "precision": "float64", + "seed": 1, + "attn": 6, + "attn_layer": 0, + }, + "fitting_net": { + "neuron": [5, 5], + "resnet_dt": True, + "precision": "float64", + "seed": 1, + }, + }, + trim_pattern="_*", + ) + dp_model = get_model_dp(data) + pt_model = EnergyModelPT.deserialize(dp_model.serialize()) + + # Set non-zero out_bias so the swap is non-trivial + dp_bias_orig = to_numpy_array(dp_model.get_out_bias()).copy() + new_bias = dp_bias_orig.copy() + new_bias[:, 0, :] = 1.5 # type 0 ("O") + new_bias[:, 1, :] = -3.7 # type 1 ("H") + dp_model.set_out_bias(new_bias) + pt_model.set_out_bias(numpy_to_torch(new_bias)) + + new_type_map = ["H", "O"] + dp_model.change_type_map(new_type_map) + pt_model.change_type_map(new_type_map) + + # Both should have the new type_map + self.assertEqual(dp_model.get_type_map(), new_type_map) + self.assertEqual(pt_model.get_type_map(), new_type_map) + + # Out_bias should be reordered consistently between backends + dp_bias_new = to_numpy_array(dp_model.get_out_bias()) + pt_bias_new = torch_to_numpy(pt_model.get_out_bias()) + np.testing.assert_allclose(dp_bias_new, pt_bias_new, rtol=1e-10, atol=1e-10) + + # Verify the reorder is correct: old type 0 -> new type 1, old type 1 -> new type 0 + np.testing.assert_allclose( + dp_bias_new[:, 0, :], + new_bias[:, 1, :], + rtol=1e-10, + atol=1e-10, + ) + np.testing.assert_allclose( + dp_bias_new[:, 1, :], + new_bias[:, 0, :], + rtol=1e-10, + atol=1e-10, + ) + + def test_update_sel(self) -> None: + """update_sel should return the same result on dp and pt.""" + from unittest.mock import ( + patch, + ) + + from deepmd.dpmodel.model.dp_model import DPModelCommon as DPModelCommonDP + from deepmd.pt.model.model.dp_model import DPModelCommon as DPModelCommonPT + + mock_min_nbor_dist = 0.5 + mock_sel = [10, 20] + local_jdata = { + "type_map": ["O", "H"], + "descriptor": { + "type": "se_e2_a", + "sel": "auto", + "rcut_smth": 0.50, + "rcut": 6.00, + }, + "fitting_net": { + "neuron": [5, 5], + }, + } + type_map = ["O", "H"] + + with patch( + "deepmd.dpmodel.utils.update_sel.UpdateSel.get_nbor_stat", + return_value=(mock_min_nbor_dist, mock_sel), + ): + dp_result, dp_min_dist = DPModelCommonDP.update_sel( + None, type_map, local_jdata + ) + + with patch( + "deepmd.pt.utils.update_sel.UpdateSel.get_nbor_stat", + return_value=(mock_min_nbor_dist, mock_sel), + ): + pt_result, pt_min_dist = DPModelCommonPT.update_sel( + None, type_map, local_jdata + ) + + self.assertEqual(dp_result, pt_result) + self.assertEqual(dp_min_dist, pt_min_dist) + # Verify sel was actually updated (not still "auto") + self.assertIsInstance(dp_result["descriptor"]["sel"], list) + self.assertNotEqual(dp_result["descriptor"]["sel"], "auto") + + def test_get_ntypes(self) -> None: + """get_ntypes should return the same value on dp and pt.""" + self.assertEqual(self.dp_model.get_ntypes(), self.pt_model.get_ntypes()) + self.assertEqual(self.dp_model.get_ntypes(), 2) + + def test_compute_or_load_out_stat(self) -> None: + """compute_or_load_out_stat should produce consistent bias on dp and pt. + + Tests both the compute path (from data) and the load path (from file). + Both backends should save the same stat file content and load identical + biases from file. + """ + import tempfile + from pathlib import ( + Path, + ) + + import h5py + + from deepmd.utils.path import ( + DPPath, + ) + + nframes = 2 + coords_2f = np.tile(self.coords, (nframes, 1, 1)) + atype_2f = np.array([[0, 0, 1, 1, 1, 1], [0, 1, 1, 0, 1, 1]], dtype=np.int32) + box_2f = np.tile(self.box.reshape(1, 3, 3), (nframes, 1, 1)) + natoms_data = np.array([[6, 6, 2, 4], [6, 6, 2, 4]], dtype=np.int32) + energy_data = np.array([10.0, 20.0]).reshape(nframes, 1) + + dp_merged = [ + { + "coord": coords_2f, + "atype": atype_2f, + "atype_ext": atype_2f, + "box": box_2f, + "natoms": natoms_data, + "energy": energy_data, + "find_energy": np.float32(1.0), + } + ] + pt_merged = [ + { + "coord": numpy_to_torch(coords_2f), + "atype": numpy_to_torch(atype_2f), + "atype_ext": numpy_to_torch(atype_2f), + "box": numpy_to_torch(box_2f), + "natoms": numpy_to_torch(natoms_data), + "energy": numpy_to_torch(energy_data), + "find_energy": np.float32(1.0), + } + ] + + # Verify bias is initially zero (or at least identical) + dp_bias_before = to_numpy_array(self.dp_model.get_out_bias()).copy() + pt_bias_before = torch_to_numpy(self.pt_model.get_out_bias()).copy() + np.testing.assert_allclose( + dp_bias_before, pt_bias_before, rtol=1e-10, atol=1e-10 + ) + + with tempfile.TemporaryDirectory() as tmpdir: + # Create separate h5 files for dp and pt + dp_h5 = str((Path(tmpdir) / "dp_stat.h5").resolve()) + pt_h5 = str((Path(tmpdir) / "pt_stat.h5").resolve()) + with h5py.File(dp_h5, "w"): + pass + with h5py.File(pt_h5, "w"): + pass + dp_stat_path = DPPath(dp_h5, "a") + pt_stat_path = DPPath(pt_h5, "a") + + # 1. Compute stats and save to file + self.dp_model.atomic_model.compute_or_load_out_stat( + dp_merged, stat_file_path=dp_stat_path + ) + self.pt_model.atomic_model.compute_or_load_out_stat( + pt_merged, stat_file_path=pt_stat_path + ) + + dp_bias_after = to_numpy_array(self.dp_model.get_out_bias()) + pt_bias_after = torch_to_numpy(self.pt_model.get_out_bias()) + np.testing.assert_allclose( + dp_bias_after, pt_bias_after, rtol=1e-10, atol=1e-10 + ) + + # Verify bias actually changed (not still all zeros) + self.assertFalse( + np.allclose(dp_bias_after, dp_bias_before), + "compute_or_load_out_stat did not change the bias", + ) + + # 2. Verify both backends saved the same file content + with h5py.File(dp_h5, "r") as dp_f, h5py.File(pt_h5, "r") as pt_f: + dp_keys = sorted(dp_f.keys()) + pt_keys = sorted(pt_f.keys()) + self.assertEqual(dp_keys, pt_keys) + for key in dp_keys: + np.testing.assert_allclose( + np.array(dp_f[key]), + np.array(pt_f[key]), + rtol=1e-10, + atol=1e-10, + err_msg=f"Stat file content mismatch for key {key}", + ) + + # 3. Reset biases to zero, then load from file + zero_bias = np.zeros_like(dp_bias_after) + self.dp_model.set_out_bias(zero_bias) + self.pt_model.set_out_bias(numpy_to_torch(zero_bias)) + + # Use a callable that raises to ensure it loads from file, not recomputes + def raise_error(): + raise RuntimeError("Should not recompute — should load from file") + + self.dp_model.atomic_model.compute_or_load_out_stat( + raise_error, stat_file_path=dp_stat_path + ) + self.pt_model.atomic_model.compute_or_load_out_stat( + raise_error, stat_file_path=pt_stat_path + ) + + dp_bias_loaded = to_numpy_array(self.dp_model.get_out_bias()) + pt_bias_loaded = torch_to_numpy(self.pt_model.get_out_bias()) + + # Loaded biases should match between backends + np.testing.assert_allclose( + dp_bias_loaded, pt_bias_loaded, rtol=1e-10, atol=1e-10 + ) + # Loaded biases should match the originally computed biases + np.testing.assert_allclose( + dp_bias_loaded, dp_bias_after, rtol=1e-10, atol=1e-10 + ) diff --git a/source/tests/consistent/model/test_frozen.py b/source/tests/consistent/model/test_frozen.py index ff7d651e7e..422fcd567a 100644 --- a/source/tests/consistent/model/test_frozen.py +++ b/source/tests/consistent/model/test_frozen.py @@ -151,7 +151,7 @@ def eval_pt(self, pt_obj: Any) -> Any: def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: # shape not matched. ravel... if backend is self.RefBackend.DP: - return (ret["energy_redu"].ravel(), ret["energy"].ravel()) + return (ret["energy"].ravel(), ret["atom_energy"].ravel()) elif backend is self.RefBackend.PT: return (ret["energy"].ravel(), ret["atom_energy"].ravel()) elif backend is self.RefBackend.TF: diff --git a/source/tests/consistent/model/test_polar.py b/source/tests/consistent/model/test_polar.py index 1405814f03..6a3d7b7443 100644 --- a/source/tests/consistent/model/test_polar.py +++ b/source/tests/consistent/model/test_polar.py @@ -184,8 +184,8 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: # shape not matched. ravel... if backend in {self.RefBackend.DP, self.RefBackend.JAX}: return ( - ret["polarizability_redu"].ravel(), - ret["polarizability"].ravel(), + ret["global_polar"].ravel(), + ret["polar"].ravel(), ) elif backend is self.RefBackend.PT: return ( diff --git a/source/tests/consistent/model/test_property.py b/source/tests/consistent/model/test_property.py index 75aded98fd..35859d86cb 100644 --- a/source/tests/consistent/model/test_property.py +++ b/source/tests/consistent/model/test_property.py @@ -186,8 +186,8 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: property_name = self.data["fitting_net"]["property_name"] if backend in {self.RefBackend.DP, self.RefBackend.JAX}: return ( - ret[f"{property_name}_redu"].ravel(), ret[property_name].ravel(), + ret[f"atom_{property_name}"].ravel(), ) elif backend is self.RefBackend.PT: return ( diff --git a/source/tests/consistent/model/test_zbl_ener.py b/source/tests/consistent/model/test_zbl_ener.py index 6fb44a59ed..83545696b3 100644 --- a/source/tests/consistent/model/test_zbl_ener.py +++ b/source/tests/consistent/model/test_zbl_ener.py @@ -209,8 +209,8 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: # shape not matched. ravel... if backend is self.RefBackend.DP: return ( - ret["energy_redu"].ravel(), ret["energy"].ravel(), + ret["atom_energy"].ravel(), SKIP_FLAG, SKIP_FLAG, ) @@ -225,9 +225,9 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: return (ret[0].ravel(), ret[1].ravel(), ret[2].ravel(), ret[3].ravel()) elif backend is self.RefBackend.JAX: return ( - ret["energy_redu"].ravel(), ret["energy"].ravel(), - ret["energy_derv_r"].ravel(), - ret["energy_derv_c_redu"].ravel(), + ret["atom_energy"].ravel(), + ret["force"].ravel(), + ret["virial"].ravel(), ) raise ValueError(f"Unknown backend: {backend}") diff --git a/source/tests/jax/test_dp_hessian_model.py b/source/tests/jax/test_dp_hessian_model.py index 89c066e980..00d9a7adee 100644 --- a/source/tests/jax/test_dp_hessian_model.py +++ b/source/tests/jax/test_dp_hessian_model.py @@ -82,34 +82,34 @@ def test_self_consistency(self): ret0 = md0.call(*args) ret1 = md1.call(*args) np.testing.assert_allclose( - to_numpy_array(ret0["energy"]), - to_numpy_array(ret1["energy"]), + to_numpy_array(ret0["atom_energy"]), + to_numpy_array(ret1["atom_energy"]), atol=self.atol, ) np.testing.assert_allclose( - to_numpy_array(ret0["energy_redu"]), - to_numpy_array(ret1["energy_redu"]), + to_numpy_array(ret0["energy"]), + to_numpy_array(ret1["energy"]), atol=self.atol, ) np.testing.assert_allclose( - to_numpy_array(ret0["energy_derv_r"]), - to_numpy_array(ret1["energy_derv_r"]), + to_numpy_array(ret0["force"]), + to_numpy_array(ret1["force"]), atol=self.atol, ) np.testing.assert_allclose( - to_numpy_array(ret0["energy_derv_c_redu"]), - to_numpy_array(ret1["energy_derv_c_redu"]), + to_numpy_array(ret0["virial"]), + to_numpy_array(ret1["virial"]), atol=self.atol, ) np.testing.assert_allclose( - to_numpy_array(ret0["energy_derv_r_derv_r"]), - to_numpy_array(ret1["energy_derv_r_derv_r"]), + to_numpy_array(ret0["hessian"]), + to_numpy_array(ret1["hessian"]), atol=self.atol, ) ret0 = md0.call(*args, do_atomic_virial=True) ret1 = md1.call(*args, do_atomic_virial=True) np.testing.assert_allclose( - to_numpy_array(ret0["energy_derv_c"]), - to_numpy_array(ret1["energy_derv_c"]), + to_numpy_array(ret0["atom_virial"]), + to_numpy_array(ret1["atom_virial"]), atol=self.atol, ) diff --git a/source/tests/jax/test_make_hessian_model.py b/source/tests/jax/test_make_hessian_model.py index 8666ff4ad4..679c0c37ce 100644 --- a/source/tests/jax/test_make_hessian_model.py +++ b/source/tests/jax/test_make_hessian_model.py @@ -100,7 +100,7 @@ def test( ) # compare hess and value models np.testing.assert_allclose(ret_dict0["energy"], ret_dict1["energy"]) - ana_hess = ret_dict0["energy_derv_r_derv_r"] + ana_hess = ret_dict0["hessian"] # compute finite difference fnt_hess = [] @@ -121,7 +121,7 @@ def np_infer( return ret def ff(xx): - return np_infer(xx)["energy_redu"] + return np_infer(xx)["energy"] xx = to_numpy_array(coord[ii]) fnt_hess.append(finite_hessian(ff, xx, delta=delta).squeeze()) diff --git a/source/tests/jax/test_padding_atoms.py b/source/tests/jax/test_padding_atoms.py index b63b464721..0f1b569821 100644 --- a/source/tests/jax/test_padding_atoms.py +++ b/source/tests/jax/test_padding_atoms.py @@ -89,8 +89,8 @@ def test_padding_atoms_consistency(self): result = model.call(*args) # test intensive np.testing.assert_allclose( - to_numpy_array(result[f"{var_name}_redu"]), - np.mean(to_numpy_array(result[f"{var_name}"]), axis=1), + to_numpy_array(result[var_name]), + np.mean(to_numpy_array(result[f"atom_{var_name}"]), axis=1), atol=self.atol, ) # test padding atoms @@ -115,8 +115,8 @@ def test_padding_atoms_consistency(self): ] result_padding = model.call(*args) np.testing.assert_allclose( - to_numpy_array(result[f"{var_name}_redu"]), - to_numpy_array(result_padding[f"{var_name}_redu"]), + to_numpy_array(result[var_name]), + to_numpy_array(result_padding[var_name]), atol=self.atol, ) diff --git a/source/tests/pd/model/test_dp_model.py b/source/tests/pd/model/test_dp_model.py index a281851f14..5e30b5ebaa 100644 --- a/source/tests/pd/model/test_dp_model.py +++ b/source/tests/pd/model/test_dp_model.py @@ -140,7 +140,7 @@ def test_dp_consistency(self): args1 = [to_paddle_tensor(ii) for ii in [self.coord, self.atype, self.cell]] kwargs0 = {"fparam": fparam, "aparam": aparam} kwargs1 = {kk: to_paddle_tensor(vv) for kk, vv in kwargs0.items()} - ret0 = md0.call(*args0, **kwargs0) + ret0 = md0.call_common(*args0, **kwargs0) ret1 = md1.forward_common(*args1, **kwargs1) np.testing.assert_allclose( ret0["energy"], @@ -179,7 +179,7 @@ def test_dp_consistency_nopbc(self): args1 = [to_paddle_tensor(ii) for ii in args0] kwargs0 = {"fparam": fparam, "aparam": aparam} kwargs1 = {kk: to_paddle_tensor(vv) for kk, vv in kwargs0.items()} - ret0 = md0.call(*args0, **kwargs0) + ret0 = md0.call_common(*args0, **kwargs0) ret1 = md1.forward_common(*args1, **kwargs1) np.testing.assert_allclose( ret0["energy"], @@ -313,7 +313,7 @@ def test_dp_consistency(self): args1 = [ to_paddle_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] ] - ret0 = md0.call_lower(*args0) + ret0 = md0.call_common_lower(*args0) ret1 = md1.forward_common_lower(*args1) np.testing.assert_allclose( ret0["energy"], diff --git a/source/tests/pt/model/test_dp_model.py b/source/tests/pt/model/test_dp_model.py index 93153ce6d5..f4e350869a 100644 --- a/source/tests/pt/model/test_dp_model.py +++ b/source/tests/pt/model/test_dp_model.py @@ -140,7 +140,7 @@ def test_dp_consistency(self) -> None: args1 = [to_torch_tensor(ii) for ii in [self.coord, self.atype, self.cell]] kwargs0 = {"fparam": fparam, "aparam": aparam} kwargs1 = {kk: to_torch_tensor(vv) for kk, vv in kwargs0.items()} - ret0 = md0.call(*args0, **kwargs0) + ret0 = md0.call_common(*args0, **kwargs0) ret1 = md1.forward_common(*args1, **kwargs1) np.testing.assert_allclose( ret0["energy"], @@ -179,7 +179,7 @@ def test_dp_consistency_nopbc(self) -> None: args1 = [to_torch_tensor(ii) for ii in args0] kwargs0 = {"fparam": fparam, "aparam": aparam} kwargs1 = {kk: to_torch_tensor(vv) for kk, vv in kwargs0.items()} - ret0 = md0.call(*args0, **kwargs0) + ret0 = md0.call_common(*args0, **kwargs0) ret1 = md1.forward_common(*args1, **kwargs1) np.testing.assert_allclose( ret0["energy"], @@ -313,7 +313,7 @@ def test_dp_consistency(self) -> None: args1 = [ to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] ] - ret0 = md0.call_lower(*args0) + ret0 = md0.call_common_lower(*args0) ret1 = md1.forward_common_lower(*args1) np.testing.assert_allclose( ret0["energy"], diff --git a/source/tests/pt_expt/model/test_autodiff.py b/source/tests/pt_expt/model/test_autodiff.py new file mode 100644 index 0000000000..de404b5b95 --- /dev/null +++ b/source/tests/pt_expt/model/test_autodiff.py @@ -0,0 +1,185 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import torch + +from deepmd.pt_expt.descriptor.se_e2_a import ( + DescrptSeA, +) +from deepmd.pt_expt.fitting import ( + InvarFitting, +) +from deepmd.pt_expt.model import ( + EnergyModel, +) +from deepmd.pt_expt.utils import ( + env, +) + +from ...seed import ( + GLOBAL_SEED, +) + +dtype = torch.float64 + + +def finite_difference(f, x, delta=1e-6): + in_shape = x.shape + y0 = f(x) + out_shape = y0.shape + res = np.empty(out_shape + in_shape) + for idx in np.ndindex(*in_shape): + diff = np.zeros(in_shape) + diff[idx] += delta + y1p = f(x + diff) + y1n = f(x - diff) + res[(Ellipsis, *idx)] = (y1p - y1n) / (2 * delta) + return res + + +def stretch_box(old_coord, old_box, new_box): + ocoord = old_coord.reshape(-1, 3) + obox = old_box.reshape(3, 3) + nbox = new_box.reshape(3, 3) + ncoord = ocoord @ np.linalg.inv(obox) @ nbox + return ncoord.reshape(old_coord.shape) + + +def eval_model(model, coord, cell, atype): + """Evaluate the pt_expt EnergyModel. + + Parameters + ---------- + model : EnergyModel + The model to evaluate. + coord : torch.Tensor + Coordinates, shape [nf, natoms, 3]. + cell : torch.Tensor + Cell, shape [nf, 3, 3]. + atype : torch.Tensor + Atom types, shape [natoms]. + + Returns + ------- + dict + Model predictions with keys: energy, force, virial. + """ + nframes = coord.shape[0] + if len(atype.shape) == 1: + atype = atype.unsqueeze(0).expand(nframes, -1) + coord_input = coord.to(dtype=dtype, device=env.DEVICE) + cell_input = cell.reshape(nframes, 9).to(dtype=dtype, device=env.DEVICE) + atype_input = atype.to(dtype=torch.long, device=env.DEVICE) + coord_input.requires_grad_(True) + result = model(coord_input, atype_input, cell_input) + return result + + +class ForceTest: + def test(self) -> None: + places = 5 + delta = 1e-5 + natoms = 5 + generator = torch.Generator(device="cpu").manual_seed(GLOBAL_SEED) + cell = torch.rand([3, 3], dtype=dtype, device="cpu", generator=generator) + cell = (cell + cell.T) + 5.0 * torch.eye(3, device="cpu") + coord = torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator) + coord = torch.matmul(coord, cell) + atype = torch.IntTensor([0, 0, 0, 1, 1]) + coord = coord.numpy() + + def np_infer_coord(coord): + result = eval_model( + self.model, + torch.tensor(coord, device=env.DEVICE).unsqueeze(0), + cell.unsqueeze(0), + atype, + ) + ret = { + key: result[key].squeeze(0).detach().cpu().numpy() + for key in ["energy", "force", "virial"] + } + return ret + + def ff_coord(_coord): + return np_infer_coord(_coord)["energy"] + + fdf = -finite_difference(ff_coord, coord, delta=delta).squeeze() + rff = np_infer_coord(coord)["force"] + np.testing.assert_almost_equal(fdf, rff, decimal=places) + + +class VirialTest: + def test(self) -> None: + places = 5 + delta = 1e-4 + natoms = 5 + generator = torch.Generator(device="cpu").manual_seed(GLOBAL_SEED) + cell = torch.rand([3, 3], dtype=dtype, device="cpu", generator=generator) + cell = (cell) + 5.0 * torch.eye(3, device="cpu") + coord = torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator) + coord = torch.matmul(coord, cell) + atype = torch.IntTensor([0, 0, 0, 1, 1]) + coord = coord.numpy() + cell = cell.numpy() + + def np_infer(new_cell): + result = eval_model( + self.model, + torch.tensor( + stretch_box(coord, cell, new_cell), device="cpu" + ).unsqueeze(0), + torch.tensor(new_cell, device="cpu").unsqueeze(0), + atype, + ) + ret = { + key: result[key].squeeze(0).detach().cpu().numpy() + for key in ["energy", "force", "virial"] + } + return ret + + def ff(bb): + return np_infer(bb)["energy"] + + fdv = ( + -(finite_difference(ff, cell, delta=delta).transpose(0, 2, 1) @ cell) + .squeeze() + .reshape(9) + ) + rfv = np_infer(cell)["virial"] + np.testing.assert_almost_equal(fdv, rfv, decimal=places) + + +class TestEnergyModelSeAForce(unittest.TestCase, ForceTest): + def setUp(self) -> None: + ds = DescrptSeA(4.0, 0.5, [8, 6]).to(env.DEVICE) + ft = InvarFitting( + "energy", + 2, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + seed=GLOBAL_SEED, + ).to(env.DEVICE) + self.model = EnergyModel(ds, ft, type_map=["foo", "bar"]).to(env.DEVICE) + self.model.eval() + + +class TestEnergyModelSeAVirial(unittest.TestCase, VirialTest): + def setUp(self) -> None: + ds = DescrptSeA(4.0, 0.5, [8, 6]).to(env.DEVICE) + ft = InvarFitting( + "energy", + 2, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + seed=GLOBAL_SEED, + ).to(env.DEVICE) + self.model = EnergyModel(ds, ft, type_map=["foo", "bar"]).to(env.DEVICE) + self.model.eval() + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt_expt/model/test_ener_model.py b/source/tests/pt_expt/model/test_ener_model.py new file mode 100644 index 0000000000..d65548996f --- /dev/null +++ b/source/tests/pt_expt/model/test_ener_model.py @@ -0,0 +1,381 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.descriptor import DescrptSeA as DPDescrptSeA +from deepmd.dpmodel.fitting import InvarFitting as DPInvarFitting +from deepmd.dpmodel.model.ener_model import EnergyModel as DPEnergyModel +from deepmd.dpmodel.utils.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, +) +from deepmd.dpmodel.utils.region import ( + normalize_coord, +) +from deepmd.pt_expt.descriptor.se_e2_a import ( + DescrptSeA, +) +from deepmd.pt_expt.fitting import ( + InvarFitting, +) +from deepmd.pt_expt.model import ( + EnergyModel, +) +from deepmd.pt_expt.utils import ( + env, +) + +from ...seed import ( + GLOBAL_SEED, +) + + +class TestEnergyModel(unittest.TestCase): + def setUp(self) -> None: + self.device = env.DEVICE + self.natoms = 5 + self.rcut = 4.0 + self.rcut_smth = 0.5 + self.sel = [8, 6] + self.nt = 2 + self.type_map = ["foo", "bar"] + + generator = torch.Generator(device=self.device).manual_seed(GLOBAL_SEED) + cell = torch.rand( + [3, 3], dtype=torch.float64, device=self.device, generator=generator + ) + cell = (cell + cell.T) + 5.0 * torch.eye(3, device=self.device) + self.cell = cell.unsqueeze(0) # [1, 3, 3] + coord = torch.rand( + [self.natoms, 3], + dtype=torch.float64, + device=self.device, + generator=generator, + ) + coord = torch.matmul(coord, cell) + self.coord = coord.unsqueeze(0).to(self.device) # [1, natoms, 3] + self.atype = torch.tensor( + [[0, 0, 0, 1, 1]], dtype=torch.int64, device=self.device + ) + + def _make_model( + self, + numb_fparam: int = 0, + numb_aparam: int = 0, + ) -> EnergyModel: + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(self.device) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + numb_fparam=numb_fparam, + numb_aparam=numb_aparam, + mixed_types=ds.mixed_types(), + seed=GLOBAL_SEED, + ).to(self.device) + return EnergyModel(ds, ft, type_map=self.type_map).to(self.device) + + def test_output_keys(self) -> None: + """Test that EnergyModel produces expected output keys.""" + md = self._make_model() + md.eval() + coord = self.coord.clone().requires_grad_(True) + ret = md(coord, self.atype, self.cell.reshape(1, 9)) + self.assertIn("energy", ret) + self.assertIn("atom_energy", ret) + self.assertIn("force", ret) + self.assertIn("virial", ret) + + def test_output_shapes(self) -> None: + """Test that output shapes are correct.""" + md = self._make_model() + md.eval() + coord = self.coord.clone().requires_grad_(True) + ret = md(coord, self.atype, self.cell.reshape(1, 9)) + self.assertEqual(ret["energy"].shape, (1, 1)) + self.assertEqual(ret["atom_energy"].shape, (1, self.natoms, 1)) + self.assertEqual(ret["force"].shape, (1, self.natoms, 3)) + self.assertEqual(ret["virial"].shape, (1, 9)) + + def _prepare_lower_inputs(self): + """Build extended coords, atype, nlist, mapping as torch tensors.""" + coord_np = self.coord.detach().cpu().numpy() + atype_np = self.atype.detach().cpu().numpy() + cell_np = self.cell.reshape(1, 9).detach().cpu().numpy() + coord_normalized = normalize_coord( + coord_np.reshape(1, self.natoms, 3), + cell_np.reshape(1, 3, 3), + ) + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( + coord_normalized, + atype_np, + cell_np, + self.rcut, + ) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + self.natoms, + self.rcut, + self.sel, + distinguish_types=True, + ) + extended_coord = extended_coord.reshape(1, -1, 3) + ext_coord = torch.tensor( + extended_coord, + dtype=torch.float64, + device=self.device, + ) + ext_atype = torch.tensor( + extended_atype, + dtype=torch.int64, + device=self.device, + ) + nlist_t = torch.tensor(nlist, dtype=torch.int64, device=self.device) + mapping_t = torch.tensor(mapping, dtype=torch.int64, device=self.device) + return ext_coord, ext_atype, nlist_t, mapping_t + + def test_forward_lower_exportable(self) -> None: + """Test that EnergyModel.forward_lower_exportable returns an exportable module. + + forward_lower_exportable() uses make_fx to trace through + torch.autograd.grad, decomposing the backward pass into primitive ops. + The returned module can be passed directly to torch.export.export. + + The test builds a model with numb_fparam > 0 and numb_aparam > 0 and + verifies that: + 1. The traced / exported module reproduces eager results (zero params). + 2. The traced / exported module reproduces eager results with non-zero + fparam and aparam (ruling out baked-in constants). + 3. Changing fparam or aparam at runtime actually changes the output. + """ + numb_fparam = 2 + numb_aparam = 3 + md = self._make_model( + numb_fparam=numb_fparam, + numb_aparam=numb_aparam, + ) + md.eval() + + ext_coord, ext_atype, nlist_t, mapping_t = self._prepare_lower_inputs() + nframes = ext_coord.shape[0] + nloc = self.natoms + output_keys = ("energy", "extended_force", "virial", "extended_virial") + + fparam_zero = torch.zeros( + nframes, + numb_fparam, + dtype=torch.float64, + device=self.device, + ) + aparam_zero = torch.zeros( + nframes, + nloc, + numb_aparam, + dtype=torch.float64, + device=self.device, + ) + + # --- eager reference with zero params --- + ret_eager_zero = md.forward_lower( + ext_coord.requires_grad_(True), + ext_atype, + nlist_t, + mapping_t, + fparam=fparam_zero, + aparam=aparam_zero, + do_atomic_virial=True, + ) + for key in output_keys: + self.assertIn(key, ret_eager_zero) + + # --- trace and export --- + traced = md.forward_lower_exportable( + ext_coord, + ext_atype, + nlist_t, + mapping_t, + fparam=fparam_zero, + aparam=aparam_zero, + do_atomic_virial=True, + ) + self.assertIsInstance(traced, torch.nn.Module) + + exported = torch.export.export( + traced, + (ext_coord, ext_atype, nlist_t, mapping_t, fparam_zero, aparam_zero), + strict=False, + ) + self.assertIsNotNone(exported) + + # --- verify traced/exported match eager (zero params) --- + ret_traced_zero = traced( + ext_coord, + ext_atype, + nlist_t, + mapping_t, + fparam_zero, + aparam_zero, + ) + ret_exported_zero = exported.module()( + ext_coord, + ext_atype, + nlist_t, + mapping_t, + fparam_zero, + aparam_zero, + ) + for key in output_keys: + np.testing.assert_allclose( + ret_eager_zero[key].detach().cpu().numpy(), + ret_traced_zero[key].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"traced vs eager (zero params): {key}", + ) + np.testing.assert_allclose( + ret_eager_zero[key].detach().cpu().numpy(), + ret_exported_zero[key].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"exported vs eager (zero params): {key}", + ) + + # --- verify traced/exported match eager (non-zero params) --- + fparam_nz = torch.ones( + nframes, + numb_fparam, + dtype=torch.float64, + device=self.device, + ) + aparam_nz = torch.ones( + nframes, + nloc, + numb_aparam, + dtype=torch.float64, + device=self.device, + ) + ret_eager_nz = md.forward_lower( + ext_coord.requires_grad_(True), + ext_atype, + nlist_t, + mapping_t, + fparam=fparam_nz, + aparam=aparam_nz, + do_atomic_virial=True, + ) + ret_traced_nz = traced( + ext_coord, + ext_atype, + nlist_t, + mapping_t, + fparam_nz, + aparam_nz, + ) + ret_exported_nz = exported.module()( + ext_coord, + ext_atype, + nlist_t, + mapping_t, + fparam_nz, + aparam_nz, + ) + for key in output_keys: + np.testing.assert_allclose( + ret_eager_nz[key].detach().cpu().numpy(), + ret_traced_nz[key].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"traced vs eager (non-zero params): {key}", + ) + np.testing.assert_allclose( + ret_eager_nz[key].detach().cpu().numpy(), + ret_exported_nz[key].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"exported vs eager (non-zero params): {key}", + ) + + # --- verify fparam is dynamic (changing it changes the output) --- + self.assertFalse( + np.allclose( + ret_traced_zero["energy"].detach().cpu().numpy(), + ret_traced_nz["energy"].detach().cpu().numpy(), + ), + "Changing fparam did not change output — " + "fparam may be baked in as a constant", + ) + + # --- verify aparam is dynamic (changing it changes the output) --- + ret_traced_ap = traced( + ext_coord, + ext_atype, + nlist_t, + mapping_t, + fparam_zero, + aparam_nz, + ) + self.assertFalse( + np.allclose( + ret_traced_zero["energy"].detach().cpu().numpy(), + ret_traced_ap["energy"].detach().cpu().numpy(), + ), + "Changing aparam did not change output — " + "aparam may be baked in as a constant", + ) + + def test_dp_consistency(self) -> None: + """Test numerical consistency with dpmodel (energy values).""" + # Build dpmodel version + ds_dp = DPDescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ) + ft_dp = DPInvarFitting( + "energy", + self.nt, + ds_dp.get_dim_out(), + 1, + mixed_types=ds_dp.mixed_types(), + seed=GLOBAL_SEED, + ) + md_dp = DPEnergyModel(ds_dp, ft_dp, type_map=self.type_map) + + # Build pt_expt version from serialized dpmodel + md_pt = EnergyModel.deserialize(md_dp.serialize()).to(self.device) + md_pt.eval() + + # dpmodel inference + coord_np = self.coord.detach().cpu().numpy() + atype_np = self.atype.detach().cpu().numpy() + cell_np = self.cell.reshape(1, 9).detach().cpu().numpy() + ret_dp = md_dp(coord_np.reshape(1, -1), atype_np, cell_np) + + # pt_expt inference + coord = self.coord.clone().requires_grad_(True) + ret_pt = md_pt(coord, self.atype, self.cell.reshape(1, 9)) + + np.testing.assert_allclose( + ret_dp["energy"], + ret_pt["energy"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) + np.testing.assert_allclose( + ret_dp["atom_energy"], + ret_pt["atom_energy"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/universal/dpmodel/model/test_model.py b/source/tests/universal/dpmodel/model/test_model.py index 815c612bb0..c82074c601 100644 --- a/source/tests/universal/dpmodel/model/test_model.py +++ b/source/tests/universal/dpmodel/model/test_model.py @@ -164,7 +164,7 @@ def setUpClass(cls) -> None: ft, type_map=cls.expected_type_map, ) - cls.output_def = cls.module.model_output_def().get_data() + cls.output_def = cls.module.translated_output_def() cls.expected_has_message_passing = ds.has_message_passing() cls.expected_sel_type = ft.get_sel_type() cls.expected_dim_fparam = ft.get_dim_fparam() @@ -271,7 +271,7 @@ def setUpClass(cls) -> None: pair_exclude_types=pair_exclude_types, ) cls.module = SpinModel(backbone_model=backbone_model, spin=spin) - cls.output_def = cls.module.model_output_def().get_data() + cls.output_def = cls.module.translated_output_def() cls.expected_has_message_passing = ds.has_message_passing() cls.expected_sel_type = ft.get_sel_type() cls.expected_dim_fparam = ft.get_dim_fparam()