@@ -155,7 +155,7 @@ def get_model_sels(self) -> list[int | list[int]]:
155155 def _sort_rcuts_sels (self ) -> tuple [tuple [Array , Array ], list [int ]]:
156156 # sort the pair of rcut and sels in ascending order, first based on sel, then on rcut.
157157 zipped = sorted (
158- zip (self .get_model_rcuts (), self .get_model_nsels ()),
158+ zip (self .get_model_rcuts (), self .get_model_nsels (), strict = False ),
159159 key = lambda x : (x [1 ], x [0 ]),
160160 )
161161 return [p [0 ] for p in zipped ], [p [1 ] for p in zipped ]
@@ -235,12 +235,12 @@ def forward_atomic(
235235 )
236236 raw_nlists = [
237237 nlists [get_multiple_nlist_key (rcut , sel )]
238- for rcut , sel in zip (self .get_model_rcuts (), self .get_model_nsels ())
238+ for rcut , sel in zip (self .get_model_rcuts (), self .get_model_nsels (), strict = False )
239239 ]
240240 nlists_ = [
241241 nl if mt else nlist_distinguish_types (nl , extended_atype , sel )
242242 for mt , nl , sel in zip (
243- self .mixed_types_list , raw_nlists , self .get_model_sels ()
243+ self .mixed_types_list , raw_nlists , self .get_model_sels (), strict = False
244244 )
245245 ]
246246 ener_list = []
0 commit comments