@@ -54,11 +54,14 @@ def process_spin_input(self, coord, atype, spin):
5454 coord = coord .reshape (nframes , nloc , 3 )
5555 spin = spin .reshape (nframes , nloc , 3 )
5656 atype_spin = torch .concat ([atype , atype + self .ntypes_real ], dim = - 1 )
57- virtual_coord = coord + spin * (self .virtual_scale_mask .to (atype .device ))[
58- atype
59- ].reshape ([nframes , nloc , 1 ])
57+ spin_dist = spin * (self .virtual_scale_mask .to (atype .device ))[atype ].reshape (
58+ [nframes , nloc , 1 ]
59+ )
60+ virtual_coord = coord + spin_dist
6061 coord_spin = torch .concat ([coord , virtual_coord ], dim = - 2 )
61- return coord_spin , atype_spin
62+ # for spin virial corr
63+ coord_corr = torch .concat ([torch .zeros_like (coord ), - spin_dist ], dim = - 2 )
64+ return coord_spin , atype_spin , coord_corr
6265
6366 def process_spin_input_lower (
6467 self ,
@@ -78,13 +81,18 @@ def process_spin_input_lower(
7881 """
7982 nframes , nall = extended_coord .shape [:2 ]
8083 nloc = nlist .shape [1 ]
81- virtual_extended_coord = extended_coord + extended_spin * (
84+ extended_spin_dist = extended_spin * (
8285 self .virtual_scale_mask .to (extended_atype .device )
8386 )[extended_atype ].reshape ([nframes , nall , 1 ])
87+ virtual_extended_coord = extended_coord + extended_spin_dist
8488 virtual_extended_atype = extended_atype + self .ntypes_real
8589 extended_coord_updated = concat_switch_virtual (
8690 extended_coord , virtual_extended_coord , nloc
8791 )
92+ # for spin virial corr
93+ extended_coord_corr = concat_switch_virtual (
94+ torch .zeros_like (extended_coord ), - extended_spin_dist , nloc
95+ )
8896 extended_atype_updated = concat_switch_virtual (
8997 extended_atype , virtual_extended_atype , nloc
9098 )
@@ -100,6 +108,7 @@ def process_spin_input_lower(
100108 extended_atype_updated ,
101109 nlist_updated ,
102110 mapping_updated ,
111+ extended_coord_corr ,
103112 )
104113
105114 def process_spin_output (
@@ -367,7 +376,7 @@ def spin_sampled_func():
367376 sampled = sampled_func ()
368377 spin_sampled = []
369378 for sys in sampled :
370- coord_updated , atype_updated = self .process_spin_input (
379+ coord_updated , atype_updated , _ = self .process_spin_input (
371380 sys ["coord" ], sys ["atype" ], sys ["spin" ]
372381 )
373382 tmp_dict = {
@@ -398,7 +407,9 @@ def forward_common(
398407 do_atomic_virial : bool = False ,
399408 ) -> dict [str , torch .Tensor ]:
400409 nframes , nloc = atype .shape
401- coord_updated , atype_updated = self .process_spin_input (coord , atype , spin )
410+ coord_updated , atype_updated , coord_corr_for_virial = self .process_spin_input (
411+ coord , atype , spin
412+ )
402413 if aparam is not None :
403414 aparam = self .expand_aparam (aparam , nloc * 2 )
404415 model_ret = self .backbone_model .forward_common (
@@ -408,6 +419,7 @@ def forward_common(
408419 fparam = fparam ,
409420 aparam = aparam ,
410421 do_atomic_virial = do_atomic_virial ,
422+ coord_corr_for_virial = coord_corr_for_virial ,
411423 )
412424 model_output_type = self .backbone_model .model_output_type ()
413425 if "mask" in model_output_type :
@@ -454,6 +466,7 @@ def forward_common_lower(
454466 extended_atype_updated ,
455467 nlist_updated ,
456468 mapping_updated ,
469+ extended_coord_corr_for_virial ,
457470 ) = self .process_spin_input_lower (
458471 extended_coord , extended_atype , extended_spin , nlist , mapping = mapping
459472 )
@@ -469,6 +482,7 @@ def forward_common_lower(
469482 do_atomic_virial = do_atomic_virial ,
470483 comm_dict = comm_dict ,
471484 extra_nlist_sort = extra_nlist_sort ,
485+ extended_coord_corr = extended_coord_corr_for_virial ,
472486 )
473487 model_output_type = self .backbone_model .model_output_type ()
474488 if "mask" in model_output_type :
@@ -541,6 +555,11 @@ def translated_output_def(self):
541555 output_def ["force" ].squeeze (- 2 )
542556 output_def ["force_mag" ] = deepcopy (out_def_data ["energy_derv_r_mag" ])
543557 output_def ["force_mag" ].squeeze (- 2 )
558+ if self .do_grad_c ("energy" ):
559+ output_def ["virial" ] = deepcopy (out_def_data ["energy_derv_c_redu" ])
560+ output_def ["virial" ].squeeze (- 2 )
561+ output_def ["atom_virial" ] = deepcopy (out_def_data ["energy_derv_c" ])
562+ output_def ["atom_virial" ].squeeze (- 3 )
544563 return output_def
545564
546565 def forward (
@@ -569,7 +588,10 @@ def forward(
569588 if self .backbone_model .do_grad_r ("energy" ):
570589 model_predict ["force" ] = model_ret ["energy_derv_r" ].squeeze (- 2 )
571590 model_predict ["force_mag" ] = model_ret ["energy_derv_r_mag" ].squeeze (- 2 )
572- # not support virial by far
591+ if self .backbone_model .do_grad_c ("energy" ):
592+ model_predict ["virial" ] = model_ret ["energy_derv_c_redu" ].squeeze (- 2 )
593+ if do_atomic_virial :
594+ model_predict ["atom_virial" ] = model_ret ["energy_derv_c" ].squeeze (- 3 )
573595 return model_predict
574596
575597 @torch .jit .export
@@ -606,5 +628,10 @@ def forward_lower(
606628 model_predict ["extended_force_mag" ] = model_ret [
607629 "energy_derv_r_mag"
608630 ].squeeze (- 2 )
609- # not support virial by far
631+ if self .backbone_model .do_grad_c ("energy" ):
632+ model_predict ["virial" ] = model_ret ["energy_derv_c_redu" ].squeeze (- 2 )
633+ if do_atomic_virial :
634+ model_predict ["extended_virial" ] = model_ret ["energy_derv_c" ].squeeze (
635+ - 3
636+ )
610637 return model_predict
0 commit comments