Skip to content

Commit 07ce025

Browse files
committed
fix: detach graph in log
1 parent bca1cd7 commit 07ce025

1 file changed

Lines changed: 18 additions & 5 deletions

File tree

deepmd/pt_expt/train/training.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1312,9 +1312,14 @@ def run(self) -> None:
13121312
self.wrapper.eval()
13131313

13141314
if self.rank == 0:
1315+
def _to_float(v: Any) -> float:
1316+
return v.detach().item() if torch.is_tensor(v) else float(v)
1317+
13151318
if not self.multi_task:
13161319
train_results = {
1317-
k: v for k, v in more_loss.items() if "l2_" not in k
1320+
k: _to_float(v)
1321+
for k, v in more_loss.items()
1322+
if "l2_" not in k
13181323
}
13191324

13201325
# validation
@@ -1335,7 +1340,8 @@ def run(self) -> None:
13351340
for k, v in _vmore.items():
13361341
if "l2_" not in k:
13371342
valid_results[k] = (
1338-
valid_results.get(k, 0.0) + v * natoms
1343+
valid_results.get(k, 0.0)
1344+
+ _to_float(v) * natoms
13391345
)
13401346
if sum_natoms > 0:
13411347
valid_results = {
@@ -1348,7 +1354,9 @@ def run(self) -> None:
13481354

13491355
# current task already has loss
13501356
train_results[task_key] = {
1351-
k: v for k, v in more_loss.items() if "l2_" not in k
1357+
k: _to_float(v)
1358+
for k, v in more_loss.items()
1359+
if "l2_" not in k
13521360
}
13531361

13541362
# compute loss for other tasks
@@ -1363,7 +1371,9 @@ def run(self) -> None:
13631371
task_key=_key,
13641372
)
13651373
train_results[_key] = {
1366-
k: v for k, v in _more.items() if "l2_" not in k
1374+
k: _to_float(v)
1375+
for k, v in _more.items()
1376+
if "l2_" not in k
13671377
}
13681378

13691379
# validation for each task
@@ -1387,7 +1397,10 @@ def run(self) -> None:
13871397
_sum_natoms += natoms
13881398
for k, v in _vmore.items():
13891399
if "l2_" not in k:
1390-
_vres[k] = _vres.get(k, 0.0) + v * natoms
1400+
_vres[k] = (
1401+
_vres.get(k, 0.0)
1402+
+ _to_float(v) * natoms
1403+
)
13911404
if _sum_natoms > 0:
13921405
_vres = {
13931406
k: v / _sum_natoms for k, v in _vres.items()

0 commit comments

Comments
 (0)