You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: deepmd/utils/argcheck.py
+3Lines changed: 3 additions & 0 deletions
Original file line number
Diff line number
Diff line change
@@ -3389,6 +3389,9 @@ def training_args(
3389
3389
Argument("use_pcgrad", bool, optional=True, default=False, doc="Apply PCGrad gradient surgery on the shared descriptor parameters in multi-task training."),
3390
3390
Argument("use_dual_batch", bool, optional=True, default=False, doc="Sample all tasks every step and sum gradients without projection. Use as control group to isolate PCGrad effect from dual-batch effect."),
3391
3391
Argument("alternating_tasks", bool, optional=True, default=False, doc="Cycle through tasks deterministically (A→B→A→B) each step instead of random sampling. Ablation control to isolate balanced-sampling effect from combined-gradient effect."),
3392
+
Argument("grad_norm_reweight", bool, optional=True, default=False, doc="(dual-batch only) Reweight per-task gradients inversely proportional to their EMA gradient norm before combining, equalizing each task's directional contribution to shared parameters."),
3393
+
Argument("loss_ratio_reweight", bool, optional=True, default=False, doc="(dual-batch only) Reweight per-task gradients proportional to their EMA loss value, giving more weight to the higher-loss task to prevent it from being sacrificed."),
3394
+
Argument("reweight_ema_decay", float, optional=True, default=0.99, doc="EMA decay factor for grad_norm_reweight and loss_ratio_reweight tracking. Higher values give smoother but slower-adapting estimates."),
0 commit comments