@@ -89,6 +89,9 @@ class DescrptDPA3(BaseDescriptor, paddle.nn.Layer):
8989 Whether to use electronic configuration type embedding.
9090 use_tebd_bias : bool, Optional
9191 Whether to use bias in the type embedding layer.
92+ use_loc_mapping : bool, Optional
93+ Whether to use local atom index mapping in training or non-parallel inference.
94+ Not supported yet in Paddle.
9295 type_map : list[str], Optional
9396 A list of strings. Give the name to each type of atoms.
9497 """
@@ -108,6 +111,7 @@ def __init__(
108111 seed : Optional [Union [int , list [int ]]] = None ,
109112 use_econf_tebd : bool = False ,
110113 use_tebd_bias : bool = False ,
114+ use_loc_mapping : bool = False ,
111115 type_map : Optional [list [str ]] = None ,
112116 ) -> None :
113117 super ().__init__ ()
@@ -152,6 +156,7 @@ def init_subclass_params(sub_data, sub_class):
152156 smooth_edge_update = self .repflow_args .smooth_edge_update ,
153157 use_dynamic_sel = self .repflow_args .use_dynamic_sel ,
154158 sel_reduce_factor = self .repflow_args .sel_reduce_factor ,
159+ use_loc_mapping = use_loc_mapping ,
155160 exclude_types = exclude_types ,
156161 env_protection = env_protection ,
157162 precision = precision ,
@@ -160,6 +165,7 @@ def init_subclass_params(sub_data, sub_class):
160165
161166 self .use_econf_tebd = use_econf_tebd
162167 self .use_tebd_bias = use_tebd_bias
168+ self .use_loc_mapping = use_loc_mapping
163169 self .type_map = type_map
164170 self .tebd_dim = self .repflow_args .n_dim
165171 self .type_embedding = TypeEmbedNet (
@@ -370,7 +376,7 @@ def serialize(self) -> dict:
370376 data = {
371377 "@class" : "Descriptor" ,
372378 "type" : "dpa3" ,
373- "@version" : 1 ,
379+ "@version" : 2 ,
374380 "ntypes" : self .ntypes ,
375381 "repflow_args" : self .repflow_args .serialize (),
376382 "concat_output_tebd" : self .concat_output_tebd ,
@@ -381,6 +387,7 @@ def serialize(self) -> dict:
381387 "trainable" : self .trainable ,
382388 "use_econf_tebd" : self .use_econf_tebd ,
383389 "use_tebd_bias" : self .use_tebd_bias ,
390+ "use_loc_mapping" : self .use_loc_mapping ,
384391 "type_map" : self .type_map ,
385392 "type_embedding" : self .type_embedding .embedding .serialize (),
386393 }
@@ -405,7 +412,7 @@ def serialize(self) -> dict:
405412 def deserialize (cls , data : dict ) -> "DescrptDPA3" :
406413 data = data .copy ()
407414 version = data .pop ("@version" )
408- check_version_compatibility (version , 1 , 1 )
415+ check_version_compatibility (version , 2 , 1 )
409416 data .pop ("@class" )
410417 data .pop ("type" )
411418 repflow_variable = data .pop ("repflow_variable" ).copy ()
0 commit comments