We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent f2d37ed commit da895d0Copy full SHA for da895d0
1 file changed
deepmd/pt/loss/xas.py
@@ -284,14 +284,10 @@ def compute_output_stats(
284
if populated.any():
285
e_std_global = e_std[populated].mean(dim=0) # [2]
286
else:
287
- e_std_global = torch.ones(
288
- 2, dtype=e_std.dtype, device=e_std.device
289
- )
+ e_std_global = torch.ones(2, dtype=e_std.dtype, device=e_std.device)
290
with torch.no_grad():
291
am.out_bias[key_idx, :, :2] = 0.0
292
- am.out_std[key_idx, :, :2] = e_std_global.to(
293
- am.out_std.dtype
294
+ am.out_std[key_idx, :, :2] = e_std_global.to(am.out_std.dtype)
295
log.info(
296
f"XASLoss: set out_bias[:,:2]=0, out_std[:,:2]={e_std_global.tolist()} eV "
297
"(NN output ±1 ≈ ±e_std eV chemical shift)."
0 commit comments