Skip to content

Commit fe6cc8c

Browse files
committed
fix pd
1 parent 56af3e5 commit fe6cc8c

1 file changed

Lines changed: 9 additions & 0 deletions

File tree

deepmd/pd/train/training.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def __init__(
130130
else 1
131131
)
132132
self.num_model = len(self.model_keys)
133+
self.model_prob = None
133134

134135
# Iteration config
135136
self.num_steps = training_params.get("numb_steps")
@@ -749,6 +750,14 @@ def single_model_finetune(
749750
frz_model = paddle.jit.load(init_frz_model)
750751
self.model.set_state_dict(frz_model.state_dict())
751752

753+
# Get model prob for multi-task
754+
if self.multi_task and self.model_prob is None:
755+
self.model_prob = resolve_model_prob(
756+
self.model_keys,
757+
training_params.get("model_prob"),
758+
training_data,
759+
)
760+
752761
# Multi-task share params
753762
if shared_links is not None:
754763
self.wrapper.share_params(

0 commit comments

Comments
 (0)