Skip to content

Commit 739bd1a

Browse files
committed
fix: missing log in multitask training
1 parent adf5b83 commit 739bd1a

1 file changed

Lines changed: 14 additions & 14 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1285,25 +1285,25 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
12851285
train_results[_key] = log_loss_train(
12861286
loss, more_loss, _task_key=_key
12871287
)
1288-
valid_results[_key] = log_loss_valid(_task_key=_key)
1289-
if self.rank == 0:
1290-
log.info(
1291-
format_training_message_per_task(
1292-
batch=display_step_id,
1293-
task_name=_key + "_trn",
1294-
rmse=train_results[_key],
1295-
learning_rate=cur_lr,
1296-
)
1297-
)
1298-
if valid_results[_key]:
1288+
valid_results[_key] = log_loss_valid(_task_key=_key)
1289+
if self.rank == 0:
12991290
log.info(
13001291
format_training_message_per_task(
13011292
batch=display_step_id,
1302-
task_name=_key + "_val",
1303-
rmse=valid_results[_key],
1304-
learning_rate=None,
1293+
task_name=_key + "_trn",
1294+
rmse=train_results[_key],
1295+
learning_rate=cur_lr,
13051296
)
13061297
)
1298+
if valid_results[_key]:
1299+
log.info(
1300+
format_training_message_per_task(
1301+
batch=display_step_id,
1302+
task_name=_key + "_val",
1303+
rmse=valid_results[_key],
1304+
learning_rate=None,
1305+
)
1306+
)
13071307
self.wrapper.train()
13081308

13091309
if self.disp_avg:

0 commit comments

Comments
 (0)