@@ -119,8 +119,9 @@ def _make_sample_inputs(
119119 Returns
120120 -------
121121 tuple
122- (ext_coord, ext_atype, nlist, mapping, fparam, aparam) or
123- (ext_coord, ext_atype, ext_spin, nlist, mapping, fparam, aparam) when has_spin.
122+ (ext_coord, ext_atype, nlist, mapping, fparam, aparam, charge_spin) or
123+ (ext_coord, ext_atype, ext_spin, nlist, mapping, fparam, aparam,
124+ charge_spin) when has_spin.
124125 """
125126 rcut = model .get_rcut ()
126127 sel = model .get_sel ()
@@ -187,14 +188,31 @@ def _make_sample_inputs(
187188 else :
188189 aparam = None
189190
191+ dim_chg_spin = model .get_dim_chg_spin () if hasattr (model , "get_dim_chg_spin" ) else 0
192+ if dim_chg_spin > 0 :
193+ charge_spin = torch .zeros (
194+ nframes , dim_chg_spin , dtype = torch .float64 , device = _env .DEVICE
195+ )
196+ else :
197+ charge_spin = None
198+
190199 if has_spin :
191200 nall = extended_coord .shape [1 ]
192201 ext_spin = torch .zeros (
193202 nframes , nall , 3 , dtype = torch .float64 , device = _env .DEVICE
194203 )
195- return ext_coord , ext_atype , ext_spin , nlist_t , mapping_t , fparam , aparam
204+ return (
205+ ext_coord ,
206+ ext_atype ,
207+ ext_spin ,
208+ nlist_t ,
209+ mapping_t ,
210+ fparam ,
211+ aparam ,
212+ charge_spin ,
213+ )
196214
197- return ext_coord , ext_atype , nlist_t , mapping_t , fparam , aparam
215+ return ext_coord , ext_atype , nlist_t , mapping_t , fparam , aparam , charge_spin
198216
199217
200218def _build_dynamic_shapes (
@@ -224,9 +242,10 @@ def _build_dynamic_shapes(
224242 nnei_dim = torch .export .Dim ("nnei" , min = max (1 , model_nnei ))
225243
226244 if has_spin :
227- # (ext_coord, ext_atype, ext_spin, nlist, mapping, fparam, aparam)
245+ # (ext_coord, ext_atype, ext_spin, nlist, mapping, fparam, aparam, charge_spin )
228246 fparam = sample_inputs [5 ]
229247 aparam = sample_inputs [6 ]
248+ charge_spin = sample_inputs [7 ]
230249 return (
231250 {0 : nframes_dim , 1 : nall_dim }, # extended_coord: (nframes, nall, 3)
232251 {0 : nframes_dim , 1 : nall_dim }, # extended_atype: (nframes, nall)
@@ -239,11 +258,13 @@ def _build_dynamic_shapes(
239258 {0 : nframes_dim , 1 : nall_dim }, # mapping: (nframes, nall)
240259 {0 : nframes_dim } if fparam is not None else None , # fparam
241260 {0 : nframes_dim , 1 : nloc_dim } if aparam is not None else None , # aparam
261+ {0 : nframes_dim } if charge_spin is not None else None , # charge_spin
242262 )
243263 else :
244- # (ext_coord, ext_atype, nlist, mapping, fparam, aparam)
264+ # (ext_coord, ext_atype, nlist, mapping, fparam, aparam, charge_spin )
245265 fparam = sample_inputs [4 ]
246266 aparam = sample_inputs [5 ]
267+ charge_spin = sample_inputs [6 ]
247268 return (
248269 {0 : nframes_dim , 1 : nall_dim }, # extended_coord: (nframes, nall, 3)
249270 {0 : nframes_dim , 1 : nall_dim }, # extended_atype: (nframes, nall)
@@ -255,6 +276,7 @@ def _build_dynamic_shapes(
255276 {0 : nframes_dim , 1 : nall_dim }, # mapping: (nframes, nall)
256277 {0 : nframes_dim } if fparam is not None else None , # fparam
257278 {0 : nframes_dim , 1 : nloc_dim } if aparam is not None else None , # aparam
279+ {0 : nframes_dim } if charge_spin is not None else None , # charge_spin
258280 )
259281
260282
@@ -487,11 +509,26 @@ def _trace_and_export(
487509 _env .DEVICE = _orig_device
488510
489511 if is_spin :
490- ext_coord , ext_atype , ext_spin , nlist_t , mapping_t , fparam , aparam = (
491- sample_inputs
492- )
512+ (
513+ ext_coord ,
514+ ext_atype ,
515+ ext_spin ,
516+ nlist_t ,
517+ mapping_t ,
518+ fparam ,
519+ aparam ,
520+ charge_spin ,
521+ ) = sample_inputs
493522 else :
494- ext_coord , ext_atype , nlist_t , mapping_t , fparam , aparam = sample_inputs
523+ (
524+ ext_coord ,
525+ ext_atype ,
526+ nlist_t ,
527+ mapping_t ,
528+ fparam ,
529+ aparam ,
530+ charge_spin ,
531+ ) = sample_inputs
495532
496533 # 4. Trace via make_fx on CPU.
497534 # This decomposes torch.autograd.grad into aten ops so the resulting
@@ -505,13 +542,21 @@ def _trace_and_export(
505542 mapping_t ,
506543 fparam = fparam ,
507544 aparam = aparam ,
545+ charge_spin = charge_spin ,
508546 do_atomic_virial = do_atomic_virial ,
509547 tracing_mode = "symbolic" ,
510548 _allow_non_fake_inputs = True ,
511549 )
512550 # 5. Extract output keys from the CPU-traced module.
513551 sample_out = traced (
514- ext_coord , ext_atype , ext_spin , nlist_t , mapping_t , fparam , aparam
552+ ext_coord ,
553+ ext_atype ,
554+ ext_spin ,
555+ nlist_t ,
556+ mapping_t ,
557+ fparam ,
558+ aparam ,
559+ charge_spin ,
515560 )
516561 else :
517562 traced = model .forward_common_lower_exportable (
@@ -521,12 +566,15 @@ def _trace_and_export(
521566 mapping_t ,
522567 fparam = fparam ,
523568 aparam = aparam ,
569+ charge_spin = charge_spin ,
524570 do_atomic_virial = do_atomic_virial ,
525571 tracing_mode = "symbolic" ,
526572 _allow_non_fake_inputs = True ,
527573 )
528574 # 5. Extract output keys from the CPU-traced module.
529- sample_out = traced (ext_coord , ext_atype , nlist_t , mapping_t , fparam , aparam )
575+ sample_out = traced (
576+ ext_coord , ext_atype , nlist_t , mapping_t , fparam , aparam , charge_spin
577+ )
530578
531579 output_keys = list (sample_out .keys ())
532580
0 commit comments