Skip to content

Commit 9a7e2a8

Browse files
committed
remove unused optim_dict
1 parent 42f9eec commit 9a7e2a8

2 files changed

Lines changed: 2 additions & 32 deletions

File tree

deepmd/pd/train/training.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -260,22 +260,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
260260
return lr_schedule
261261

262262
# Optimizer
263-
if self.multi_task and training_params.get("optim_dict", None) is not None:
264-
self.optim_dict = training_params.get("optim_dict")
265-
missing_keys = [
266-
key for key in self.model_keys if key not in self.optim_dict
267-
]
268-
assert not missing_keys, (
269-
f"These keys are not in optim_dict: {missing_keys}!"
270-
)
271-
self.opt_type = {}
272-
self.opt_param = {}
273-
for model_key in self.model_keys:
274-
self.opt_type[model_key], self.opt_param[model_key] = get_opt_param(
275-
self.optim_dict[model_key]
276-
)
277-
else:
278-
self.opt_type, self.opt_param = get_opt_param(optimizer_params)
263+
self.opt_type, self.opt_param = get_opt_param(optimizer_params)
279264

280265
# loss_param_tmp for Hessian activation
281266
loss_param_tmp = None

deepmd/pt/train/training.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -305,22 +305,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
305305
return lr_schedule
306306

307307
# Optimizer
308-
if self.multi_task and training_params.get("optim_dict", None) is not None:
309-
self.optim_dict = training_params.get("optim_dict")
310-
missing_keys = [
311-
key for key in self.model_keys if key not in self.optim_dict
312-
]
313-
assert not missing_keys, (
314-
f"These keys are not in optim_dict: {missing_keys}!"
315-
)
316-
self.opt_type = {}
317-
self.opt_param = {}
318-
for model_key in self.model_keys:
319-
self.opt_type[model_key], self.opt_param[model_key] = get_opt_param(
320-
self.optim_dict[model_key]
321-
)
322-
else:
323-
self.opt_type, self.opt_param = get_opt_param(optimizer_params)
308+
self.opt_type, self.opt_param = get_opt_param(optimizer_params)
324309
if self.zero_stage > 0 and self.multi_task:
325310
raise ValueError(
326311
"training.zero_stage is currently only supported in single-task training."

0 commit comments

Comments
 (0)