@@ -101,6 +101,26 @@ def model_output_type(self) -> list[str]:
101101 vars .append (kk )
102102 return vars
103103
104+ @torch .jit .export
105+ def has_chg_spin_ebd (self ) -> bool :
106+ """Check if the model has charge spin embedding."""
107+ return self .atomic_model .has_chg_spin_ebd ()
108+
109+ @torch .jit .export
110+ def get_dim_chg_spin (self ) -> int :
111+ """Get the dimension of charge_spin input."""
112+ return self .atomic_model .get_dim_chg_spin ()
113+
114+ @torch .jit .export
115+ def has_default_chg_spin (self ) -> bool :
116+ """Check if the model has default charge_spin values."""
117+ return self .atomic_model .has_default_chg_spin ()
118+
119+ @torch .jit .export
120+ def get_default_chg_spin (self ) -> torch .Tensor | None :
121+ """Get the default charge_spin values."""
122+ return self .atomic_model .get_default_chg_spin ()
123+
104124 # cannot use the name forward. torch script does not work
105125 def forward_common (
106126 self ,
@@ -111,6 +131,7 @@ def forward_common(
111131 fparam : Optional [torch .Tensor ] = None ,
112132 aparam : Optional [torch .Tensor ] = None ,
113133 do_atomic_virial : bool = False ,
134+ charge_spin : Optional [torch .Tensor ] = None ,
114135 ) -> dict [str , torch .Tensor ]:
115136 """Return model prediction.
116137
@@ -145,6 +166,7 @@ def forward_common(
145166 coord , grid , box = box , fparam = fparam , aparam = aparam
146167 )
147168 del coord , grid , box , fparam , aparam
169+ gg = gg .view (gg .shape [0 ], - 1 , 3 )
148170 (
149171 extended_coord ,
150172 extended_atype ,
@@ -158,7 +180,7 @@ def forward_common(
158180 mixed_types = self .mixed_types (),
159181 box = bb ,
160182 )
161- grid_type = torch .zeros (gg .shape [: - 1 ], device = gg .device , dtype = atype .dtype )
183+ grid_type = torch .zeros (gg .shape [0 ], gg . shape [ 1 ], device = gg .device , dtype = atype .dtype )
162184 grid_nlist = build_directional_neighbor_list (
163185 gg ,
164186 grid_type ,
@@ -233,6 +255,7 @@ def forward_common_lower(
233255 mapping : Optional [torch .Tensor ] = None ,
234256 fparam : Optional [torch .Tensor ] = None ,
235257 aparam : Optional [torch .Tensor ] = None ,
258+ charge_spin : Optional [torch .Tensor ] = None ,
236259 do_atomic_virial : bool = False ,
237260 comm_dict : Optional [dict [str , torch .Tensor ]] = None ,
238261 extra_nlist_sort : bool = False ,
@@ -572,9 +595,14 @@ def compute_or_load_stat(
572595 self ,
573596 sampled_func ,
574597 stat_file_path : Optional [DPPath ] = None ,
598+ preset_observed_type : list [str ] | None = None ,
575599 ):
576600 """Compute or load the statistics."""
577- return self .atomic_model .compute_or_load_stat (sampled_func , stat_file_path )
601+ return self .atomic_model .compute_or_load_stat (
602+ sampled_func ,
603+ stat_file_path ,
604+ preset_observed_type = preset_observed_type ,
605+ )
578606
579607 def get_sel (self ) -> list [int ]:
580608 """Returns the number of selected atoms for each type."""
0 commit comments