Skip to content

Commit a100841

Browse files
Fix UT
1 parent f88ab82 commit a100841

1 file changed

Lines changed: 4 additions & 2 deletions

File tree

deepmd/pd/utils/stat.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def _post_process_stat(
144144
"""Post process the statistics.
145145
146146
For global statistics, we do not have the std for each type of atoms,
147-
thus fake the output std by ones for all the types.
147+
thus broadcast the global std to all the types.
148148
If the shape of out_std is already the same as out_bias,
149149
we do not need to do anything.
150150
@@ -154,7 +154,9 @@ def _post_process_stat(
154154
if vv.shape == out_std[kk].shape:
155155
new_std[kk] = out_std[kk]
156156
else:
157-
new_std[kk] = np.ones_like(vv)
157+
ntypes = vv.shape[0]
158+
reps = [ntypes] + [1] * (vv.ndim - 1)
159+
new_std[kk] = np.tile(out_std[kk], reps)
158160
return out_bias, new_std
159161

160162

0 commit comments

Comments
 (0)