@@ -354,34 +354,21 @@ def _make_sample_inputs(
354354 )
355355 charge_spin = None
356356 if dim_chg_spin > 0 :
357- default_chg_spin = model .get_default_chg_spin ()
358- if default_chg_spin is None :
359- raise ValueError (
360- "SeZM .pt2 freeze requires default_chg_spin when charge/spin "
361- "conditioning is enabled; runtime charge_spin input is not exposed."
362- )
363- charge_spin = (
364- default_chg_spin .to (device = device , dtype = torch .float64 )
365- .view (1 , dim_chg_spin )
366- .expand (nframes , - 1 )
367- .contiguous ()
357+ charge_spin = torch .zeros (
358+ nframes , dim_chg_spin , dtype = torch .float64 , device = device
368359 )
369360 if has_spin :
370- if charge_spin is not None :
371- return (
372- ext_coord ,
373- ext_atype ,
374- ext_spin ,
375- nlist_t ,
376- mapping_t ,
377- fparam ,
378- aparam ,
379- charge_spin ,
380- )
381- return ext_coord , ext_atype , ext_spin , nlist_t , mapping_t , fparam , aparam
382- if charge_spin is not None :
383- return ext_coord , ext_atype , nlist_t , mapping_t , fparam , aparam , charge_spin
384- return ext_coord , ext_atype , nlist_t , mapping_t , fparam , aparam
361+ return (
362+ ext_coord ,
363+ ext_atype ,
364+ ext_spin ,
365+ nlist_t ,
366+ mapping_t ,
367+ fparam ,
368+ aparam ,
369+ charge_spin ,
370+ )
371+ return ext_coord , ext_atype , nlist_t , mapping_t , fparam , aparam , charge_spin
385372
386373
387374def _resolve_nframes (
@@ -446,6 +433,9 @@ def _build_dynamic_shapes(
446433 nloc_dim = torch .export .Dim ("nloc" , min = 1 )
447434 fparam = sample_inputs [5 ] if has_spin else sample_inputs [4 ]
448435 aparam = sample_inputs [6 ] if has_spin else sample_inputs [5 ]
436+ charge_spin = None
437+ if has_charge_spin :
438+ charge_spin = sample_inputs [7 ] if has_spin else sample_inputs [6 ]
449439 if has_spin :
450440 shapes = (
451441 {0 : nframes_dim , 1 : nall_dim }, # extended_coord
@@ -457,7 +447,7 @@ def _build_dynamic_shapes(
457447 {0 : nframes_dim , 1 : nloc_dim } if aparam is not None else None ,
458448 )
459449 if has_charge_spin :
460- shapes = (* shapes , {0 : nframes_dim })
450+ shapes = (* shapes , {0 : nframes_dim } if charge_spin is not None else None )
461451 return shapes
462452 shapes = (
463453 {0 : nframes_dim , 1 : nall_dim }, # extended_coord: (nframes, nall, 3)
@@ -468,7 +458,7 @@ def _build_dynamic_shapes(
468458 {0 : nframes_dim , 1 : nloc_dim } if aparam is not None else None ,
469459 )
470460 if has_charge_spin :
471- shapes = (* shapes , {0 : nframes_dim })
461+ shapes = (* shapes , {0 : nframes_dim } if charge_spin is not None else None )
472462 return shapes
473463
474464
@@ -527,10 +517,48 @@ def freeze_sezm_to_pt2(
527517 # do_atomic_virial=True pulls every key that DeepPotPTExpt may read
528518 # (energy, energy_redu, energy_derv_r, energy_derv_c, energy_derv_c_redu)
529519 # into the traced graph.
530- traced = model .forward_common_lower_exportable (
531- * sample_inputs_cpu ,
532- do_atomic_virial = True ,
533- )
520+ if is_spin :
521+ (
522+ ext_coord ,
523+ ext_atype ,
524+ ext_spin ,
525+ nlist_t ,
526+ mapping_t ,
527+ fparam ,
528+ aparam ,
529+ charge_spin ,
530+ ) = sample_inputs_cpu
531+ traced = model .forward_common_lower_exportable (
532+ ext_coord ,
533+ ext_atype ,
534+ ext_spin ,
535+ nlist_t ,
536+ mapping_t ,
537+ fparam = fparam ,
538+ aparam = aparam ,
539+ charge_spin = charge_spin ,
540+ do_atomic_virial = True ,
541+ )
542+ else :
543+ (
544+ ext_coord ,
545+ ext_atype ,
546+ nlist_t ,
547+ mapping_t ,
548+ fparam ,
549+ aparam ,
550+ charge_spin ,
551+ ) = sample_inputs_cpu
552+ traced = model .forward_common_lower_exportable (
553+ ext_coord ,
554+ ext_atype ,
555+ nlist_t ,
556+ mapping_t ,
557+ fparam = fparam ,
558+ aparam = aparam ,
559+ charge_spin = charge_spin ,
560+ do_atomic_virial = True ,
561+ )
534562
535563 # Output key order is taken from a concrete run; Python dict order
536564 # is stable and matches what DeepPotPTExpt::extract_outputs zips
0 commit comments