@@ -579,6 +579,7 @@ def call_common(
579579 box : Array | None = None ,
580580 fparam : Array | None = None ,
581581 aparam : Array | None = None ,
582+ charge_spin : Array | None = None ,
582583 do_atomic_virial : bool = False ,
583584 ) -> dict [str , Array ]:
584585 """Return model prediction with raw internal keys.
@@ -624,6 +625,7 @@ def call_common(
624625 box ,
625626 fparam = fparam ,
626627 aparam = aparam ,
628+ charge_spin = charge_spin ,
627629 do_atomic_virial = do_atomic_virial ,
628630 coord_corr_for_virial = coord_corr_for_virial ,
629631 )
@@ -673,6 +675,7 @@ def call(
673675 box : Array | None = None ,
674676 fparam : Array | None = None ,
675677 aparam : Array | None = None ,
678+ charge_spin : Array | None = None ,
676679 do_atomic_virial : bool = False ,
677680 ) -> dict [str , Array ]:
678681 """Return model prediction with translated user-facing keys.
@@ -710,6 +713,7 @@ def call(
710713 box ,
711714 fparam = fparam ,
712715 aparam = aparam ,
716+ charge_spin = charge_spin ,
713717 do_atomic_virial = do_atomic_virial ,
714718 )
715719 model_output_type = self .backbone_model .model_output_type ()
@@ -747,6 +751,7 @@ def call_common_lower(
747751 mapping : Array | None = None ,
748752 fparam : Array | None = None ,
749753 aparam : Array | None = None ,
754+ charge_spin : Array | None = None ,
750755 do_atomic_virial : bool = False ,
751756 ) -> dict [str , Array ]:
752757 """Return model prediction with raw internal keys. Lower interface that takes
@@ -798,6 +803,7 @@ def call_common_lower(
798803 mapping = mapping_updated ,
799804 fparam = fparam ,
800805 aparam = aparam ,
806+ charge_spin = charge_spin ,
801807 do_atomic_virial = do_atomic_virial ,
802808 extended_coord_corr = extended_coord_corr ,
803809 )
@@ -851,6 +857,7 @@ def call_lower(
851857 mapping : Array | None = None ,
852858 fparam : Array | None = None ,
853859 aparam : Array | None = None ,
860+ charge_spin : Array | None = None ,
854861 do_atomic_virial : bool = False ,
855862 ) -> dict [str , Array ]:
856863 """Return model prediction with translated user-facing keys. Lower interface.
@@ -889,6 +896,7 @@ def call_lower(
889896 mapping = mapping ,
890897 fparam = fparam ,
891898 aparam = aparam ,
899+ charge_spin = charge_spin ,
892900 do_atomic_virial = do_atomic_virial ,
893901 )
894902 model_output_type = self .backbone_model .model_output_type ()
0 commit comments