@@ -133,10 +133,9 @@ def _run_allgather_test(sharding_dims, recipe):
133133)
134134def test_fp8_fsdp2_allgather (sharding_dims , fp_recipe ):
135135 """Verify FSDP2 FP8 all-gather matches a manual dequantize-then-gather reference."""
136- if fp_recipe in ( "Float8BlockScaling" , " NVFP4BlockScaling") :
136+ if fp_recipe == " NVFP4BlockScaling" :
137137 pytest .xfail (
138- f"{ fp_recipe } : block-scaled quantization formats are not supported by the "
139- "FP8 FSDP2 all-gather correctness test."
138+ f"{ fp_recipe } : NVFP4 FSDP2 all-gather hooks need to be implemented."
140139 )
141140
142141 parallel_size = math .prod (x for x in sharding_dims if x != 0 )
@@ -181,20 +180,7 @@ def test_fsdp2_fused_adam_fp8_master_weights(fp_recipe):
181180
182181@pytest .mark .skipif (NUM_PROCS < 2 , reason = "Requires 2+ GPUs" )
183182def test_fsdp2_fused_adam_fp8_master_weights_no_meta (fp_recipe ):
184- """FusedAdam(master_weights=True) + FSDP2 + quantized_model_init (CUDA init, no meta device).
185-
186- Block-scaling QuantizedTensors (MXFP8, Float8Blockwise, NVFP4) are wrapper
187- subclasses with data_ptr() == 0. Without meta-device init, FSDP2's
188- reset_sharded_param() crashes with 'invalid python storage'.
189- Per-tensor FP8 (DelayedScaling, Float8CurrentScaling) works because
190- Float8Tensor's storage is accessible.
191- """
192- if fp_recipe in ("MXFP8BlockScaling" , "Float8BlockScaling" , "NVFP4BlockScaling" ):
193- pytest .xfail (
194- f"{ fp_recipe } : FSDP2 without meta-device init crashes on block-scaling "
195- "QuantizedTensor wrapper subclasses (data_ptr() == 0). "
196- "Use device='meta' + reset_parameters() after sharding."
197- )
183+ """FusedAdam(master_weights=True) + FSDP2 + quantized_model_init (CUDA init, no meta device)."""
198184 _run_fused_adam_test ("fused_adam_fp8_master_weights_no_meta" , fp_recipe )
199185
200186
@@ -232,8 +218,8 @@ def test_fsdp2_dcp_output_parity(fp_recipe):
232218
233219 if fp_recipe == "NVFP4BlockScaling" :
234220 pytest .xfail (
235- "NVFP4BlockScaling: DCP load_state_dict triggers reset_sharded_param() "
236- "which calls data_ptr() on NVFP4Tensor wrapper subclass with invalid storage "
221+ "NVFP4BlockScaling: Failing parity tests with DCP. Snippet: \n "
222+ "Fresh model loaded from DCP checkpoint produces different output. "
237223 )
238224
239225 if fp_recipe == "Float8BlockScaling" and torch .cuda .get_device_capability ()[0 ] == 12 :
@@ -261,8 +247,8 @@ def test_fsdp2_dcp_output_parity_async(fp_recipe):
261247
262248 if fp_recipe == "NVFP4BlockScaling" :
263249 pytest .xfail (
264- "NVFP4BlockScaling: DCP load_state_dict triggers reset_sharded_param() "
265- "which calls data_ptr() on NVFP4Tensor wrapper subclass with invalid storage "
250+ "NVFP4BlockScaling: Failing parity tests with DCP. Snippet: \n "
251+ "Fresh model loaded from DCP checkpoint produces different output. "
266252 )
267253
268254 if fp_recipe == "Float8BlockScaling" :
0 commit comments