@@ -642,8 +642,12 @@ def single_model_finetune(
642642 )
643643 self .reweight_ema_decay = training_params .get ("reweight_ema_decay" , 0.99 )
644644 if self .use_grad_norm_reweight or self .use_loss_ratio_reweight :
645- self .grad_norm_ema = {k : None for k in self .model_keys }
646- self .loss_val_ema = {k : None for k in self .model_keys }
645+ # Per-component EMA norms: descriptor and shared fitting_net tracked separately
646+ self .grad_norm_ema_desc = {k : None for k in self .model_keys }
647+ self .grad_norm_ema_fit = {k : None for k in self .model_keys }
648+ # Two-speed EMA for loss: relative rate = fast/slow, scale-invariant
649+ self .loss_val_ema_fast = {k : None for k in self .model_keys }
650+ self .loss_val_ema_slow = {k : None for k in self .model_keys }
647651 if self .use_pcgrad :
648652 log .info ("PCGrad enabled: descriptor gradients will be projected each step." )
649653 elif self .use_dual_batch :
@@ -865,68 +869,113 @@ def step(_step_id, task_key="Default") -> None:
865869 _step_id + 1 , dot .item (), cos_sim .item (), projected ,
866870 )
867871
868- # Compute per-task weights for gradient combination
869- task_weights = {k : 1.0 for k in self .model_keys }
870- d = self .reweight_ema_decay
871-
872- if self .use_grad_norm_reweight :
873- for k in self .model_keys :
874- grads = [
875- task_grads [k ][id (p )]
876- for p in all_params
877- if task_grads [k ][id (p )] is not None
878- ]
879- if grads :
880- cur_norm = torch .stack (
881- [g .norm () for g in grads ]
882- ).norm ().item ()
883- if self .grad_norm_ema [k ] is None :
884- self .grad_norm_ema [k ] = cur_norm
885- else :
886- self .grad_norm_ema [k ] = (
887- d * self .grad_norm_ema [k ]
888- + (1 - d ) * cur_norm
889- )
890- task_weights [k ] /= self .grad_norm_ema [k ] + 1e-8
872+ # Gradient combination: per-component reweighting when active,
873+ # otherwise simple equal average for shared params
874+ self .optimizer .zero_grad (set_to_none = True )
875+ num_tasks = len (self .model_keys )
891876
892- if self .use_loss_ratio_reweight :
893- for k in self .model_keys :
894- cur_loss_val = task_loss_val [k ]
895- if self .loss_val_ema [k ] is None :
896- self .loss_val_ema [k ] = cur_loss_val
897- else :
898- self .loss_val_ema [k ] = (
899- d * self .loss_val_ema [k ]
900- + (1 - d ) * cur_loss_val
901- )
902- task_weights [k ] *= self .loss_val_ema [k ]
903-
904- # Normalize weights to sum to 1 (keeps gradient scale on par with baseline)
905- w_sum = sum (task_weights .values ())
906- for k in task_weights :
907- task_weights [k ] /= w_sum
908-
909- if self .rank == 0 and (
910- self .use_grad_norm_reweight or self .use_loss_ratio_reweight
911- ) and (_step_id + 1 ) % self .disp_freq == 0 :
912- weight_str = ", " .join (
913- f"{ k } ={ task_weights [k ]:.4f} " for k in self .model_keys
914- )
915- log .info (
916- "Reweight step %d: [%s]" , _step_id + 1 , weight_str
877+ if self .use_grad_norm_reweight or self .use_loss_ratio_reweight :
878+ d = self .reweight_ema_decay
879+ d_slow = 1.0 - (1.0 - d ) / 10.0
880+
881+ _module = (
882+ self .wrapper .module
883+ if hasattr (self .wrapper , "module" )
884+ else self .wrapper
917885 )
886+ _desc_param_ids = {
887+ id (p ) for p in _module .model [k0 ].get_descriptor ().parameters ()
888+ }
889+ _shared_pids = {
890+ id (p ) for p in all_params
891+ if task_grads [k0 ].get (id (p )) is not None
892+ and task_grads [k1 ].get (id (p )) is not None
893+ }
894+ _fit_shared_pids = _shared_pids - _desc_param_ids
895+
896+ if self .use_grad_norm_reweight :
897+ for k in self .model_keys :
898+ desc_g = [
899+ task_grads [k ][pid ]
900+ for pid in (_shared_pids & _desc_param_ids )
901+ if task_grads [k ].get (pid ) is not None
902+ ]
903+ if desc_g :
904+ cur = torch .stack ([g .norm () for g in desc_g ]).norm ().item ()
905+ if self .grad_norm_ema_desc [k ] is None :
906+ self .grad_norm_ema_desc [k ] = cur
907+ else :
908+ self .grad_norm_ema_desc [k ] = d * self .grad_norm_ema_desc [k ] + (1 - d ) * cur
909+ fit_g = [
910+ task_grads [k ][pid ]
911+ for pid in _fit_shared_pids
912+ if task_grads [k ].get (pid ) is not None
913+ ]
914+ if fit_g :
915+ cur = torch .stack ([g .norm () for g in fit_g ]).norm ().item ()
916+ if self .grad_norm_ema_fit [k ] is None :
917+ self .grad_norm_ema_fit [k ] = cur
918+ else :
919+ self .grad_norm_ema_fit [k ] = d * self .grad_norm_ema_fit [k ] + (1 - d ) * cur
920+
921+ if self .use_loss_ratio_reweight :
922+ for k in self .model_keys :
923+ v = task_loss_val [k ]
924+ if self .loss_val_ema_fast [k ] is None :
925+ self .loss_val_ema_fast [k ] = v
926+ self .loss_val_ema_slow [k ] = v
927+ else :
928+ self .loss_val_ema_fast [k ] = d * self .loss_val_ema_fast [k ] + (1 - d ) * v
929+ self .loss_val_ema_slow [k ] = d_slow * self .loss_val_ema_slow [k ] + (1 - d_slow ) * v
930+
931+ # Build per-component normalized weights
932+ def _w (norm_ema ):
933+ return {k : 1.0 / (norm_ema [k ] + 1e-8 ) if norm_ema [k ] is not None else 1.0
934+ for k in self .model_keys }
935+
936+ dw = _w (self .grad_norm_ema_desc ) if self .use_grad_norm_reweight else {k : 1.0 for k in self .model_keys }
937+ fw = _w (self .grad_norm_ema_fit ) if self .use_grad_norm_reweight else {k : 1.0 for k in self .model_keys }
938+
939+ if self .use_loss_ratio_reweight :
940+ for k in self .model_keys :
941+ rel = (self .loss_val_ema_fast [k ] / (self .loss_val_ema_slow [k ] + 1e-8 )
942+ if self .loss_val_ema_fast [k ] is not None else 1.0 )
943+ dw [k ] *= rel
944+ fw [k ] *= rel
945+
946+ dw_sum = sum (dw .values ())
947+ fw_sum = sum (fw .values ())
948+ desc_w = {k : dw [k ] / dw_sum for k in self .model_keys }
949+ fit_w = {k : fw [k ] / fw_sum for k in self .model_keys }
950+
951+ if self .rank == 0 and (_step_id + 1 ) % self .disp_freq == 0 :
952+ log .info (
953+ "Reweight step %d: desc=[%s] fit=[%s]" ,
954+ _step_id + 1 ,
955+ ", " .join (f"{ k } ={ desc_w [k ]:.4f} " for k in self .model_keys ),
956+ ", " .join (f"{ k } ={ fit_w [k ]:.4f} " for k in self .model_keys ),
957+ )
918958
919- # Set final grads: weighted average for shared params, full grad for exclusive params
920- self .optimizer .zero_grad (set_to_none = True )
921- for p in all_params :
922- pid = id (p )
923- g0p , g1p = task_grads [k0 ][pid ], task_grads [k1 ][pid ]
924- if g0p is not None and g1p is not None :
925- p .grad = task_weights [k0 ] * g0p + task_weights [k1 ] * g1p
926- elif g0p is not None :
927- p .grad = g0p
928- elif g1p is not None :
929- p .grad = g1p
959+ for p in all_params :
960+ pid = id (p )
961+ g0p , g1p = task_grads [k0 ][pid ], task_grads [k1 ][pid ]
962+ if g0p is not None and g1p is not None :
963+ w0 , w1 = (desc_w [k0 ], desc_w [k1 ]) if pid in _desc_param_ids else (fit_w [k0 ], fit_w [k1 ])
964+ p .grad = w0 * g0p + w1 * g1p
965+ elif g0p is not None :
966+ p .grad = g0p
967+ elif g1p is not None :
968+ p .grad = g1p
969+ else :
970+ for p in all_params :
971+ pid = id (p )
972+ g0p , g1p = task_grads [k0 ][pid ], task_grads [k1 ][pid ]
973+ if g0p is not None and g1p is not None :
974+ p .grad = (g0p + g1p ) / num_tasks
975+ elif g0p is not None :
976+ p .grad = g0p
977+ elif g1p is not None :
978+ p .grad = g1p
930979
931980 more_loss = more_loss_per_task [task_key ]
932981 else :
0 commit comments