File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -290,6 +290,7 @@ def forward_atomic(
290290 mapping ,
291291 fparam ,
292292 aparam ,
293+ comm_dict = comm_dict ,
293294 )["energy" ]
294295 )
295296 weights = self ._compute_weight (extended_coord , extended_atype , nlists_ )
Original file line number Diff line number Diff line change @@ -98,6 +98,7 @@ def forward_lower(
9898 fparam : Optional [torch .Tensor ] = None ,
9999 aparam : Optional [torch .Tensor ] = None ,
100100 do_atomic_virial : bool = False ,
101+ comm_dict : Optional [dict [str , torch .Tensor ]] = None ,
101102 ):
102103 model_ret = self .forward_common_lower (
103104 extended_coord ,
@@ -107,6 +108,7 @@ def forward_lower(
107108 fparam = fparam ,
108109 aparam = aparam ,
109110 do_atomic_virial = do_atomic_virial ,
111+ comm_dict = comm_dict ,
110112 extra_nlist_sort = self .need_sorted_nlist_for_lower (),
111113 )
112114 if self .get_fitting_net () is not None :
Original file line number Diff line number Diff line change @@ -88,6 +88,7 @@ def forward_lower(
8888 fparam : Optional [torch .Tensor ] = None ,
8989 aparam : Optional [torch .Tensor ] = None ,
9090 do_atomic_virial : bool = False ,
91+ comm_dict : Optional [dict [str , torch .Tensor ]] = None ,
9192 ):
9293 model_ret = self .forward_common_lower (
9394 extended_coord ,
@@ -97,6 +98,7 @@ def forward_lower(
9798 fparam = fparam ,
9899 aparam = aparam ,
99100 do_atomic_virial = do_atomic_virial ,
101+ comm_dict = comm_dict ,
100102 extra_nlist_sort = self .need_sorted_nlist_for_lower (),
101103 )
102104 if self .get_fitting_net () is not None :
Original file line number Diff line number Diff line change @@ -97,6 +97,7 @@ def forward_lower(
9797 fparam : Optional [torch .Tensor ] = None ,
9898 aparam : Optional [torch .Tensor ] = None ,
9999 do_atomic_virial : bool = False ,
100+ comm_dict : Optional [dict [str , torch .Tensor ]] = None ,
100101 ):
101102 model_ret = self .forward_common_lower (
102103 extended_coord ,
@@ -106,6 +107,7 @@ def forward_lower(
106107 fparam = fparam ,
107108 aparam = aparam ,
108109 do_atomic_virial = do_atomic_virial ,
110+ comm_dict = comm_dict ,
109111 extra_nlist_sort = self .need_sorted_nlist_for_lower (),
110112 )
111113
Original file line number Diff line number Diff line change @@ -97,6 +97,7 @@ def forward_lower(
9797 fparam : Optional [torch .Tensor ] = None ,
9898 aparam : Optional [torch .Tensor ] = None ,
9999 do_atomic_virial : bool = False ,
100+ comm_dict : Optional [dict [str , torch .Tensor ]] = None ,
100101 ):
101102 model_ret = self .forward_common_lower (
102103 extended_coord ,
@@ -106,6 +107,7 @@ def forward_lower(
106107 fparam = fparam ,
107108 aparam = aparam ,
108109 do_atomic_virial = do_atomic_virial ,
110+ comm_dict = comm_dict ,
109111 extra_nlist_sort = self .need_sorted_nlist_for_lower (),
110112 )
111113
Original file line number Diff line number Diff line change @@ -82,6 +82,7 @@ def forward_lower(
8282 fparam : Optional [torch .Tensor ] = None ,
8383 aparam : Optional [torch .Tensor ] = None ,
8484 do_atomic_virial : bool = False ,
85+ comm_dict : Optional [dict [str , torch .Tensor ]] = None ,
8586 ):
8687 model_ret = self .forward_common_lower (
8788 extended_coord ,
@@ -91,6 +92,7 @@ def forward_lower(
9192 fparam = fparam ,
9293 aparam = aparam ,
9394 do_atomic_virial = do_atomic_virial ,
95+ comm_dict = comm_dict ,
9496 extra_nlist_sort = self .need_sorted_nlist_for_lower (),
9597 )
9698 if self .get_fitting_net () is not None :
You can’t perform that action at this time.
0 commit comments