@@ -575,26 +575,30 @@ def _setup(self):
575575 expert .linear_fc2 .parallel_state = self .parallel_state
576576
577577 def layer_sync_moe_local_experts_amax (self ):
578- """Sync amax across local experts in a SequentialMLP.
578+ """Sync input quantizer amax across local experts in a SequentialMLP.
579579
580- Synchronize the amax values across local experts in a lyaer such that all local experts will
581- share the same amax. This function operates on a single rank and does not require distributed sync.
580+ Ensures all experts have the same input quantizer amax.This function operates
581+ on a single rank and does not require distributed sync.
582582
583583 Distributed amax sync across EP and ETP (for RowParallel) happens in model_calib.max_calibrate().
584584 This function should be called before the distributed sync to ensure the amax values
585585 are synchronized across the layer first.
586586
587587 Note:
588588 Because there are logic which calls collective communication based on whether amax is not None,
589- We need to garuantee that all experts must have amax. Otherwise, there will be deadlock
590- when synchroizing over EP since some ranks may have amax None and not calling the collective
589+ We need to guarantee that all experts must have amax. Otherwise, there will be deadlock
590+ when synchronizing over EP since some ranks may have amax None and not calling the collective
591591 communication.
592592 """
593593 # Collect amax from all local experts
594594 amax_dict = {}
595595 for expert in self .local_experts :
596596 for name , module in expert .named_modules ():
597- if isinstance (module , TensorQuantizer ) and module .amax is not None :
597+ if (
598+ isinstance (module , TensorQuantizer )
599+ and module .amax is not None
600+ and name == "input_quantizer"
601+ ):
598602 stored_amax = amax_dict .get (name )
599603 amax_tensor = module .amax .detach ().clone ()
600604 amax_dict [name ] = (
@@ -606,7 +610,11 @@ def layer_sync_moe_local_experts_amax(self):
606610 # Apply synchronized amax values back to all local experts
607611 for expert in self .local_experts :
608612 for name , module in expert .named_modules ():
609- if isinstance (module , TensorQuantizer ) and name in amax_dict :
613+ if (
614+ isinstance (module , TensorQuantizer )
615+ and name in amax_dict
616+ and name == "input_quantizer"
617+ ):
610618 module .amax = amax_dict [name ].detach ().clone ()
611619
612620 def sharded_state_dict (self , prefix = "" , sharded_offsets = (), metadata = None ):
0 commit comments