@@ -121,7 +121,9 @@ def test_export_pipeline(self, descriptor_type, with_fparam) -> None:
121121 inputs_trace = _make_sample_inputs (model2 , nframes = 5 , nloc = 7 )
122122 finally :
123123 _env .DEVICE = orig_device
124- ext_coord , ext_atype , nlist_t , mapping_t , fparam , aparam = inputs_trace
124+ ext_coord , ext_atype , nlist_t , mapping_t , fparam , aparam , charge_spin = (
125+ inputs_trace
126+ )
125127
126128 # 4. Eager reference
127129 eager_out = model2 .forward_common_lower (
@@ -132,6 +134,7 @@ def test_export_pipeline(self, descriptor_type, with_fparam) -> None:
132134 fparam = fparam ,
133135 aparam = aparam ,
134136 do_atomic_virial = True ,
137+ charge_spin = charge_spin ,
135138 )
136139
137140 # 5. Trace with symbolic mode (same as dp freeze)
@@ -142,6 +145,7 @@ def test_export_pipeline(self, descriptor_type, with_fparam) -> None:
142145 mapping_t ,
143146 fparam = fparam ,
144147 aparam = aparam ,
148+ charge_spin = charge_spin ,
145149 do_atomic_virial = True ,
146150 tracing_mode = "symbolic" ,
147151 _allow_non_fake_inputs = True ,
@@ -155,11 +159,12 @@ def test_export_pipeline(self, descriptor_type, with_fparam) -> None:
155159 mapping_t ,
156160 fparam ,
157161 aparam ,
162+ charge_spin ,
158163 model_nnei = sum (model2 .get_sel ()),
159164 )
160165 exported = torch .export .export (
161166 traced ,
162- (ext_coord , ext_atype , nlist_t , mapping_t , fparam , aparam ),
167+ (ext_coord , ext_atype , nlist_t , mapping_t , fparam , aparam , charge_spin ),
163168 dynamic_shapes = dynamic_shapes ,
164169 strict = False ,
165170 prefer_deferred_runtime_asserts_over_guards = True ,
@@ -171,7 +176,9 @@ def test_export_pipeline(self, descriptor_type, with_fparam) -> None:
171176 loaded = torch .export .load (tmp .name ).module ()
172177
173178 # 8. Verify: traced output matches eager (same shapes as trace)
174- traced_out = traced (ext_coord , ext_atype , nlist_t , mapping_t , fparam , aparam )
179+ traced_out = traced (
180+ ext_coord , ext_atype , nlist_t , mapping_t , fparam , aparam , charge_spin
181+ )
175182 for key in eager_out :
176183 np .testing .assert_allclose (
177184 eager_out [key ].detach ().cpu ().numpy (),
@@ -182,7 +189,9 @@ def test_export_pipeline(self, descriptor_type, with_fparam) -> None:
182189 )
183190
184191 # 9. Verify: loaded (.pte) output matches eager (same shapes)
185- loaded_out = loaded (ext_coord , ext_atype , nlist_t , mapping_t , fparam , aparam )
192+ loaded_out = loaded (
193+ ext_coord , ext_atype , nlist_t , mapping_t , fparam , aparam , charge_spin
194+ )
186195 for key in eager_out :
187196 np .testing .assert_allclose (
188197 eager_out [key ].detach ().cpu ().numpy (),
@@ -206,6 +215,7 @@ def test_export_pipeline(self, descriptor_type, with_fparam) -> None:
206215 mapping_t2 ,
207216 fparam2 ,
208217 aparam2 ,
218+ charge_spin2 ,
209219 ) = inputs_infer
210220
211221 eager_out2 = model2 .forward_common_lower (
@@ -216,9 +226,16 @@ def test_export_pipeline(self, descriptor_type, with_fparam) -> None:
216226 fparam = fparam2 ,
217227 aparam = aparam2 ,
218228 do_atomic_virial = True ,
229+ charge_spin = charge_spin2 ,
219230 )
220231 loaded_out2 = loaded (
221- ext_coord2 , ext_atype2 , nlist_t2 , mapping_t2 , fparam2 , aparam2
232+ ext_coord2 ,
233+ ext_atype2 ,
234+ nlist_t2 ,
235+ mapping_t2 ,
236+ fparam2 ,
237+ aparam2 ,
238+ charge_spin2 ,
222239 )
223240 for key in eager_out2 :
224241 np .testing .assert_allclose (
@@ -248,9 +265,16 @@ def test_export_pipeline(self, descriptor_type, with_fparam) -> None:
248265 fparam = fparam_ones ,
249266 aparam = aparam ,
250267 do_atomic_virial = True ,
268+ charge_spin = charge_spin ,
251269 )
252270 loaded_out_fp1 = loaded (
253- ext_coord , ext_atype , nlist_t , mapping_t , fparam_ones , aparam
271+ ext_coord ,
272+ ext_atype ,
273+ nlist_t ,
274+ mapping_t ,
275+ fparam_ones ,
276+ aparam ,
277+ charge_spin ,
254278 )
255279 # Loaded with fparam=1 should match eager with fparam=1
256280 for key in eager_out_fp1 :
0 commit comments