Skip to content

Commit 552acd2

Browse files
committed
fix dos
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
1 parent ba1204e commit 552acd2

1 file changed

Lines changed: 4 additions & 3 deletions

File tree

deepmd/tf/model/model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -932,10 +932,11 @@ def _apply_out_bias_std(self, output, atype, natoms, coord, selected_atype=None)
932932
# Apply bias and std: output = output * std + bias
933933
adjusted_output = output_reshaped * std_per_atom + bias_per_atom
934934

935+
# expand axis 2 of valid_mask to nout
936+
valid_mask = tf.tile(tf.expand_dims(valid_mask, -1), [1, 1, nout])
937+
935938
# Only apply bias/std to valid atoms, keep original values for invalid atoms
936-
output_reshaped = tf.where(
937-
tf.expand_dims(valid_mask, -1), adjusted_output, output_reshaped
938-
)
939+
output_reshaped = tf.where(valid_mask, adjusted_output, output_reshaped)
939940

940941
return tf.reshape(output_reshaped, tf.shape(output))
941942

0 commit comments

Comments
 (0)