@@ -150,13 +150,48 @@ def init_subclass_params(sub_data, sub_class):
150150 fix_stat_std = self .repflow_args .fix_stat_std ,
151151 optim_update = self .repflow_args .optim_update ,
152152 smooth_edge_update = self .repflow_args .smooth_edge_update ,
153+ angle_multi_freq = self .repflow_args .angle_multi_freq ,
154+ use_dynamic_sel = self .repflow_args .use_dynamic_sel ,
155+ sel_reduce_factor = self .repflow_args .sel_reduce_factor ,
156+ use_env_envelope = self .repflow_args .use_env_envelope ,
157+ use_new_sw = self .repflow_args .use_new_sw ,
158+ update_dihedral = self .repflow_args .update_dihedral ,
159+ d_dim = self .repflow_args .d_dim ,
160+ d_sel = self .repflow_args .d_sel ,
161+ d_rcut = self .repflow_args .d_rcut ,
162+ d_rcut_smth = self .repflow_args .d_rcut_smth ,
163+ use_ffn_node_edge_message = self .repflow_args .use_ffn_node_edge_message ,
164+ use_ffn_edge_edge_message = self .repflow_args .use_ffn_edge_edge_message ,
165+ use_ffn_edge_angle_message = self .repflow_args .use_ffn_edge_angle_message ,
166+ use_ffn_angle_angle_message = self .repflow_args .use_ffn_angle_angle_message ,
167+ ffn_hidden_dim = self .repflow_args .ffn_hidden_dim ,
168+ edge_use_concat_rbf = self .repflow_args .edge_use_concat_rbf ,
169+ edge_use_rbf = self .repflow_args .edge_use_rbf ,
170+ edge_use_dist = self .repflow_args .edge_use_dist ,
171+ embed_use_bias = self .repflow_args .embed_use_bias ,
172+ edge_use_attn = self .repflow_args .edge_use_attn ,
173+ edge_attn_hidden = self .repflow_args .edge_attn_hidden ,
174+ edge_attn_head = self .repflow_args .edge_attn_head ,
175+ edge_attn_use_ln = self .repflow_args .edge_attn_use_ln ,
176+ edge_rbf_dot_self = self .repflow_args .edge_rbf_dot_self ,
177+ edge_rbf_dot_message = self .repflow_args .edge_rbf_dot_message ,
178+ edge_use_esen_rbf = self .repflow_args .edge_use_esen_rbf ,
179+ edge_use_esen_atom_ebd = self .repflow_args .edge_use_esen_atom_ebd ,
180+ edge_use_esen_env = self .repflow_args .edge_use_esen_env ,
181+ residual_pref = self .repflow_args .residual_pref ,
182+ tebd_use_act = self .repflow_args .tebd_use_act ,
183+ message_use_self_concat = self .repflow_args .message_use_self_concat ,
184+ use_slim_message = self .repflow_args .use_slim_message ,
185+ use_combined_output = self .repflow_args .use_combined_output ,
186+ use_loc_mapping = use_loc_mapping ,
153187 exclude_types = exclude_types ,
154188 env_protection = env_protection ,
155189 precision = precision ,
156190 seed = child_seed (seed , 1 ),
157191 )
158192
159193 self .use_econf_tebd = use_econf_tebd
194+ self .use_loc_mapping = use_loc_mapping
160195 self .use_tebd_bias = use_tebd_bias
161196 self .type_map = type_map
162197 self .tebd_dim = self .repflow_args .n_dim
@@ -466,12 +501,16 @@ def forward(
466501 The smooth switch function. shape: nf x nloc x nnei
467502
468503 """
504+ parrallel_mode = comm_dict is not None
469505 # cast the input to internal precsion
470506 extended_coord = extended_coord .to (dtype = self .prec )
471507 nframes , nloc , nnei = nlist .shape
472508 nall = extended_coord .view (nframes , - 1 ).shape [1 ] // 3
473509
474- node_ebd_ext = self .type_embedding (extended_atype )
510+ if not parrallel_mode and self .use_loc_mapping :
511+ node_ebd_ext = self .type_embedding (extended_atype [:, :nloc ])
512+ else :
513+ node_ebd_ext = self .type_embedding (extended_atype )
475514 node_ebd_inp = node_ebd_ext [:, :nloc , :]
476515 # repflows
477516 node_ebd , edge_ebd , h2 , rot_mat , sw = self .repflows (
0 commit comments