Skip to content

Commit 4cee0bf

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 07ce025 commit 4cee0bf

1 file changed

Lines changed: 18 additions & 2 deletions

File tree

deepmd/pt_expt/train/training.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,16 @@ def _expand(t: torch.Tensor | None) -> torch.Tensor | None:
378378
tracing_mode="symbolic",
379379
_allow_non_fake_inputs=True,
380380
decomposition_table=decomp_table,
381-
)(ext_coord, ext_atype, nlist, mapping, fparam, aparam, charge_spin, *task_buf_vals_trace)
381+
)(
382+
ext_coord,
383+
ext_atype,
384+
nlist,
385+
mapping,
386+
fparam,
387+
aparam,
388+
charge_spin,
389+
*task_buf_vals_trace,
390+
)
382391

383392
# make_fx inserts aten.detach.default for saved tensors used in the
384393
# decomposed autograd.grad backward ops. These detach nodes break
@@ -484,7 +493,13 @@ def forward(
484493
getattr(self, f"_task_{name}") for name in self._task_buf_order
485494
)
486495
result = self.compiled_forward_lower(
487-
ext_coord, ext_atype, nlist, mapping, fparam, aparam, charge_spin,
496+
ext_coord,
497+
ext_atype,
498+
nlist,
499+
mapping,
500+
fparam,
501+
aparam,
502+
charge_spin,
488503
*task_buf_vals,
489504
)
490505

@@ -1312,6 +1327,7 @@ def run(self) -> None:
13121327
self.wrapper.eval()
13131328

13141329
if self.rank == 0:
1330+
13151331
def _to_float(v: Any) -> float:
13161332
return v.detach().item() if torch.is_tensor(v) else float(v)
13171333

0 commit comments

Comments
 (0)