Skip to content

Commit 813bbc8

Browse files
OutisLipre-commit-ci[bot]njzjz
authored
refactor(training): Average training loss for smoother and more representative logging (#4850)
This pull request modifies the training loop to improve the quality and readability of the reported training loss. ## Summary of Changes Previously, the training loss and associated metrics (e.g., rmse_e_trn) reported in lcurve.out and the console log at each disp_freq step represented the instantaneous value from that single training batch. This could be quite noisy and subject to high variance depending on the specific batch sampled. This PR introduces an accumulator for the training loss. The key changes are: During each training step, the loss values are accumulated. When a display step is reached, the accumulated values are averaged over the number of steps in that interval. This averaged loss is then reported in the log and lcurve.out. The accumulators are reset for the next interval. The validation logic remains unchanged, continuing to provide a periodic snapshot of model performance, which is the standard and efficient approach. ## Significance and Benefits Reporting the averaged training loss provides a much smoother and more representative training curve. The benefits include: Reduced Noise: Eliminates high-frequency fluctuations, making it easier to see the true learning trend. Improved Readability: Plotted learning curves from lcurve.out are cleaner and more interpretable. Better Comparability: Simplifies the comparison of model performance across different training runs, as the impact of single-batch anomalies is minimized. ## A Note on Formatting Please note that due to automatic code formatters (e.g., black, isort), some minor, purely stylistic changes may appear in the diff that are not directly related to the core logic. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Training loss values displayed during training are now averaged over the display interval, providing more stable and representative loss metrics for both single-task and multi-task modes. * Added an option to enable or disable averaging of training loss display via a new configuration setting. * **Improvements** * Enhanced training loss reporting for improved monitoring and analysis during model training. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: LI TIANCHENG <137472077+OutisLi@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jinzhe Zeng <njzjz@qq.com>
1 parent 1dc1248 commit 813bbc8

File tree

2 files changed

+121
-25
lines changed

2 files changed

+121
-25
lines changed

deepmd/pt/train/training.py

Lines changed: 111 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def __init__(
140140
self.num_steps = training_params["numb_steps"]
141141
self.disp_file = training_params.get("disp_file", "lcurve.out")
142142
self.disp_freq = training_params.get("disp_freq", 1000)
143+
self.disp_avg = training_params.get("disp_avg", False)
143144
self.save_ckpt = training_params.get("save_ckpt", "model.ckpt")
144145
self.save_freq = training_params.get("save_freq", 1000)
145146
self.max_ckpt_keep = training_params.get("max_ckpt_keep", 5)
@@ -808,23 +809,75 @@ def fake_model():
808809
else:
809810
raise ValueError(f"Not supported optimizer type '{self.opt_type}'")
810811

812+
if self.disp_avg:
813+
# Accumulate loss for averaging over display interval
814+
self.step_count_in_interval += 1
815+
if not self.multi_task:
816+
# Accumulate loss for single task
817+
if not self.train_loss_accu:
818+
# Initialize accumulator with current loss structure
819+
for item in more_loss:
820+
if "l2_" not in item:
821+
self.train_loss_accu[item] = 0.0
822+
for item in more_loss:
823+
if "l2_" not in item:
824+
self.train_loss_accu[item] += more_loss[item]
825+
else:
826+
# Accumulate loss for multi-task
827+
if task_key not in self.train_loss_accu:
828+
self.train_loss_accu[task_key] = {}
829+
if task_key not in self.step_count_per_task:
830+
self.step_count_per_task[task_key] = 0
831+
self.step_count_per_task[task_key] += 1
832+
833+
for item in more_loss:
834+
if "l2_" not in item:
835+
if item not in self.train_loss_accu[task_key]:
836+
self.train_loss_accu[task_key][item] = 0.0
837+
self.train_loss_accu[task_key][item] += more_loss[item]
838+
811839
# Log and persist
812840
display_step_id = _step_id + 1
813841
if self.display_in_training and (
814842
display_step_id % self.disp_freq == 0 or display_step_id == 1
815843
):
816844
self.wrapper.eval() # Will set to train mode before fininshing validation
817845

818-
def log_loss_train(_loss, _more_loss, _task_key="Default"):
819-
results = {}
820-
rmse_val = {
821-
item: _more_loss[item]
822-
for item in _more_loss
823-
if "l2_" not in item
824-
}
825-
for item in sorted(rmse_val.keys()):
826-
results[item] = rmse_val[item]
827-
return results
846+
if self.disp_avg:
847+
848+
def log_loss_train(_loss, _more_loss, _task_key="Default"):
849+
results = {}
850+
if not self.multi_task:
851+
# Use accumulated average loss for single task
852+
for item in self.train_loss_accu:
853+
results[item] = (
854+
self.train_loss_accu[item]
855+
/ self.step_count_in_interval
856+
)
857+
else:
858+
# Use accumulated average loss for multi-task
859+
if (
860+
_task_key in self.train_loss_accu
861+
and _task_key in self.step_count_per_task
862+
):
863+
for item in self.train_loss_accu[_task_key]:
864+
results[item] = (
865+
self.train_loss_accu[_task_key][item]
866+
/ self.step_count_per_task[_task_key]
867+
)
868+
return results
869+
else:
870+
871+
def log_loss_train(_loss, _more_loss, _task_key="Default"):
872+
results = {}
873+
rmse_val = {
874+
item: _more_loss[item]
875+
for item in _more_loss
876+
if "l2_" not in item
877+
}
878+
for item in sorted(rmse_val.keys()):
879+
results[item] = rmse_val[item]
880+
return results
828881

829882
def log_loss_valid(_task_key="Default"):
830883
single_results = {}
@@ -882,24 +935,31 @@ def log_loss_valid(_task_key="Default"):
882935
else:
883936
train_results = {_key: {} for _key in self.model_keys}
884937
valid_results = {_key: {} for _key in self.model_keys}
885-
train_results[task_key] = log_loss_train(
886-
loss, more_loss, _task_key=task_key
887-
)
888-
for _key in self.model_keys:
889-
if _key != task_key:
890-
self.optimizer.zero_grad()
891-
input_dict, label_dict, _ = self.get_data(
892-
is_train=True, task_key=_key
893-
)
894-
_, loss, more_loss = self.wrapper(
895-
**input_dict,
896-
cur_lr=pref_lr,
897-
label=label_dict,
898-
task_key=_key,
899-
)
938+
if self.disp_avg:
939+
# For multi-task, use accumulated average loss for all tasks
940+
for _key in self.model_keys:
900941
train_results[_key] = log_loss_train(
901942
loss, more_loss, _task_key=_key
902943
)
944+
else:
945+
train_results[task_key] = log_loss_train(
946+
loss, more_loss, _task_key=task_key
947+
)
948+
for _key in self.model_keys:
949+
if _key != task_key:
950+
self.optimizer.zero_grad()
951+
input_dict, label_dict, _ = self.get_data(
952+
is_train=True, task_key=_key
953+
)
954+
_, loss, more_loss = self.wrapper(
955+
**input_dict,
956+
cur_lr=pref_lr,
957+
label=label_dict,
958+
task_key=_key,
959+
)
960+
train_results[_key] = log_loss_train(
961+
loss, more_loss, _task_key=_key
962+
)
903963
valid_results[_key] = log_loss_valid(_task_key=_key)
904964
if self.rank == 0:
905965
log.info(
@@ -921,6 +981,21 @@ def log_loss_valid(_task_key="Default"):
921981
)
922982
self.wrapper.train()
923983

984+
if self.disp_avg:
985+
# Reset loss accumulators after display
986+
if not self.multi_task:
987+
for item in self.train_loss_accu:
988+
self.train_loss_accu[item] = 0.0
989+
else:
990+
for task_key in self.model_keys:
991+
if task_key in self.train_loss_accu:
992+
for item in self.train_loss_accu[task_key]:
993+
self.train_loss_accu[task_key][item] = 0.0
994+
if task_key in self.step_count_per_task:
995+
self.step_count_per_task[task_key] = 0
996+
self.step_count_in_interval = 0
997+
self.last_display_step = display_step_id
998+
924999
current_time = time.time()
9251000
train_time = current_time - self.t0
9261001
self.t0 = current_time
@@ -993,6 +1068,17 @@ def log_loss_valid(_task_key="Default"):
9931068
self.t0 = time.time()
9941069
self.total_train_time = 0.0
9951070
self.timed_steps = 0
1071+
1072+
if self.disp_avg:
1073+
# Initialize loss accumulators
1074+
if not self.multi_task:
1075+
self.train_loss_accu = {}
1076+
else:
1077+
self.train_loss_accu = {key: {} for key in self.model_keys}
1078+
self.step_count_per_task = dict.fromkeys(self.model_keys, 0)
1079+
self.step_count_in_interval = 0
1080+
self.last_display_step = 0
1081+
9961082
for step_id in range(self.start_step, self.num_steps):
9971083
step(step_id)
9981084
if JIT:

deepmd/utils/argcheck.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3137,6 +3137,9 @@ def training_args(
31373137
)
31383138
doc_disp_training = "Displaying verbose information during training."
31393139
doc_time_training = "Timing during training."
3140+
doc_disp_avg = (
3141+
"Display the average loss over the display interval for training sets."
3142+
)
31403143
doc_profiling = "Export the profiling results to the Chrome JSON file for performance analysis, driven by the legacy TensorFlow profiling API or PyTorch Profiler. The output file will be saved to `profiling_file`."
31413144
doc_profiling_file = "Output file for profiling."
31423145
doc_enable_profiler = "Export the profiling results to the TensorBoard log for performance analysis, driven by TensorFlow Profiler (available in TensorFlow 2.3) or PyTorch Profiler. The log will be saved to `tensorboard_log_dir`."
@@ -3213,6 +3216,13 @@ def training_args(
32133216
Argument(
32143217
"time_training", bool, optional=True, default=True, doc=doc_time_training
32153218
),
3219+
Argument(
3220+
"disp_avg",
3221+
bool,
3222+
optional=True,
3223+
default=False,
3224+
doc=doc_only_pt_supported + doc_disp_avg,
3225+
),
32163226
Argument(
32173227
"profiling",
32183228
bool,

0 commit comments

Comments
 (0)