@@ -122,6 +122,7 @@ def __init__(
122122 use_loc_mapping : bool = True ,
123123 type_map : list [str ] | None = None ,
124124 add_chg_spin_ebd : bool = False ,
125+ default_chg_spin : list [float ] | None = None ,
125126 ) -> None :
126127 super ().__init__ ()
127128
@@ -177,6 +178,11 @@ def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any:
177178
178179 self .use_econf_tebd = use_econf_tebd
179180 self .add_chg_spin_ebd = add_chg_spin_ebd
181+ if default_chg_spin is not None and len (default_chg_spin ) != 2 :
182+ raise ValueError (
183+ "default_chg_spin must have exactly 2 values [charge, spin]"
184+ )
185+ self .default_chg_spin = default_chg_spin
180186 self .use_loc_mapping = use_loc_mapping
181187 self .use_tebd_bias = use_tebd_bias
182188 self .type_map = type_map
@@ -447,6 +453,18 @@ def get_stat_mean_and_stddev(
447453 stddev_list = [self .repflows .stddev ]
448454 return mean_list , stddev_list
449455
456+ def get_dim_chg_spin (self ) -> int :
457+ """Returns the dimension of charge_spin input."""
458+ return 2 if self .add_chg_spin_ebd else 0
459+
460+ def has_default_chg_spin (self ) -> bool :
461+ """Returns whether default charge_spin values are set."""
462+ return self .default_chg_spin is not None
463+
464+ def get_default_chg_spin (self ) -> list [float ] | None :
465+ """Returns the default charge_spin values."""
466+ return self .default_chg_spin
467+
450468 def serialize (self ) -> dict :
451469 repflows = self .repflows
452470 data = {
@@ -465,6 +483,7 @@ def serialize(self) -> dict:
465483 "use_tebd_bias" : self .use_tebd_bias ,
466484 "use_loc_mapping" : self .use_loc_mapping ,
467485 "add_chg_spin_ebd" : self .add_chg_spin_ebd ,
486+ "default_chg_spin" : self .default_chg_spin ,
468487 "type_map" : self .type_map ,
469488 "type_embedding" : self .type_embedding .embedding .serialize (),
470489 }
@@ -541,7 +560,7 @@ def forward(
541560 nlist : paddle .Tensor ,
542561 mapping : paddle .Tensor | None = None ,
543562 comm_dict : list [paddle .Tensor ] | None = None ,
544- fparam : paddle .Tensor | None = None ,
563+ charge_spin : paddle .Tensor | None = None ,
545564 ) -> tuple [
546565 paddle .Tensor ,
547566 paddle .Tensor | None ,
@@ -593,11 +612,11 @@ def forward(
593612 node_ebd_ext = self .type_embedding (extended_atype )
594613
595614 if self .add_chg_spin_ebd :
596- assert fparam is not None
615+ assert charge_spin is not None
597616 assert self .chg_embedding is not None
598617 assert self .spin_embedding is not None
599- charge = fparam [:, 0 ].to (dtype = paddle .int64 ) + 100
600- spin = fparam [:, 1 ].to (dtype = paddle .int64 )
618+ charge = charge_spin [:, 0 ].to (dtype = paddle .int64 ) + 100
619+ spin = charge_spin [:, 1 ].to (dtype = paddle .int64 )
601620 chg_ebd = self .chg_embedding (charge )
602621 spin_ebd = self .spin_embedding (spin )
603622 sys_cs_embd = self .act (
0 commit comments