@@ -59,46 +59,9 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
5959 # 2-3. Split + export each per-expert projection.
6060 fused_dim0 = gate_up .shape [1 ] # 2 * expert_dim
6161
62- def _safe_amax (quantizer_src : nn .Module ) -> torch .Tensor | None :
63- """Return _amax as a clean tensor, surfacing any latent CUDA error first.
64-
65- Layerwise calibration's _save_layer + full_restore can leave the per-expert
66- ``_amax`` as a CUDA tensor reconstructed from a serialized view with non-zero
67- storage offset. Touching it directly (``torch.all`` / ``deepcopy``) then triggers
68- ``cudaErrorIllegalAddress``. We synchronize first to surface any pending error,
69- then return the tensor on its original device. Falling back to CPU only on the
70- error path avoids creating a device mismatch with sibling buffers
71- (``_global_amax``) that stayed on the original device.
72- """
73- amax = getattr (quantizer_src , "_amax" , None )
74- if amax is None or not isinstance (amax , torch .Tensor ):
75- return None
76- try :
77- if amax .is_cuda :
78- torch .cuda .synchronize (amax .device )
79- # Force a no-op read to trigger any latent async error.
80- _ = amax .shape
81- return amax .detach ()
82- except Exception :
83- # CUDA tensor was unreadable. Try to recover a CPU copy; if that
84- # also fails, treat as uncalibrated.
85- try :
86- return amax .detach ().cpu ().float ()
87- except Exception :
88- return None
89-
9062 for idx in range (n ):
9163 expert = nn .Module ()
9264
93- # Pre-extract both per-expert amaxes to CPU *before* the projection loop's
94- # deepcopy. deepcopy calls .clone() on CUDA tensors — if the stored _amax
95- # has corrupt storage (under-calibrated experts after layerwise calib), the
96- # clone triggers an async CUDA illegal-memory-access error. Synchronizing in
97- # _safe_amax surfaces the error here so subsequent operations work on
98- # safe CPU float32 tensors.
99- gu_amax = _safe_amax (module .gate_up_proj_weight_quantizers [idx ])
100- down_amax = _safe_amax (module .down_proj_weight_quantizers [idx ])
101-
10265 # If the gate_up source quantizer was never calibrated (rare expert
10366 # that received no calibration tokens), derive its amax once from the
10467 # FUSED tensor so gate and up share the same weight_scale_2 below.
@@ -109,11 +72,11 @@ def _safe_amax(quantizer_src: nn.Module) -> torch.Tensor | None:
10972 # mismatched weight_scale_2 and garbled MoE output at inference.
11073 gate_up_q = module .gate_up_proj_weight_quantizers [idx ]
11174 if getattr (gate_up_q , "is_enabled" , False ) and (
112- gu_amax is None or bool (torch .all (gu_amax == 0 ))
75+ not hasattr (gate_up_q , "_amax" )
76+ or gate_up_q ._amax is None
77+ or torch .all (gate_up_q ._amax == 0 )
11378 ):
11479 gate_up_q .amax = gate_up [idx ].abs ().amax ().to (torch .float32 )
115- # Refresh the CPU amax we'll inject below.
116- gu_amax = _safe_amax (gate_up_q )
11780 warnings .warn (
11881 f"Expert { idx } gate_up_proj weight quantizer was not calibrated "
11982 f"(amax missing or zero). Using fused-tensor amax as fallback "
@@ -137,23 +100,7 @@ def _safe_amax(quantizer_src: nn.Module) -> torch.Tensor | None:
137100 i_quantizer = gate_up_input_q if is_gate_up else down_input_q
138101
139102 # gate/up share a weight quantizer — clone so each gets independent amax.
140- # Null _amax on source before deepcopy so the (possibly corrupt) CUDA tensor
141- # is never cloned; restore afterwards for the sibling projection. The CPU
142- # amax we pre-extracted gets injected in its place.
143- if is_gate_up :
144- _saved_amax = getattr (w_quantizer_src , "_amax" , None )
145- try :
146- w_quantizer_src ._amax = None
147- w_quantizer = copy .deepcopy (w_quantizer_src )
148- finally :
149- w_quantizer_src ._amax = _saved_amax
150- if gu_amax is not None :
151- w_quantizer ._amax = gu_amax
152- else :
153- w_quantizer = w_quantizer_src
154- if down_amax is not None :
155- # Replace any CUDA-resident _amax with the safe CPU copy.
156- w_quantizer ._amax = down_amax
103+ w_quantizer = copy .deepcopy (w_quantizer_src ) if is_gate_up else w_quantizer_src
157104
158105 # For per-channel amax (dim >= 1), proportionally slice dim-0
159106 # to match the split weight.
@@ -162,7 +109,7 @@ def _safe_amax(quantizer_src: nn.Module) -> torch.Tensor | None:
162109 and w_quantizer ._amax is not None
163110 and w_quantizer ._amax .dim () >= 1
164111 ):
165- amax = w_quantizer ._amax # safe-extracted via _safe_amax (CUDA or CPU, recovered if corrupt)
112+ amax = w_quantizer ._amax
166113 # Per-block _amax (NVFP4 static) collapses the row axis we want
167114 # to slice on; restore it so dim-0 slicing splits gate/up.
168115 if amax .numel () != fused_total and amax .numel () % fused_total == 0 :
@@ -185,14 +132,13 @@ def _safe_amax(quantizer_src: nn.Module) -> torch.Tensor | None:
185132 )
186133
187134 # If the weight quantizer was never calibrated, compute amax from weights.
188- # All amax tests below operate on the safe CPU tensor injected above.
189135 if (
190136 hasattr (w_quantizer , "is_enabled" )
191137 and w_quantizer .is_enabled
192138 and (
193139 not hasattr (w_quantizer , "_amax" )
194140 or w_quantizer ._amax is None
195- or bool ( torch .all (w_quantizer ._amax == 0 ) )
141+ or torch .all (w_quantizer ._amax == 0 )
196142 )
197143 ):
198144 w_quantizer .amax = weight_slice .abs ().amax ().to (torch .float32 )
@@ -203,23 +149,6 @@ def _safe_amax(quantizer_src: nn.Module) -> torch.Tensor | None:
203149 stacklevel = 2 ,
204150 )
205151
206- # Align _amax and global_amax with the weight slice's device. The
207- # export math ``per_block_scale * 448 / per_block_scale_max`` reads
208- # both from the quantizer and would otherwise error if they drifted
209- # apart (e.g., CPU-offloaded big-model layers + CUDA-resident weight
210- # slice, or our CPU-injected _amax + the original CUDA global_amax).
211- # No magnitude floor — that's main's policy for the uncalibrated
212- # fallback below.
213- if (
214- hasattr (w_quantizer , "_amax" )
215- and w_quantizer ._amax is not None
216- ):
217- target_device = weight_slice .device
218- if w_quantizer ._amax .device != target_device :
219- w_quantizer ._amax = w_quantizer ._amax .to (target_device )
220- if hasattr (w_quantizer , "global_amax" ):
221- w_quantizer .global_amax = w_quantizer ._amax .float ().amax ()
222-
223152 wrapper = nn .Module ()
224153 wrapper .weight = nn .Parameter (weight_slice .contiguous (), requires_grad = False )
225154 wrapper .weight_quantizer = w_quantizer
0 commit comments