3232 UpdateSel ,
3333)
3434from deepmd .pd .utils .utils import (
35+ ActivationFn ,
3536 to_numpy_array ,
3637)
3738from deepmd .utils .data_system import (
@@ -120,6 +121,7 @@ def __init__(
120121 use_tebd_bias : bool = False ,
121122 use_loc_mapping : bool = True ,
122123 type_map : list [str ] | None = None ,
124+ add_chg_spin_ebd : bool = False ,
123125 ) -> None :
124126 super ().__init__ ()
125127
@@ -174,6 +176,7 @@ def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any:
174176 )
175177
176178 self .use_econf_tebd = use_econf_tebd
179+ self .add_chg_spin_ebd = add_chg_spin_ebd
177180 self .use_loc_mapping = use_loc_mapping
178181 self .use_tebd_bias = use_tebd_bias
179182 self .type_map = type_map
@@ -196,6 +199,34 @@ def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any:
196199 self .concat_output_tebd = concat_output_tebd
197200 self .precision = precision
198201 self .prec = PRECISION_DICT [self .precision ]
202+
203+ if self .add_chg_spin_ebd :
204+ self .act = ActivationFn (activation_function )
205+ # -100 ~ 100 is a conservative bound
206+ self .chg_embedding = TypeEmbedNet (
207+ 200 ,
208+ self .tebd_dim ,
209+ precision = precision ,
210+ seed = child_seed (seed , 3 ),
211+ )
212+ # 100 is a conservative upper bound
213+ self .spin_embedding = TypeEmbedNet (
214+ 100 ,
215+ self .tebd_dim ,
216+ precision = precision ,
217+ seed = child_seed (seed , 4 ),
218+ )
219+ self .mix_cs_mlp = MLPLayer (
220+ 2 * self .tebd_dim ,
221+ self .tebd_dim ,
222+ precision = precision ,
223+ seed = child_seed (seed , 5 ),
224+ )
225+ else :
226+ self .chg_embedding = None
227+ self .spin_embedding = None
228+ self .mix_cs_mlp = None
229+
199230 self .exclude_types = exclude_types
200231 self .env_protection = env_protection
201232 self .trainable = trainable
@@ -433,9 +464,14 @@ def serialize(self) -> dict:
433464 "use_econf_tebd" : self .use_econf_tebd ,
434465 "use_tebd_bias" : self .use_tebd_bias ,
435466 "use_loc_mapping" : self .use_loc_mapping ,
467+ "add_chg_spin_ebd" : self .add_chg_spin_ebd ,
436468 "type_map" : self .type_map ,
437469 "type_embedding" : self .type_embedding .embedding .serialize (),
438470 }
471+ if self .add_chg_spin_ebd :
472+ data ["chg_embedding" ] = self .chg_embedding .embedding .serialize ()
473+ data ["spin_embedding" ] = self .spin_embedding .embedding .serialize ()
474+ data ["mix_cs_mlp" ] = self .mix_cs_mlp .serialize ()
439475 repflow_variable = {
440476 "edge_embd" : repflows .edge_embd .serialize (),
441477 "angle_embd" : repflows .angle_embd .serialize (),
@@ -462,12 +498,24 @@ def deserialize(cls, data: dict) -> "DescrptDPA3":
462498 data .pop ("type" )
463499 repflow_variable = data .pop ("repflow_variable" ).copy ()
464500 type_embedding = data .pop ("type_embedding" )
501+ chg_embedding = data .pop ("chg_embedding" , None )
502+ spin_embedding = data .pop ("spin_embedding" , None )
503+ mix_cs_mlp = data .pop ("mix_cs_mlp" , None )
465504 data ["repflow" ] = RepFlowArgs (** data .pop ("repflow_args" ))
466505 obj = cls (** data )
467506 obj .type_embedding .embedding = TypeEmbedNetConsistent .deserialize (
468507 type_embedding
469508 )
470509
510+ if obj .add_chg_spin_ebd and chg_embedding is not None :
511+ obj .chg_embedding .embedding = TypeEmbedNetConsistent .deserialize (
512+ chg_embedding
513+ )
514+ obj .spin_embedding .embedding = TypeEmbedNetConsistent .deserialize (
515+ spin_embedding
516+ )
517+ obj .mix_cs_mlp = MLPLayer .deserialize (mix_cs_mlp )
518+
471519 def t_cvt (xx : Any ) -> paddle .Tensor :
472520 return paddle .to_tensor (xx , dtype = obj .repflows .prec , place = env .DEVICE )
473521
@@ -493,7 +541,14 @@ def forward(
493541 nlist : paddle .Tensor ,
494542 mapping : paddle .Tensor | None = None ,
495543 comm_dict : list [paddle .Tensor ] | None = None ,
496- ) -> paddle .Tensor :
544+ fparam : paddle .Tensor | None = None ,
545+ ) -> tuple [
546+ paddle .Tensor ,
547+ paddle .Tensor | None ,
548+ paddle .Tensor | None ,
549+ paddle .Tensor | None ,
550+ paddle .Tensor | None ,
551+ ]:
497552 """Compute the descriptor.
498553
499554 Parameters
@@ -536,6 +591,20 @@ def forward(
536591 node_ebd_ext = self .type_embedding (extended_atype [:, :nloc ])
537592 else :
538593 node_ebd_ext = self .type_embedding (extended_atype )
594+
595+ if self .add_chg_spin_ebd :
596+ assert fparam is not None
597+ assert self .chg_embedding is not None
598+ assert self .spin_embedding is not None
599+ charge = fparam [:, 0 ].to (dtype = paddle .int64 ) + 100
600+ spin = fparam [:, 1 ].to (dtype = paddle .int64 )
601+ chg_ebd = self .chg_embedding (charge )
602+ spin_ebd = self .spin_embedding (spin )
603+ sys_cs_embd = self .act (
604+ self .mix_cs_mlp (paddle .concat ([chg_ebd , spin_ebd ], axis = - 1 ))
605+ )
606+ node_ebd_ext = node_ebd_ext + sys_cs_embd .unsqueeze (1 )
607+
539608 node_ebd_inp = node_ebd_ext [:, :nloc , :]
540609 # repflows
541610 node_ebd , edge_ebd , h2 , rot_mat , sw = self .repflows (
@@ -550,10 +619,14 @@ def forward(
550619 node_ebd = paddle .concat ([node_ebd , node_ebd_inp ], axis = - 1 )
551620 return (
552621 node_ebd .to (dtype = env .GLOBAL_PD_FLOAT_PRECISION ),
553- rot_mat .to (dtype = env .GLOBAL_PD_FLOAT_PRECISION ),
554- edge_ebd .to (dtype = env .GLOBAL_PD_FLOAT_PRECISION ),
555- h2 .to (dtype = env .GLOBAL_PD_FLOAT_PRECISION ),
556- sw .to (dtype = env .GLOBAL_PD_FLOAT_PRECISION ),
622+ rot_mat .to (dtype = env .GLOBAL_PD_FLOAT_PRECISION )
623+ if rot_mat is not None
624+ else None ,
625+ edge_ebd .to (dtype = env .GLOBAL_PD_FLOAT_PRECISION )
626+ if edge_ebd is not None
627+ else None ,
628+ h2 .to (dtype = env .GLOBAL_PD_FLOAT_PRECISION ) if h2 is not None else None ,
629+ sw .to (dtype = env .GLOBAL_PD_FLOAT_PRECISION ) if sw is not None else None ,
557630 )
558631
559632 @classmethod
0 commit comments