From 1269ed8c9a40df78e6d28a148145b6760a7d2a4a Mon Sep 17 00:00:00 2001 From: OutisLi Date: Mon, 2 Mar 2026 10:54:55 +0800 Subject: [PATCH] refactor(training): remove unused learning_rate_dict multitask handling Remove the unused learning_rate_dict configuration option that allowed per-task learning rate settings in multitask training. This simplifies the code by using a single learning_rate configuration for all tasks. Changes: - Remove learning_rate_dict branch in loss initialization - Remove per-task lr_schedule dictionary creation - Remove isinstance(dict) check in training loop - Unify single-task and multi-task code paths Both PyTorch and Paddle backends are updated consistently. --- deepmd/pd/train/training.py | 20 +++----------------- deepmd/pt/train/training.py | 20 +++----------------- 2 files changed, 6 insertions(+), 34 deletions(-) diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index b3702c8255..649e816e0b 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -302,10 +302,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: self.loss = {} for model_key in self.model_keys: loss_param = config["loss_dict"][model_key] - if config.get("learning_rate_dict", None) is not None: - lr_param = config["learning_rate_dict"][model_key]["start_lr"] - else: - lr_param = config["learning_rate"]["start_lr"] + lr_param = config["learning_rate"]["start_lr"] ntypes = len(model_params["model_dict"][model_key]["type_map"]) self.loss[model_key] = get_loss( loss_param, lr_param, ntypes, self.model[model_key] @@ -476,14 +473,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: # Learning rate self.gradient_max_norm = training_params.get("gradient_max_norm", 0.0) - if self.multi_task and config.get("learning_rate_dict", None) is not None: - self.lr_schedule = {} - for model_key in self.model_keys: - self.lr_schedule[model_key] = get_lr( - config["learning_rate_dict"][model_key] - ) - else: - self.lr_schedule = get_lr(config["learning_rate"]) + self.lr_schedule = get_lr(config["learning_rate"]) # JIT if JIT: @@ -806,11 +796,7 @@ def step(_step_id: int, task_key: str = "Default") -> None: # Paddle Profiler if enable_profiling: core.nvprof_nvtx_push(f"Training step {_step_id}") - if isinstance(self.lr_schedule, dict): - _lr = self.lr_schedule[task_key] - else: - _lr = self.lr_schedule - cur_lr = _lr.value(_step_id) + cur_lr = self.lr_schedule.value(_step_id) pref_lr = cur_lr with nvprof_context(enable_profiling, "Fetching data"): diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index f28edb1430..dcf4ea3e13 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -365,10 +365,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: self.loss = {} for model_key in self.model_keys: loss_param = config["loss_dict"][model_key] - if config.get("learning_rate_dict", None) is not None: - lr_param = config["learning_rate_dict"][model_key]["start_lr"] - else: - lr_param = config["learning_rate"]["start_lr"] + lr_param = config["learning_rate"]["start_lr"] ntypes = len(model_params["model_dict"][model_key]["type_map"]) self.loss[model_key] = get_loss( loss_param, lr_param, ntypes, self.model[model_key] @@ -548,14 +545,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: # Learning rate self.gradient_max_norm = training_params.get("gradient_max_norm", 0.0) - if self.multi_task and config.get("learning_rate_dict", None) is not None: - self.lr_schedule = {} - for model_key in self.model_keys: - self.lr_schedule[model_key] = get_lr( - config["learning_rate_dict"][model_key] - ) - else: - self.lr_schedule = get_lr(config["learning_rate"]) + self.lr_schedule = get_lr(config["learning_rate"]) # JIT if JIT: @@ -1027,11 +1017,7 @@ def step(_step_id: int, task_key: str = "Default") -> None: # PyTorch Profiler if self.enable_profiler or self.profiling: prof.step() - if isinstance(self.lr_schedule, dict): - _lr = self.lr_schedule[task_key] - else: - _lr = self.lr_schedule - cur_lr = _lr.value(_step_id) + cur_lr = self.lr_schedule.value(_step_id) pref_lr = cur_lr self.optimizer.zero_grad(set_to_none=True) input_dict, label_dict, log_dict = self.get_data(