@@ -298,10 +298,10 @@ def call_common(
298298 The keys are defined by the `ModelOutputDef`.
299299
300300 """
301- cc , bb , fp , ap , input_prec = self ._input_type_cast (
302- coord , box = box , fparam = fparam , aparam = aparam
301+ cc , bb , fp , ap , cs , input_prec = self ._input_type_cast (
302+ coord , box = box , fparam = fparam , aparam = aparam , charge_spin = charge_spin
303303 )
304- del coord , box , fparam , aparam
304+ del coord , box , fparam , aparam , charge_spin
305305 model_predict = model_call_from_call_lower (
306306 call_lower = self .call_common_lower ,
307307 rcut = self .get_rcut (),
@@ -315,7 +315,7 @@ def call_common(
315315 aparam = ap ,
316316 do_atomic_virial = do_atomic_virial ,
317317 coord_corr_for_virial = coord_corr_for_virial ,
318- charge_spin = charge_spin ,
318+ charge_spin = cs ,
319319 )
320320 model_predict = self ._output_type_cast (model_predict , input_prec )
321321 return model_predict
@@ -377,10 +377,10 @@ def call_common_lower(
377377 nlist ,
378378 extra_nlist_sort = self .need_sorted_nlist_for_lower (),
379379 )
380- cc_ext , _ , fp , ap , input_prec = self ._input_type_cast (
381- extended_coord , fparam = fparam , aparam = aparam
380+ cc_ext , _ , fp , ap , cs , input_prec = self ._input_type_cast (
381+ extended_coord , fparam = fparam , aparam = aparam , charge_spin = charge_spin
382382 )
383- del extended_coord , fparam , aparam
383+ del extended_coord , fparam , aparam , charge_spin
384384 model_predict = self .forward_common_atomic (
385385 cc_ext ,
386386 extended_atype ,
@@ -391,7 +391,7 @@ def call_common_lower(
391391 do_atomic_virial = do_atomic_virial ,
392392 extended_coord_corr = extended_coord_corr ,
393393 comm_dict = comm_dict ,
394- charge_spin = charge_spin ,
394+ charge_spin = cs ,
395395 )
396396 model_predict = self ._output_type_cast (model_predict , input_prec )
397397 return model_predict
@@ -482,7 +482,8 @@ def _input_type_cast(
482482 box : Array | None = None ,
483483 fparam : Array | None = None ,
484484 aparam : Array | None = None ,
485- ) -> tuple [Array , Array | None , Array | None , Array | None , Any ]:
485+ charge_spin : Array | None = None ,
486+ ) -> tuple [Array , Array | None , Array | None , Array | None , Array | None , Any ]:
486487 """Cast the input data to global float type."""
487488 xp = array_api_compat .array_namespace (coord )
488489 input_dtype = coord .dtype
@@ -494,17 +495,20 @@ def _input_type_cast(
494495 ###
495496 _lst : list [Array | None ] = [
496497 xp .astype (vv , input_dtype ) if vv is not None else None
497- for vv in [box , fparam , aparam ]
498+ for vv in [box , fparam , aparam , charge_spin ]
498499 ]
499- box , fparam , aparam = _lst
500+ box , fparam , aparam , charge_spin = _lst
500501 if input_dtype == global_dtype :
501- return coord , box , fparam , aparam , input_dtype
502+ return coord , box , fparam , aparam , charge_spin , input_dtype
502503 else :
503504 return (
504505 xp .astype (coord , global_dtype ),
505506 xp .astype (box , global_dtype ) if box is not None else None ,
506507 xp .astype (fparam , global_dtype ) if fparam is not None else None ,
507508 xp .astype (aparam , global_dtype ) if aparam is not None else None ,
509+ xp .astype (charge_spin , global_dtype )
510+ if charge_spin is not None
511+ else None ,
508512 input_dtype ,
509513 )
510514
0 commit comments