2020)
2121from deepmd .dpmodel .utils .network import (
2222 NativeLayer ,
23+ get_activation_fn ,
2324)
2425from deepmd .dpmodel .utils .seed import (
2526 child_seed ,
@@ -356,6 +357,7 @@ def __init__(
356357 use_tebd_bias : bool = False ,
357358 use_loc_mapping : bool = True ,
358359 type_map : list [str ] | None = None ,
360+ add_chg_spin_ebd : bool = False ,
359361 ) -> None :
360362 super ().__init__ ()
361363
@@ -410,6 +412,7 @@ def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any:
410412 )
411413
412414 self .use_econf_tebd = use_econf_tebd
415+ self .add_chg_spin_ebd = add_chg_spin_ebd
413416 self .use_tebd_bias = use_tebd_bias
414417 self .use_loc_mapping = use_loc_mapping
415418 self .type_map = type_map
@@ -428,6 +431,38 @@ def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any:
428431 )
429432 self .concat_output_tebd = concat_output_tebd
430433 self .precision = precision
434+
435+ if self .add_chg_spin_ebd :
436+ self .cs_activation_fn = get_activation_fn (activation_function )
437+ # -100 ~ 100 is a conservative bound
438+ self .chg_embedding = TypeEmbedNet (
439+ ntypes = 200 ,
440+ neuron = [self .tebd_dim ],
441+ padding = True ,
442+ activation_function = "Linear" ,
443+ precision = precision ,
444+ seed = child_seed (seed , 3 ),
445+ )
446+ # 100 is a conservative upper bound
447+ self .spin_embedding = TypeEmbedNet (
448+ ntypes = 100 ,
449+ neuron = [self .tebd_dim ],
450+ padding = True ,
451+ activation_function = "Linear" ,
452+ precision = precision ,
453+ seed = child_seed (seed , 4 ),
454+ )
455+ self .mix_cs_mlp = NativeLayer (
456+ 2 * self .tebd_dim ,
457+ self .tebd_dim ,
458+ precision = precision ,
459+ seed = child_seed (seed , 5 ),
460+ )
461+ else :
462+ self .chg_embedding = None
463+ self .spin_embedding = None
464+ self .mix_cs_mlp = None
465+
431466 self .exclude_types = exclude_types
432467 self .env_protection = env_protection
433468 self .trainable = trainable
@@ -579,6 +614,7 @@ def call(
579614 atype_ext : Array ,
580615 nlist : Array ,
581616 mapping : Array | None = None ,
617+ fparam : Array | None = None ,
582618 ) -> tuple [Array , Array , Array , Array , Array ]:
583619 """Compute the descriptor.
584620
@@ -625,6 +661,27 @@ def call(
625661 xp .take (type_embedding , xp .reshape (atype_ext , (- 1 ,)), axis = 0 ),
626662 (nframes , nall , self .tebd_dim ),
627663 )
664+
665+ if self .add_chg_spin_ebd :
666+ assert fparam is not None
667+ assert self .chg_embedding is not None
668+ assert self .spin_embedding is not None
669+ chg_tebd = self .chg_embedding .call ()
670+ spin_tebd = self .spin_embedding .call ()
671+ charge = xp .astype (fparam [:, 0 ], xp .int64 ) + 100
672+ spin = xp .astype (fparam [:, 1 ], xp .int64 )
673+ chg_ebd = xp .reshape (
674+ xp .take (chg_tebd , xp .reshape (charge , (- 1 ,)), axis = 0 ),
675+ (nframes , self .tebd_dim ),
676+ )
677+ spin_ebd = xp .reshape (
678+ xp .take (spin_tebd , xp .reshape (spin , (- 1 ,)), axis = 0 ),
679+ (nframes , self .tebd_dim ),
680+ )
681+ cs_cat = xp .concat ([chg_ebd , spin_ebd ], axis = - 1 )
682+ sys_cs_embd = self .cs_activation_fn (self .mix_cs_mlp .call (cs_cat ))
683+ node_ebd_ext = node_ebd_ext + xp .expand_dims (sys_cs_embd , axis = 1 )
684+
628685 node_ebd_inp = node_ebd_ext [:, :nloc , :]
629686 # repflows
630687 node_ebd , edge_ebd , h2 , rot_mat , sw = self .repflows (
@@ -655,9 +712,14 @@ def serialize(self) -> dict:
655712 "use_econf_tebd" : self .use_econf_tebd ,
656713 "use_tebd_bias" : self .use_tebd_bias ,
657714 "use_loc_mapping" : self .use_loc_mapping ,
715+ "add_chg_spin_ebd" : self .add_chg_spin_ebd ,
658716 "type_map" : self .type_map ,
659717 "type_embedding" : self .type_embedding .serialize (),
660718 }
719+ if self .add_chg_spin_ebd :
720+ data ["chg_embedding" ] = self .chg_embedding .serialize ()
721+ data ["spin_embedding" ] = self .spin_embedding .serialize ()
722+ data ["mix_cs_mlp" ] = self .mix_cs_mlp .serialize ()
661723 repflow_variable = {
662724 "edge_embd" : repflows .edge_embd .serialize (),
663725 "angle_embd" : repflows .angle_embd .serialize (),
@@ -684,10 +746,18 @@ def deserialize(cls, data: dict) -> "DescrptDPA3":
684746 data .pop ("type" )
685747 repflow_variable = data .pop ("repflow_variable" ).copy ()
686748 type_embedding = data .pop ("type_embedding" )
749+ chg_embedding = data .pop ("chg_embedding" , None )
750+ spin_embedding = data .pop ("spin_embedding" , None )
751+ mix_cs_mlp = data .pop ("mix_cs_mlp" , None )
687752 data ["repflow" ] = RepFlowArgs (** data .pop ("repflow_args" ))
688753 obj = cls (** data )
689754 obj .type_embedding = TypeEmbedNet .deserialize (type_embedding )
690755
756+ if obj .add_chg_spin_ebd and chg_embedding is not None :
757+ obj .chg_embedding = TypeEmbedNet .deserialize (chg_embedding )
758+ obj .spin_embedding = TypeEmbedNet .deserialize (spin_embedding )
759+ obj .mix_cs_mlp = NativeLayer .deserialize (mix_cs_mlp )
760+
691761 # deserialize repflow
692762 statistic_repflows = repflow_variable .pop ("@variables" )
693763 env_mat = repflow_variable .pop ("env_mat" )
0 commit comments