Skip to content

Commit 412d4b7

Browse files
committed
fix: norm only use descript
1 parent 043586d commit 412d4b7

1 file changed

Lines changed: 110 additions & 61 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 110 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -642,8 +642,12 @@ def single_model_finetune(
642642
)
643643
self.reweight_ema_decay = training_params.get("reweight_ema_decay", 0.99)
644644
if self.use_grad_norm_reweight or self.use_loss_ratio_reweight:
645-
self.grad_norm_ema = {k: None for k in self.model_keys}
646-
self.loss_val_ema = {k: None for k in self.model_keys}
645+
# Per-component EMA norms: descriptor and shared fitting_net tracked separately
646+
self.grad_norm_ema_desc = {k: None for k in self.model_keys}
647+
self.grad_norm_ema_fit = {k: None for k in self.model_keys}
648+
# Two-speed EMA for loss: relative rate = fast/slow, scale-invariant
649+
self.loss_val_ema_fast = {k: None for k in self.model_keys}
650+
self.loss_val_ema_slow = {k: None for k in self.model_keys}
647651
if self.use_pcgrad:
648652
log.info("PCGrad enabled: descriptor gradients will be projected each step.")
649653
elif self.use_dual_batch:
@@ -865,68 +869,113 @@ def step(_step_id, task_key="Default") -> None:
865869
_step_id + 1, dot.item(), cos_sim.item(), projected,
866870
)
867871

868-
# Compute per-task weights for gradient combination
869-
task_weights = {k: 1.0 for k in self.model_keys}
870-
d = self.reweight_ema_decay
871-
872-
if self.use_grad_norm_reweight:
873-
for k in self.model_keys:
874-
grads = [
875-
task_grads[k][id(p)]
876-
for p in all_params
877-
if task_grads[k][id(p)] is not None
878-
]
879-
if grads:
880-
cur_norm = torch.stack(
881-
[g.norm() for g in grads]
882-
).norm().item()
883-
if self.grad_norm_ema[k] is None:
884-
self.grad_norm_ema[k] = cur_norm
885-
else:
886-
self.grad_norm_ema[k] = (
887-
d * self.grad_norm_ema[k]
888-
+ (1 - d) * cur_norm
889-
)
890-
task_weights[k] /= self.grad_norm_ema[k] + 1e-8
872+
# Gradient combination: per-component reweighting when active,
873+
# otherwise simple equal average for shared params
874+
self.optimizer.zero_grad(set_to_none=True)
875+
num_tasks = len(self.model_keys)
891876

892-
if self.use_loss_ratio_reweight:
893-
for k in self.model_keys:
894-
cur_loss_val = task_loss_val[k]
895-
if self.loss_val_ema[k] is None:
896-
self.loss_val_ema[k] = cur_loss_val
897-
else:
898-
self.loss_val_ema[k] = (
899-
d * self.loss_val_ema[k]
900-
+ (1 - d) * cur_loss_val
901-
)
902-
task_weights[k] *= self.loss_val_ema[k]
903-
904-
# Normalize weights to sum to 1 (keeps gradient scale on par with baseline)
905-
w_sum = sum(task_weights.values())
906-
for k in task_weights:
907-
task_weights[k] /= w_sum
908-
909-
if self.rank == 0 and (
910-
self.use_grad_norm_reweight or self.use_loss_ratio_reweight
911-
) and (_step_id + 1) % self.disp_freq == 0:
912-
weight_str = ", ".join(
913-
f"{k}={task_weights[k]:.4f}" for k in self.model_keys
914-
)
915-
log.info(
916-
"Reweight step %d: [%s]", _step_id + 1, weight_str
877+
if self.use_grad_norm_reweight or self.use_loss_ratio_reweight:
878+
d = self.reweight_ema_decay
879+
d_slow = 1.0 - (1.0 - d) / 10.0
880+
881+
_module = (
882+
self.wrapper.module
883+
if hasattr(self.wrapper, "module")
884+
else self.wrapper
917885
)
886+
_desc_param_ids = {
887+
id(p) for p in _module.model[k0].get_descriptor().parameters()
888+
}
889+
_shared_pids = {
890+
id(p) for p in all_params
891+
if task_grads[k0].get(id(p)) is not None
892+
and task_grads[k1].get(id(p)) is not None
893+
}
894+
_fit_shared_pids = _shared_pids - _desc_param_ids
895+
896+
if self.use_grad_norm_reweight:
897+
for k in self.model_keys:
898+
desc_g = [
899+
task_grads[k][pid]
900+
for pid in (_shared_pids & _desc_param_ids)
901+
if task_grads[k].get(pid) is not None
902+
]
903+
if desc_g:
904+
cur = torch.stack([g.norm() for g in desc_g]).norm().item()
905+
if self.grad_norm_ema_desc[k] is None:
906+
self.grad_norm_ema_desc[k] = cur
907+
else:
908+
self.grad_norm_ema_desc[k] = d * self.grad_norm_ema_desc[k] + (1 - d) * cur
909+
fit_g = [
910+
task_grads[k][pid]
911+
for pid in _fit_shared_pids
912+
if task_grads[k].get(pid) is not None
913+
]
914+
if fit_g:
915+
cur = torch.stack([g.norm() for g in fit_g]).norm().item()
916+
if self.grad_norm_ema_fit[k] is None:
917+
self.grad_norm_ema_fit[k] = cur
918+
else:
919+
self.grad_norm_ema_fit[k] = d * self.grad_norm_ema_fit[k] + (1 - d) * cur
920+
921+
if self.use_loss_ratio_reweight:
922+
for k in self.model_keys:
923+
v = task_loss_val[k]
924+
if self.loss_val_ema_fast[k] is None:
925+
self.loss_val_ema_fast[k] = v
926+
self.loss_val_ema_slow[k] = v
927+
else:
928+
self.loss_val_ema_fast[k] = d * self.loss_val_ema_fast[k] + (1 - d) * v
929+
self.loss_val_ema_slow[k] = d_slow * self.loss_val_ema_slow[k] + (1 - d_slow) * v
930+
931+
# Build per-component normalized weights
932+
def _w(norm_ema):
933+
return {k: 1.0 / (norm_ema[k] + 1e-8) if norm_ema[k] is not None else 1.0
934+
for k in self.model_keys}
935+
936+
dw = _w(self.grad_norm_ema_desc) if self.use_grad_norm_reweight else {k: 1.0 for k in self.model_keys}
937+
fw = _w(self.grad_norm_ema_fit) if self.use_grad_norm_reweight else {k: 1.0 for k in self.model_keys}
938+
939+
if self.use_loss_ratio_reweight:
940+
for k in self.model_keys:
941+
rel = (self.loss_val_ema_fast[k] / (self.loss_val_ema_slow[k] + 1e-8)
942+
if self.loss_val_ema_fast[k] is not None else 1.0)
943+
dw[k] *= rel
944+
fw[k] *= rel
945+
946+
dw_sum = sum(dw.values())
947+
fw_sum = sum(fw.values())
948+
desc_w = {k: dw[k] / dw_sum for k in self.model_keys}
949+
fit_w = {k: fw[k] / fw_sum for k in self.model_keys}
950+
951+
if self.rank == 0 and (_step_id + 1) % self.disp_freq == 0:
952+
log.info(
953+
"Reweight step %d: desc=[%s] fit=[%s]",
954+
_step_id + 1,
955+
", ".join(f"{k}={desc_w[k]:.4f}" for k in self.model_keys),
956+
", ".join(f"{k}={fit_w[k]:.4f}" for k in self.model_keys),
957+
)
918958

919-
# Set final grads: weighted average for shared params, full grad for exclusive params
920-
self.optimizer.zero_grad(set_to_none=True)
921-
for p in all_params:
922-
pid = id(p)
923-
g0p, g1p = task_grads[k0][pid], task_grads[k1][pid]
924-
if g0p is not None and g1p is not None:
925-
p.grad = task_weights[k0] * g0p + task_weights[k1] * g1p
926-
elif g0p is not None:
927-
p.grad = g0p
928-
elif g1p is not None:
929-
p.grad = g1p
959+
for p in all_params:
960+
pid = id(p)
961+
g0p, g1p = task_grads[k0][pid], task_grads[k1][pid]
962+
if g0p is not None and g1p is not None:
963+
w0, w1 = (desc_w[k0], desc_w[k1]) if pid in _desc_param_ids else (fit_w[k0], fit_w[k1])
964+
p.grad = w0 * g0p + w1 * g1p
965+
elif g0p is not None:
966+
p.grad = g0p
967+
elif g1p is not None:
968+
p.grad = g1p
969+
else:
970+
for p in all_params:
971+
pid = id(p)
972+
g0p, g1p = task_grads[k0][pid], task_grads[k1][pid]
973+
if g0p is not None and g1p is not None:
974+
p.grad = (g0p + g1p) / num_tasks
975+
elif g0p is not None:
976+
p.grad = g0p
977+
elif g1p is not None:
978+
p.grad = g1p
930979

931980
more_loss = more_loss_per_task[task_key]
932981
else:

0 commit comments

Comments
 (0)