@@ -262,6 +262,7 @@ def __init__(
262262 use_econf_tebd : bool = False ,
263263 use_tebd_bias : bool = False ,
264264 type_map : Optional [list [str ]] = None ,
265+ use_ext_ebd : bool = False ,
265266 ) -> None :
266267 super ().__init__ ()
267268
@@ -275,6 +276,7 @@ def init_subclass_params(sub_data, sub_class):
275276 f"Input args must be a { sub_class .__name__ } class or a dict!"
276277 )
277278
279+ self .use_ext_ebd = use_ext_ebd
278280 self .repflow_args = init_subclass_params (repflow , RepFlowArgs )
279281 self .activation_function = activation_function
280282
@@ -307,6 +309,7 @@ def init_subclass_params(sub_data, sub_class):
307309 env_protection = env_protection ,
308310 precision = precision ,
309311 seed = child_seed (seed , 1 ),
312+ use_ext_ebd = use_ext_ebd ,
310313 )
311314
312315 self .use_econf_tebd = use_econf_tebd
@@ -544,6 +547,7 @@ def serialize(self) -> dict:
544547 "use_tebd_bias" : self .use_tebd_bias ,
545548 "type_map" : self .type_map ,
546549 "type_embedding" : self .type_embedding .serialize (),
550+ "use_ext_ebd" : self .use_ext_ebd ,
547551 }
548552 repflow_variable = {
549553 "edge_embd" : repflows .edge_embd .serialize (),
0 commit comments