diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 193dcd8cb9..de387641c6 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -140,6 +140,7 @@ def __init__( self.num_steps = training_params["numb_steps"] self.disp_file = training_params.get("disp_file", "lcurve.out") self.disp_freq = training_params.get("disp_freq", 1000) + self.disp_avg = training_params.get("disp_avg", False) self.save_ckpt = training_params.get("save_ckpt", "model.ckpt") self.save_freq = training_params.get("save_freq", 1000) self.max_ckpt_keep = training_params.get("max_ckpt_keep", 5) @@ -808,6 +809,33 @@ def fake_model(): else: raise ValueError(f"Not supported optimizer type '{self.opt_type}'") + if self.disp_avg: + # Accumulate loss for averaging over display interval + self.step_count_in_interval += 1 + if not self.multi_task: + # Accumulate loss for single task + if not self.train_loss_accu: + # Initialize accumulator with current loss structure + for item in more_loss: + if "l2_" not in item: + self.train_loss_accu[item] = 0.0 + for item in more_loss: + if "l2_" not in item: + self.train_loss_accu[item] += more_loss[item] + else: + # Accumulate loss for multi-task + if task_key not in self.train_loss_accu: + self.train_loss_accu[task_key] = {} + if task_key not in self.step_count_per_task: + self.step_count_per_task[task_key] = 0 + self.step_count_per_task[task_key] += 1 + + for item in more_loss: + if "l2_" not in item: + if item not in self.train_loss_accu[task_key]: + self.train_loss_accu[task_key][item] = 0.0 + self.train_loss_accu[task_key][item] += more_loss[item] + # Log and persist display_step_id = _step_id + 1 if self.display_in_training and ( @@ -815,16 +843,41 @@ def fake_model(): ): self.wrapper.eval() # Will set to train mode before fininshing validation - def log_loss_train(_loss, _more_loss, _task_key="Default"): - results = {} - rmse_val = { - item: _more_loss[item] - for item in _more_loss - if "l2_" not in item - } - for item in sorted(rmse_val.keys()): - results[item] = rmse_val[item] - return results + if self.disp_avg: + + def log_loss_train(_loss, _more_loss, _task_key="Default"): + results = {} + if not self.multi_task: + # Use accumulated average loss for single task + for item in self.train_loss_accu: + results[item] = ( + self.train_loss_accu[item] + / self.step_count_in_interval + ) + else: + # Use accumulated average loss for multi-task + if ( + _task_key in self.train_loss_accu + and _task_key in self.step_count_per_task + ): + for item in self.train_loss_accu[_task_key]: + results[item] = ( + self.train_loss_accu[_task_key][item] + / self.step_count_per_task[_task_key] + ) + return results + else: + + def log_loss_train(_loss, _more_loss, _task_key="Default"): + results = {} + rmse_val = { + item: _more_loss[item] + for item in _more_loss + if "l2_" not in item + } + for item in sorted(rmse_val.keys()): + results[item] = rmse_val[item] + return results def log_loss_valid(_task_key="Default"): single_results = {} @@ -882,24 +935,31 @@ def log_loss_valid(_task_key="Default"): else: train_results = {_key: {} for _key in self.model_keys} valid_results = {_key: {} for _key in self.model_keys} - train_results[task_key] = log_loss_train( - loss, more_loss, _task_key=task_key - ) - for _key in self.model_keys: - if _key != task_key: - self.optimizer.zero_grad() - input_dict, label_dict, _ = self.get_data( - is_train=True, task_key=_key - ) - _, loss, more_loss = self.wrapper( - **input_dict, - cur_lr=pref_lr, - label=label_dict, - task_key=_key, - ) + if self.disp_avg: + # For multi-task, use accumulated average loss for all tasks + for _key in self.model_keys: train_results[_key] = log_loss_train( loss, more_loss, _task_key=_key ) + else: + train_results[task_key] = log_loss_train( + loss, more_loss, _task_key=task_key + ) + for _key in self.model_keys: + if _key != task_key: + self.optimizer.zero_grad() + input_dict, label_dict, _ = self.get_data( + is_train=True, task_key=_key + ) + _, loss, more_loss = self.wrapper( + **input_dict, + cur_lr=pref_lr, + label=label_dict, + task_key=_key, + ) + train_results[_key] = log_loss_train( + loss, more_loss, _task_key=_key + ) valid_results[_key] = log_loss_valid(_task_key=_key) if self.rank == 0: log.info( @@ -921,6 +981,21 @@ def log_loss_valid(_task_key="Default"): ) self.wrapper.train() + if self.disp_avg: + # Reset loss accumulators after display + if not self.multi_task: + for item in self.train_loss_accu: + self.train_loss_accu[item] = 0.0 + else: + for task_key in self.model_keys: + if task_key in self.train_loss_accu: + for item in self.train_loss_accu[task_key]: + self.train_loss_accu[task_key][item] = 0.0 + if task_key in self.step_count_per_task: + self.step_count_per_task[task_key] = 0 + self.step_count_in_interval = 0 + self.last_display_step = display_step_id + current_time = time.time() train_time = current_time - self.t0 self.t0 = current_time @@ -993,6 +1068,17 @@ def log_loss_valid(_task_key="Default"): self.t0 = time.time() self.total_train_time = 0.0 self.timed_steps = 0 + + if self.disp_avg: + # Initialize loss accumulators + if not self.multi_task: + self.train_loss_accu = {} + else: + self.train_loss_accu = {key: {} for key in self.model_keys} + self.step_count_per_task = dict.fromkeys(self.model_keys, 0) + self.step_count_in_interval = 0 + self.last_display_step = 0 + for step_id in range(self.start_step, self.num_steps): step(step_id) if JIT: diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index fb911550dd..e4c15ebd21 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -3137,6 +3137,9 @@ def training_args( ) doc_disp_training = "Displaying verbose information during training." doc_time_training = "Timing during training." + doc_disp_avg = ( + "Display the average loss over the display interval for training sets." + ) doc_profiling = "Export the profiling results to the Chrome JSON file for performance analysis, driven by the legacy TensorFlow profiling API or PyTorch Profiler. The output file will be saved to `profiling_file`." doc_profiling_file = "Output file for profiling." doc_enable_profiler = "Export the profiling results to the TensorBoard log for performance analysis, driven by TensorFlow Profiler (available in TensorFlow 2.3) or PyTorch Profiler. The log will be saved to `tensorboard_log_dir`." @@ -3213,6 +3216,13 @@ def training_args( Argument( "time_training", bool, optional=True, default=True, doc=doc_time_training ), + Argument( + "disp_avg", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_disp_avg, + ), Argument( "profiling", bool,