We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 56af3e5 commit fe6cc8cCopy full SHA for fe6cc8c
1 file changed
deepmd/pd/train/training.py
@@ -130,6 +130,7 @@ def __init__(
130
else 1
131
)
132
self.num_model = len(self.model_keys)
133
+ self.model_prob = None
134
135
# Iteration config
136
self.num_steps = training_params.get("numb_steps")
@@ -749,6 +750,14 @@ def single_model_finetune(
749
750
frz_model = paddle.jit.load(init_frz_model)
751
self.model.set_state_dict(frz_model.state_dict())
752
753
+ # Get model prob for multi-task
754
+ if self.multi_task and self.model_prob is None:
755
+ self.model_prob = resolve_model_prob(
756
+ self.model_keys,
757
+ training_params.get("model_prob"),
758
+ training_data,
759
+ )
760
+
761
# Multi-task share params
762
if shared_links is not None:
763
self.wrapper.share_params(
0 commit comments