Skip to content

Commit 043586d

Browse files
committed
feat: ema norm scaling
1 parent 9e0f0bf commit 043586d

2 files changed

Lines changed: 94 additions & 12 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 91 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -634,10 +634,32 @@ def single_model_finetune(
634634
not self.use_dual_batch
635635
and training_params.get("alternating_tasks", False)
636636
)
637+
self.use_grad_norm_reweight = self.use_dual_batch and training_params.get(
638+
"grad_norm_reweight", False
639+
)
640+
self.use_loss_ratio_reweight = self.use_dual_batch and training_params.get(
641+
"loss_ratio_reweight", False
642+
)
643+
self.reweight_ema_decay = training_params.get("reweight_ema_decay", 0.99)
644+
if self.use_grad_norm_reweight or self.use_loss_ratio_reweight:
645+
self.grad_norm_ema = {k: None for k in self.model_keys}
646+
self.loss_val_ema = {k: None for k in self.model_keys}
637647
if self.use_pcgrad:
638648
log.info("PCGrad enabled: descriptor gradients will be projected each step.")
639649
elif self.use_dual_batch:
640-
log.info("Dual-batch enabled: all tasks sampled per step, gradients summed without projection.")
650+
reweight_modes = []
651+
if self.use_grad_norm_reweight:
652+
reweight_modes.append("grad-norm-EMA")
653+
if self.use_loss_ratio_reweight:
654+
reweight_modes.append("loss-ratio")
655+
if reweight_modes:
656+
log.info(
657+
"Dual-batch enabled with reweighting: %s (ema_decay=%.3f).",
658+
"+".join(reweight_modes),
659+
self.reweight_ema_decay,
660+
)
661+
else:
662+
log.info("Dual-batch enabled: all tasks sampled per step, gradients averaged.")
641663
elif self.use_alternating:
642664
log.info("Alternating-tasks enabled: tasks cycled deterministically A→B→A→B each step.")
643665

@@ -769,13 +791,14 @@ def step(_step_id, task_key="Default") -> None:
769791
cur_lr = _lr.value(_step_id)
770792
pref_lr = cur_lr
771793
self.optimizer.zero_grad(set_to_none=True)
772-
input_dict, label_dict, log_dict = self.get_data(
773-
is_train=True, task_key=task_key
774-
)
775-
if SAMPLER_RECORD:
776-
print_str = f"Step {_step_id}: sample system{log_dict['sid']} frame{log_dict['fid']}\n"
777-
fout1.write(print_str)
778-
fout1.flush()
794+
if not self.use_dual_batch:
795+
input_dict, label_dict, log_dict = self.get_data(
796+
is_train=True, task_key=task_key
797+
)
798+
if SAMPLER_RECORD:
799+
print_str = f"Step {_step_id}: sample system{log_dict['sid']} frame{log_dict['fid']}\n"
800+
fout1.write(print_str)
801+
fout1.flush()
779802
if self.opt_type in ["Adam", "AdamW"]:
780803
cur_lr = self.scheduler.get_last_lr()[0]
781804
if _step_id < self.warmup_steps:
@@ -786,9 +809,14 @@ def step(_step_id, task_key="Default") -> None:
786809
all_params = list(self.wrapper.parameters())
787810
task_grads = {}
788811
more_loss_per_task = {}
812+
task_loss_val = {}
789813
for tk in self.model_keys:
790814
self.optimizer.zero_grad(set_to_none=True)
791-
in_d, lbl_d, _ = self.get_data(is_train=True, task_key=tk)
815+
in_d, lbl_d, log_d = self.get_data(is_train=True, task_key=tk)
816+
if SAMPLER_RECORD and tk == self.model_keys[0]:
817+
print_str = f"Step {_step_id}: sample system{log_d['sid']} frame{log_d['fid']}\n"
818+
fout1.write(print_str)
819+
fout1.flush()
792820
_, loss, more_loss = self.wrapper(
793821
**in_d, cur_lr=pref_lr, label=lbl_d, task_key=tk
794822
)
@@ -798,6 +826,7 @@ def step(_step_id, task_key="Default") -> None:
798826
for p in all_params
799827
}
800828
more_loss_per_task[tk] = more_loss
829+
task_loss_val[tk] = loss.item()
801830

802831
k0, k1 = self.model_keys[0], self.model_keys[1]
803832

@@ -836,14 +865,64 @@ def step(_step_id, task_key="Default") -> None:
836865
_step_id + 1, dot.item(), cos_sim.item(), projected,
837866
)
838867

839-
# Set final grads: sum of (projected) per-task grads
868+
# Compute per-task weights for gradient combination
869+
task_weights = {k: 1.0 for k in self.model_keys}
870+
d = self.reweight_ema_decay
871+
872+
if self.use_grad_norm_reweight:
873+
for k in self.model_keys:
874+
grads = [
875+
task_grads[k][id(p)]
876+
for p in all_params
877+
if task_grads[k][id(p)] is not None
878+
]
879+
if grads:
880+
cur_norm = torch.stack(
881+
[g.norm() for g in grads]
882+
).norm().item()
883+
if self.grad_norm_ema[k] is None:
884+
self.grad_norm_ema[k] = cur_norm
885+
else:
886+
self.grad_norm_ema[k] = (
887+
d * self.grad_norm_ema[k]
888+
+ (1 - d) * cur_norm
889+
)
890+
task_weights[k] /= self.grad_norm_ema[k] + 1e-8
891+
892+
if self.use_loss_ratio_reweight:
893+
for k in self.model_keys:
894+
cur_loss_val = task_loss_val[k]
895+
if self.loss_val_ema[k] is None:
896+
self.loss_val_ema[k] = cur_loss_val
897+
else:
898+
self.loss_val_ema[k] = (
899+
d * self.loss_val_ema[k]
900+
+ (1 - d) * cur_loss_val
901+
)
902+
task_weights[k] *= self.loss_val_ema[k]
903+
904+
# Normalize weights to sum to 1 (keeps gradient scale on par with baseline)
905+
w_sum = sum(task_weights.values())
906+
for k in task_weights:
907+
task_weights[k] /= w_sum
908+
909+
if self.rank == 0 and (
910+
self.use_grad_norm_reweight or self.use_loss_ratio_reweight
911+
) and (_step_id + 1) % self.disp_freq == 0:
912+
weight_str = ", ".join(
913+
f"{k}={task_weights[k]:.4f}" for k in self.model_keys
914+
)
915+
log.info(
916+
"Reweight step %d: [%s]", _step_id + 1, weight_str
917+
)
918+
919+
# Set final grads: weighted average for shared params, full grad for exclusive params
840920
self.optimizer.zero_grad(set_to_none=True)
841-
num_tasks = len(self.model_keys)
842921
for p in all_params:
843922
pid = id(p)
844923
g0p, g1p = task_grads[k0][pid], task_grads[k1][pid]
845924
if g0p is not None and g1p is not None:
846-
p.grad = (g0p + g1p) / num_tasks
925+
p.grad = task_weights[k0] * g0p + task_weights[k1] * g1p
847926
elif g0p is not None:
848927
p.grad = g0p
849928
elif g1p is not None:

deepmd/utils/argcheck.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3389,6 +3389,9 @@ def training_args(
33893389
Argument("use_pcgrad", bool, optional=True, default=False, doc="Apply PCGrad gradient surgery on the shared descriptor parameters in multi-task training."),
33903390
Argument("use_dual_batch", bool, optional=True, default=False, doc="Sample all tasks every step and sum gradients without projection. Use as control group to isolate PCGrad effect from dual-batch effect."),
33913391
Argument("alternating_tasks", bool, optional=True, default=False, doc="Cycle through tasks deterministically (A→B→A→B) each step instead of random sampling. Ablation control to isolate balanced-sampling effect from combined-gradient effect."),
3392+
Argument("grad_norm_reweight", bool, optional=True, default=False, doc="(dual-batch only) Reweight per-task gradients inversely proportional to their EMA gradient norm before combining, equalizing each task's directional contribution to shared parameters."),
3393+
Argument("loss_ratio_reweight", bool, optional=True, default=False, doc="(dual-batch only) Reweight per-task gradients proportional to their EMA loss value, giving more weight to the higher-loss task to prevent it from being sacrificed."),
3394+
Argument("reweight_ema_decay", float, optional=True, default=0.99, doc="EMA decay factor for grad_norm_reweight and loss_ratio_reweight tracking. Higher values give smoother but slower-adapting estimates."),
33923395
Argument("data_dict", dict, data_args, repeat=True, doc=doc_data_dict),
33933396
]
33943397
)

0 commit comments

Comments
 (0)