Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 3 additions & 17 deletions deepmd/pd/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,10 +302,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
self.loss = {}
for model_key in self.model_keys:
loss_param = config["loss_dict"][model_key]
if config.get("learning_rate_dict", None) is not None:
lr_param = config["learning_rate_dict"][model_key]["start_lr"]
else:
lr_param = config["learning_rate"]["start_lr"]
lr_param = config["learning_rate"]["start_lr"]
ntypes = len(model_params["model_dict"][model_key]["type_map"])
self.loss[model_key] = get_loss(
loss_param, lr_param, ntypes, self.model[model_key]
Expand Down Expand Up @@ -476,14 +473,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:

# Learning rate
self.gradient_max_norm = training_params.get("gradient_max_norm", 0.0)
if self.multi_task and config.get("learning_rate_dict", None) is not None:
self.lr_schedule = {}
for model_key in self.model_keys:
self.lr_schedule[model_key] = get_lr(
config["learning_rate_dict"][model_key]
)
else:
self.lr_schedule = get_lr(config["learning_rate"])
self.lr_schedule = get_lr(config["learning_rate"])

# JIT
if JIT:
Expand Down Expand Up @@ -806,11 +796,7 @@ def step(_step_id: int, task_key: str = "Default") -> None:
# Paddle Profiler
if enable_profiling:
core.nvprof_nvtx_push(f"Training step {_step_id}")
if isinstance(self.lr_schedule, dict):
_lr = self.lr_schedule[task_key]
else:
_lr = self.lr_schedule
cur_lr = _lr.value(_step_id)
cur_lr = self.lr_schedule.value(_step_id)
pref_lr = cur_lr

with nvprof_context(enable_profiling, "Fetching data"):
Expand Down
20 changes: 3 additions & 17 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,10 +365,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
self.loss = {}
for model_key in self.model_keys:
loss_param = config["loss_dict"][model_key]
if config.get("learning_rate_dict", None) is not None:
lr_param = config["learning_rate_dict"][model_key]["start_lr"]
else:
lr_param = config["learning_rate"]["start_lr"]
lr_param = config["learning_rate"]["start_lr"]
ntypes = len(model_params["model_dict"][model_key]["type_map"])
self.loss[model_key] = get_loss(
loss_param, lr_param, ntypes, self.model[model_key]
Expand Down Expand Up @@ -548,14 +545,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:

# Learning rate
self.gradient_max_norm = training_params.get("gradient_max_norm", 0.0)
if self.multi_task and config.get("learning_rate_dict", None) is not None:
self.lr_schedule = {}
for model_key in self.model_keys:
self.lr_schedule[model_key] = get_lr(
config["learning_rate_dict"][model_key]
)
else:
self.lr_schedule = get_lr(config["learning_rate"])
self.lr_schedule = get_lr(config["learning_rate"])

# JIT
if JIT:
Expand Down Expand Up @@ -1027,11 +1017,7 @@ def step(_step_id: int, task_key: str = "Default") -> None:
# PyTorch Profiler
if self.enable_profiler or self.profiling:
prof.step()
if isinstance(self.lr_schedule, dict):
_lr = self.lr_schedule[task_key]
else:
_lr = self.lr_schedule
cur_lr = _lr.value(_step_id)
cur_lr = self.lr_schedule.value(_step_id)
pref_lr = cur_lr
self.optimizer.zero_grad(set_to_none=True)
input_dict, label_dict, log_dict = self.get_data(
Expand Down
Loading