@@ -1312,9 +1312,14 @@ def run(self) -> None:
13121312 self .wrapper .eval ()
13131313
13141314 if self .rank == 0 :
1315+ def _to_float (v : Any ) -> float :
1316+ return v .detach ().item () if torch .is_tensor (v ) else float (v )
1317+
13151318 if not self .multi_task :
13161319 train_results = {
1317- k : v for k , v in more_loss .items () if "l2_" not in k
1320+ k : _to_float (v )
1321+ for k , v in more_loss .items ()
1322+ if "l2_" not in k
13181323 }
13191324
13201325 # validation
@@ -1335,7 +1340,8 @@ def run(self) -> None:
13351340 for k , v in _vmore .items ():
13361341 if "l2_" not in k :
13371342 valid_results [k ] = (
1338- valid_results .get (k , 0.0 ) + v * natoms
1343+ valid_results .get (k , 0.0 )
1344+ + _to_float (v ) * natoms
13391345 )
13401346 if sum_natoms > 0 :
13411347 valid_results = {
@@ -1348,7 +1354,9 @@ def run(self) -> None:
13481354
13491355 # current task already has loss
13501356 train_results [task_key ] = {
1351- k : v for k , v in more_loss .items () if "l2_" not in k
1357+ k : _to_float (v )
1358+ for k , v in more_loss .items ()
1359+ if "l2_" not in k
13521360 }
13531361
13541362 # compute loss for other tasks
@@ -1363,7 +1371,9 @@ def run(self) -> None:
13631371 task_key = _key ,
13641372 )
13651373 train_results [_key ] = {
1366- k : v for k , v in _more .items () if "l2_" not in k
1374+ k : _to_float (v )
1375+ for k , v in _more .items ()
1376+ if "l2_" not in k
13671377 }
13681378
13691379 # validation for each task
@@ -1387,7 +1397,10 @@ def run(self) -> None:
13871397 _sum_natoms += natoms
13881398 for k , v in _vmore .items ():
13891399 if "l2_" not in k :
1390- _vres [k ] = _vres .get (k , 0.0 ) + v * natoms
1400+ _vres [k ] = (
1401+ _vres .get (k , 0.0 )
1402+ + _to_float (v ) * natoms
1403+ )
13911404 if _sum_natoms > 0 :
13921405 _vres = {
13931406 k : v / _sum_natoms for k , v in _vres .items ()
0 commit comments