Skip to content

Commit 41757f2

Browse files
authored
refactor(training): remove unused learning_rate_dict multitask handling (#5278)
Remove the unused learning_rate_dict configuration option that allowed per-task learning rate settings in multitask training. This simplifies the code by using a single learning_rate configuration for all tasks. Changes: - Remove learning_rate_dict branch in loss initialization - Remove per-task lr_schedule dictionary creation - Remove isinstance(dict) check in training loop - Unify single-task and multi-task code paths Both PyTorch and Paddle backends are updated consistently. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Streamlined learning rate scheduling in multi-task training scenarios to ensure consistent initialization and computation of learning rates across all models. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent a3db25a commit 41757f2

2 files changed

Lines changed: 6 additions & 34 deletions

File tree

deepmd/pd/train/training.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -302,10 +302,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
302302
self.loss = {}
303303
for model_key in self.model_keys:
304304
loss_param = config["loss_dict"][model_key]
305-
if config.get("learning_rate_dict", None) is not None:
306-
lr_param = config["learning_rate_dict"][model_key]["start_lr"]
307-
else:
308-
lr_param = config["learning_rate"]["start_lr"]
305+
lr_param = config["learning_rate"]["start_lr"]
309306
ntypes = len(model_params["model_dict"][model_key]["type_map"])
310307
self.loss[model_key] = get_loss(
311308
loss_param, lr_param, ntypes, self.model[model_key]
@@ -476,14 +473,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
476473

477474
# Learning rate
478475
self.gradient_max_norm = training_params.get("gradient_max_norm", 0.0)
479-
if self.multi_task and config.get("learning_rate_dict", None) is not None:
480-
self.lr_schedule = {}
481-
for model_key in self.model_keys:
482-
self.lr_schedule[model_key] = get_lr(
483-
config["learning_rate_dict"][model_key]
484-
)
485-
else:
486-
self.lr_schedule = get_lr(config["learning_rate"])
476+
self.lr_schedule = get_lr(config["learning_rate"])
487477

488478
# JIT
489479
if JIT:
@@ -806,11 +796,7 @@ def step(_step_id: int, task_key: str = "Default") -> None:
806796
# Paddle Profiler
807797
if enable_profiling:
808798
core.nvprof_nvtx_push(f"Training step {_step_id}")
809-
if isinstance(self.lr_schedule, dict):
810-
_lr = self.lr_schedule[task_key]
811-
else:
812-
_lr = self.lr_schedule
813-
cur_lr = _lr.value(_step_id)
799+
cur_lr = self.lr_schedule.value(_step_id)
814800
pref_lr = cur_lr
815801

816802
with nvprof_context(enable_profiling, "Fetching data"):

deepmd/pt/train/training.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -365,10 +365,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
365365
self.loss = {}
366366
for model_key in self.model_keys:
367367
loss_param = config["loss_dict"][model_key]
368-
if config.get("learning_rate_dict", None) is not None:
369-
lr_param = config["learning_rate_dict"][model_key]["start_lr"]
370-
else:
371-
lr_param = config["learning_rate"]["start_lr"]
368+
lr_param = config["learning_rate"]["start_lr"]
372369
ntypes = len(model_params["model_dict"][model_key]["type_map"])
373370
self.loss[model_key] = get_loss(
374371
loss_param, lr_param, ntypes, self.model[model_key]
@@ -548,14 +545,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
548545

549546
# Learning rate
550547
self.gradient_max_norm = training_params.get("gradient_max_norm", 0.0)
551-
if self.multi_task and config.get("learning_rate_dict", None) is not None:
552-
self.lr_schedule = {}
553-
for model_key in self.model_keys:
554-
self.lr_schedule[model_key] = get_lr(
555-
config["learning_rate_dict"][model_key]
556-
)
557-
else:
558-
self.lr_schedule = get_lr(config["learning_rate"])
548+
self.lr_schedule = get_lr(config["learning_rate"])
559549

560550
# JIT
561551
if JIT:
@@ -1027,11 +1017,7 @@ def step(_step_id: int, task_key: str = "Default") -> None:
10271017
# PyTorch Profiler
10281018
if self.enable_profiler or self.profiling:
10291019
prof.step()
1030-
if isinstance(self.lr_schedule, dict):
1031-
_lr = self.lr_schedule[task_key]
1032-
else:
1033-
_lr = self.lr_schedule
1034-
cur_lr = _lr.value(_step_id)
1020+
cur_lr = self.lr_schedule.value(_step_id)
10351021
pref_lr = cur_lr
10361022
self.optimizer.zero_grad(set_to_none=True)
10371023
input_dict, label_dict, log_dict = self.get_data(

0 commit comments

Comments
 (0)