@@ -319,6 +319,7 @@ def __init__(
319319 trainable_ln = trainable_ln ,
320320 ln_eps = ln_eps ,
321321 seed = child_seed (seed , 0 ),
322+ trainable = trainable ,
322323 )
323324 self .use_econf_tebd = use_econf_tebd
324325 self .use_tebd_bias = use_tebd_bias
@@ -333,6 +334,7 @@ def __init__(
333334 use_tebd_bias = use_tebd_bias ,
334335 type_map = type_map ,
335336 seed = child_seed (seed , 1 ),
337+ trainable = trainable ,
336338 )
337339 self .tebd_dim = tebd_dim
338340 self .concat_output_tebd = concat_output_tebd
@@ -520,7 +522,7 @@ def call(
520522 type_embedding = self .type_embedding .call ()
521523 # nf x nall x tebd_dim
522524 atype_embd_ext = xp .reshape (
523- xp .take (type_embedding , xp .reshape (atype_ext , [ - 1 ] ), axis = 0 ),
525+ xp .take (type_embedding , xp .reshape (atype_ext , ( - 1 ,) ), axis = 0 ),
524526 (nf , nall , self .tebd_dim ),
525527 )
526528 # nfnl x tebd_dim
@@ -691,6 +693,7 @@ def __init__(
691693 ln_eps : Optional [float ] = 1e-5 ,
692694 smooth : bool = True ,
693695 seed : Optional [Union [int , list [int ]]] = None ,
696+ trainable : bool = True ,
694697 ) -> None :
695698 self .rcut = rcut
696699 self .rcut_smth = rcut_smth
@@ -741,6 +744,7 @@ def __init__(
741744 self .resnet_dt ,
742745 self .precision ,
743746 seed = child_seed (seed , 0 ),
747+ trainable = trainable ,
744748 )
745749 self .embeddings = embeddings
746750 if self .tebd_input_mode in ["strip" ]:
@@ -756,6 +760,7 @@ def __init__(
756760 self .resnet_dt ,
757761 self .precision ,
758762 seed = child_seed (seed , 1 ),
763+ trainable = trainable ,
759764 )
760765 self .embeddings_strip = embeddings_strip
761766 else :
@@ -774,6 +779,7 @@ def __init__(
774779 smooth = self .smooth ,
775780 precision = self .precision ,
776781 seed = child_seed (seed , 2 ),
782+ trainable = trainable ,
777783 )
778784
779785 wanted_shape = (self .ntypes , self .nnei , 4 )
@@ -1027,7 +1033,7 @@ def call(
10271033 xp .tile (
10281034 (xp .reshape (atype , (- 1 , 1 )) * ntypes_with_padding ), (1 , nnei )
10291035 ),
1030- (- 1 ),
1036+ (- 1 , ),
10311037 )
10321038 idx_j = xp .reshape (nei_type , (- 1 ,))
10331039 # (nf x nl x nnei) x ng
@@ -1186,6 +1192,7 @@ def __init__(
11861192 smooth : bool = True ,
11871193 precision : str = DEFAULT_PRECISION ,
11881194 seed : Optional [Union [int , list [int ]]] = None ,
1195+ trainable : bool = True ,
11891196 ) -> None :
11901197 """Construct a neighbor-wise attention net."""
11911198 super ().__init__ ()
@@ -1219,6 +1226,7 @@ def __init__(
12191226 smooth = smooth ,
12201227 precision = precision ,
12211228 seed = child_seed (seed , ii ),
1229+ trainable = trainable ,
12221230 )
12231231 for ii in range (layer_num )
12241232 ]
@@ -1314,6 +1322,7 @@ def __init__(
13141322 smooth : bool = True ,
13151323 precision : str = DEFAULT_PRECISION ,
13161324 seed : Optional [Union [int , list [int ]]] = None ,
1325+ trainable : bool = True ,
13171326 ) -> None :
13181327 """Construct a neighbor-wise attention layer."""
13191328 super ().__init__ ()
@@ -1340,6 +1349,7 @@ def __init__(
13401349 smooth = smooth ,
13411350 precision = precision ,
13421351 seed = child_seed (seed , 0 ),
1352+ trainable = trainable ,
13431353 )
13441354 self .attn_layer_norm = LayerNorm (
13451355 self .embed_dim ,
@@ -1420,6 +1430,7 @@ def __init__(
14201430 smooth : bool = True ,
14211431 precision : str = DEFAULT_PRECISION ,
14221432 seed : Optional [Union [int , list [int ]]] = None ,
1433+ trainable : bool = True ,
14231434 ) -> None :
14241435 """Construct a multi-head neighbor-wise attention net."""
14251436 super ().__init__ ()
@@ -1449,6 +1460,7 @@ def __init__(
14491460 use_timestep = False ,
14501461 precision = precision ,
14511462 seed = child_seed (seed , 0 ),
1463+ trainable = trainable ,
14521464 )
14531465 self .out_proj = NativeLayer (
14541466 hidden_dim ,
@@ -1457,6 +1469,7 @@ def __init__(
14571469 use_timestep = False ,
14581470 precision = precision ,
14591471 seed = child_seed (seed , 1 ),
1472+ trainable = trainable ,
14601473 )
14611474
14621475 def call (self , query , nei_mask , input_r = None , sw = None , attnw_shift = 20.0 ):
0 commit comments