Skip to content

Commit da895d0

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent f2d37ed commit da895d0

1 file changed

Lines changed: 2 additions & 6 deletions

File tree

deepmd/pt/loss/xas.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -284,14 +284,10 @@ def compute_output_stats(
284284
if populated.any():
285285
e_std_global = e_std[populated].mean(dim=0) # [2]
286286
else:
287-
e_std_global = torch.ones(
288-
2, dtype=e_std.dtype, device=e_std.device
289-
)
287+
e_std_global = torch.ones(2, dtype=e_std.dtype, device=e_std.device)
290288
with torch.no_grad():
291289
am.out_bias[key_idx, :, :2] = 0.0
292-
am.out_std[key_idx, :, :2] = e_std_global.to(
293-
am.out_std.dtype
294-
)
290+
am.out_std[key_idx, :, :2] = e_std_global.to(am.out_std.dtype)
295291
log.info(
296292
f"XASLoss: set out_bias[:,:2]=0, out_std[:,:2]={e_std_global.tolist()} eV "
297293
"(NN output ±1 ≈ ±e_std eV chemical shift)."

0 commit comments

Comments
 (0)