@@ -526,8 +526,8 @@ def forward_common(
526526 box : torch .Tensor | None = None ,
527527 fparam : torch .Tensor | None = None ,
528528 aparam : torch .Tensor | None = None ,
529- charge_spin : torch .Tensor | None = None ,
530529 do_atomic_virial : bool = False ,
530+ charge_spin : torch .Tensor | None = None ,
531531 ) -> dict [str , torch .Tensor ]:
532532 nframes , nloc = atype .shape
533533 coord_updated , atype_updated , coord_corr_for_virial = self .process_spin_input (
@@ -580,10 +580,10 @@ def forward_common_lower(
580580 mapping : torch .Tensor | None = None ,
581581 fparam : torch .Tensor | None = None ,
582582 aparam : torch .Tensor | None = None ,
583- charge_spin : torch .Tensor | None = None ,
584583 do_atomic_virial : bool = False ,
585584 comm_dict : dict [str , torch .Tensor ] | None = None ,
586585 extra_nlist_sort : bool = False ,
586+ charge_spin : torch .Tensor | None = None ,
587587 ) -> dict [str , torch .Tensor ]:
588588 nframes , nloc = nlist .shape [:2 ]
589589 (
@@ -699,8 +699,8 @@ def forward(
699699 box : torch .Tensor | None = None ,
700700 fparam : torch .Tensor | None = None ,
701701 aparam : torch .Tensor | None = None ,
702- charge_spin : torch .Tensor | None = None ,
703702 do_atomic_virial : bool = False ,
703+ charge_spin : torch .Tensor | None = None ,
704704 ) -> dict [str , torch .Tensor ]:
705705 model_ret = self .forward_common (
706706 coord ,
@@ -735,9 +735,9 @@ def forward_lower(
735735 mapping : torch .Tensor | None = None ,
736736 fparam : torch .Tensor | None = None ,
737737 aparam : torch .Tensor | None = None ,
738- charge_spin : torch .Tensor | None = None ,
739738 do_atomic_virial : bool = False ,
740739 comm_dict : dict [str , torch .Tensor ] | None = None ,
740+ charge_spin : torch .Tensor | None = None ,
741741 ) -> dict [str , torch .Tensor ]:
742742 model_ret = self .forward_common_lower (
743743 extended_coord ,
0 commit comments