2121 PRECISION_DICT ,
2222 RESERVED_PRECISION_DICT ,
2323 NativeOP ,
24+ get_xp_precision ,
2425)
2526from deepmd .dpmodel .model .base_model import (
2627 BaseModel ,
@@ -103,7 +104,8 @@ def model_call_from_call_lower(
103104 bb .reshape (nframes , 3 , 3 ),
104105 )
105106 else :
106- coord_normalized = cc .copy ()
107+ xp = array_api_compat .array_namespace (cc )
108+ coord_normalized = xp .reshape (cc , (nframes , nloc , 3 ))
107109 extended_coord , extended_atype , mapping = extend_coord_with_ghosts (
108110 coord_normalized , atype , bb , rcut
109111 )
@@ -255,7 +257,7 @@ def call(
255257 The keys are defined by the `ModelOutputDef`.
256258
257259 """
258- cc , bb , fp , ap , input_prec = self .input_type_cast (
260+ cc , bb , fp , ap , input_prec = self ._input_type_cast (
259261 coord , box = box , fparam = fparam , aparam = aparam
260262 )
261263 del coord , box , fparam , aparam
@@ -272,7 +274,7 @@ def call(
272274 aparam = ap ,
273275 do_atomic_virial = do_atomic_virial ,
274276 )
275- model_predict = self .output_type_cast (model_predict , input_prec )
277+ model_predict = self ._output_type_cast (model_predict , input_prec )
276278 return model_predict
277279
278280 def call_lower (
@@ -321,7 +323,7 @@ def call_lower(
321323 nlist ,
322324 extra_nlist_sort = self .need_sorted_nlist_for_lower (),
323325 )
324- cc_ext , _ , fp , ap , input_prec = self .input_type_cast (
326+ cc_ext , _ , fp , ap , input_prec = self ._input_type_cast (
325327 extended_coord , fparam = fparam , aparam = aparam
326328 )
327329 del extended_coord , fparam , aparam
@@ -334,7 +336,7 @@ def call_lower(
334336 aparam = ap ,
335337 do_atomic_virial = do_atomic_virial ,
336338 )
337- model_predict = self .output_type_cast (model_predict , input_prec )
339+ model_predict = self ._output_type_cast (model_predict , input_prec )
338340 return model_predict
339341
340342 def forward_common_atomic (
@@ -364,60 +366,107 @@ def forward_common_atomic(
364366 )
365367
366368 forward_lower = call_lower
369+ forward_common = call
370+ forward_common_lower = call_lower
367371
368- def input_type_cast (
372+ def get_out_bias (self ) -> Array :
373+ """Get the output bias."""
374+ return self .atomic_model .out_bias
375+
376+ def set_out_bias (self , out_bias : Array ) -> None :
377+ """Set the output bias."""
378+ self .atomic_model .out_bias = out_bias
379+
380+ def change_out_bias (
381+ self ,
382+ merged : Any ,
383+ bias_adjust_mode : str = "change-by-statistic" ,
384+ ) -> None :
385+ """Change the output bias according to the input data and the pretrained model.
386+
387+ Parameters
388+ ----------
389+ merged
390+ The merged data samples.
391+ bias_adjust_mode : str
392+ The mode for changing output bias:
393+ 'change-by-statistic' or 'set-by-statistic'.
394+ """
395+ self .atomic_model .change_out_bias (merged , bias_adjust_mode = bias_adjust_mode )
396+
397+ def _input_type_cast (
369398 self ,
370399 coord : Array ,
371400 box : Array | None = None ,
372401 fparam : Array | None = None ,
373402 aparam : Array | None = None ,
374- ) -> tuple [Array , Array , np . ndarray | None , np . ndarray | None , str ]:
403+ ) -> tuple [Array , Array | None , Array | None , Array | None , Any ]:
375404 """Cast the input data to global float type."""
376- input_prec = RESERVED_PRECISION_DICT [self .precision_dict [coord .dtype .name ]]
405+ xp = array_api_compat .array_namespace (coord )
406+ input_dtype = coord .dtype
407+ global_dtype = get_xp_precision (
408+ xp , RESERVED_PRECISION_DICT [self .global_np_float_precision ]
409+ )
377410 ###
378411 ### type checking would not pass jit, convert to coord prec anyway
379412 ###
380- _lst : list [np . ndarray | None ] = [
381- vv .astype (coord . dtype ) if vv is not None else None
413+ _lst : list [Array | None ] = [
414+ xp .astype (vv , input_dtype ) if vv is not None else None
382415 for vv in [box , fparam , aparam ]
383416 ]
384417 box , fparam , aparam = _lst
385- if input_prec == RESERVED_PRECISION_DICT [ self . global_np_float_precision ] :
386- return coord , box , fparam , aparam , input_prec
418+ if input_dtype == global_dtype :
419+ return coord , box , fparam , aparam , input_dtype
387420 else :
388- pp = self .global_np_float_precision
389421 return (
390- coord .astype (pp ),
391- box .astype (pp ) if box is not None else None ,
392- fparam .astype (pp ) if fparam is not None else None ,
393- aparam .astype (pp ) if aparam is not None else None ,
394- input_prec ,
422+ xp .astype (coord , global_dtype ),
423+ xp .astype (box , global_dtype ) if box is not None else None ,
424+ xp .astype (fparam , global_dtype ) if fparam is not None else None ,
425+ xp .astype (aparam , global_dtype ) if aparam is not None else None ,
426+ input_dtype ,
395427 )
396428
397- def output_type_cast (
429+ def _output_type_cast (
398430 self ,
399431 model_ret : dict [str , Array ],
400- input_prec : str ,
432+ input_prec : Any ,
401433 ) -> dict [str , Array ]:
402- """Convert the model output to the input prec."""
403- do_cast = (
404- input_prec != RESERVED_PRECISION_DICT [self .global_np_float_precision ]
434+ """Convert the model output to the input prec.
435+
436+ Parameters
437+ ----------
438+ model_ret
439+ The model output.
440+ input_prec
441+ The input dtype returned by ``_input_type_cast``.
442+ """
443+ model_ret_not_none = [vv for vv in model_ret .values () if vv is not None ]
444+ if not model_ret_not_none :
445+ return model_ret
446+ xp = array_api_compat .array_namespace (model_ret_not_none [0 ])
447+ global_dtype = get_xp_precision (
448+ xp , RESERVED_PRECISION_DICT [self .global_np_float_precision ]
449+ )
450+ ener_dtype = get_xp_precision (
451+ xp , RESERVED_PRECISION_DICT [self .global_ener_float_precision ]
405452 )
406- pp = self . precision_dict [ input_prec ]
453+ do_cast = input_prec != global_dtype
407454 odef = self .model_output_def ()
408455 for kk in odef .keys ():
409456 if kk not in model_ret .keys ():
410457 # do not return energy_derv_c if not do_atomic_virial
411458 continue
412459 if check_operation_applied (odef [kk ], OutputVariableOperation .REDU ):
413460 model_ret [kk ] = (
414- model_ret [kk ]. astype ( self . global_ener_float_precision )
461+ xp . astype ( model_ret [kk ], ener_dtype )
415462 if model_ret [kk ] is not None
416463 else None
417464 )
418465 elif do_cast :
419466 model_ret [kk ] = (
420- model_ret [kk ].astype (pp ) if model_ret [kk ] is not None else None
467+ xp .astype (model_ret [kk ], input_prec )
468+ if model_ret [kk ] is not None
469+ else None
421470 )
422471 return model_ret
423472
0 commit comments