@@ -387,17 +387,117 @@ def compare_fn(path, x, y):
387387
388388 jax .tree_util .tree_map_with_path (compare_fn , a , b )
389389
390- def quantization_config (self , quant , logits_tolerance = 2e-1 , grad_tolerance = 5e-1 ):
390+ def quantization_config (self , quant , logits_tolerance = 2e-1 , grad_tolerance = 5e-1 , ** kwargs ):
391391 """Run forward pass and backward pass for quantized model and compare with base model."""
392392 # pylint: disable=protected-access
393- cfg = self .init_pyconfig (quantization = quant )
394- qt_model = model_creation_utils .create_model (cfg , self .mesh )
395-
393+ cfg = self .init_pyconfig (quantization = quant , ** kwargs )
396394 ids , decoder_segment_ids , decoder_positions = self .get_data ()
397395
398- if not hasattr (self .__class__ , "_cached_base_results" ):
399- model = model_creation_utils .create_model (self .cfg , self .mesh )
400- var = model .init (
396+ if cfg .pure_nnx :
397+ qt_model = model_creation_utils .create_model (cfg , self .mesh , rngs = nnx .Rngs (0 ))
398+ if getattr (self .__class__ , "_cached_base_results_nnx" , None ) is None :
399+ base_cfg = self .init_pyconfig (quantization = "" , ** kwargs )
400+ base_model = model_creation_utils .create_model (base_cfg , self .mesh , rngs = nnx .Rngs (0 ))
401+
402+ def loss_base (model ):
403+ logits = model (
404+ decoder_input_tokens = ids ,
405+ decoder_positions = decoder_positions ,
406+ decoder_segment_ids = decoder_segment_ids ,
407+ enable_dropout = False ,
408+ )
409+ return jnp .mean ((logits ) ** 2 )
410+
411+ grads_base = nnx .grad (loss_base )(base_model )
412+ logits_base = base_model (
413+ decoder_input_tokens = ids ,
414+ decoder_positions = decoder_positions ,
415+ decoder_segment_ids = decoder_segment_ids ,
416+ enable_dropout = False ,
417+ )
418+ self .__class__ ._cached_base_results_nnx = (grads_base , logits_base )
419+
420+ grads_base , logits = self .__class__ ._cached_base_results_nnx
421+
422+ def loss_quant (model ):
423+ logits_q = model (
424+ decoder_input_tokens = ids ,
425+ decoder_positions = decoder_positions ,
426+ decoder_segment_ids = decoder_segment_ids ,
427+ enable_dropout = False ,
428+ )
429+ return jnp .mean ((logits_q ) ** 2 )
430+
431+ grads_quant = nnx .grad (loss_quant )(qt_model )
432+ quant_logits = qt_model (
433+ decoder_input_tokens = ids ,
434+ decoder_positions = decoder_positions ,
435+ decoder_segment_ids = decoder_segment_ids ,
436+ enable_dropout = False ,
437+ )
438+
439+ print ("relative error in logits:" f" { jnp .abs (quant_logits - logits ).mean () / jnp .abs (logits ).mean ()} " )
440+ assert jnp .abs (quant_logits - logits ).mean () / jnp .abs (logits ).mean () < logits_tolerance
441+
442+ # nnx.grad returns a State object which is a mapping of paths to gradients.
443+ # Flatten them to check for tolerance.
444+ from flax .nnx import traversals
445+ grads_base_flat = traversals .flatten_mapping (grads_base )
446+ grads_quant_flat = traversals .flatten_mapping (grads_quant )
447+
448+ # Filter for param collections to compare only parameters and not stats/buffers if any
449+ # Note: NNX grads structure might contain variables like 'kernel', 'bias'.
450+ # For simplicity we compare all matching keys.
451+ def flatten_and_filter (grads_flat ):
452+ return {k : v for k , v in grads_flat .items () if hasattr (v , "shape" ) and "quant_stats" not in str (k )}
453+
454+ gb_f = flatten_and_filter (grads_base_flat )
455+ gq_f = flatten_and_filter (grads_quant_flat )
456+
457+ for k in gb_f :
458+ if k in gq_f :
459+ diff = jnp .abs (gb_f [k ] - gq_f [k ]).mean () / (jnp .abs (gb_f [k ]).mean () + 1e-8 )
460+ if diff > grad_tolerance :
461+ print (f"Gradient mismatch for { k } : rel_error = { diff } " )
462+ assert diff <= grad_tolerance
463+ else :
464+ qt_model = model_creation_utils .create_model (cfg , self .mesh )
465+ if not hasattr (self .__class__ , "_cached_base_results" ):
466+ model = model_creation_utils .create_model (self .cfg , self .mesh )
467+ var = model .init (
468+ {"params" : self .rng , "aqt" : self .rng , "dropout" : self .rng },
469+ ids ,
470+ decoder_positions ,
471+ decoder_segment_ids ,
472+ enable_dropout = False ,
473+ mutable = True ,
474+ )
475+
476+ def loss_base_linen (all_vars , inputs ):
477+ logits_b , _ = model .apply (
478+ all_vars ,
479+ * inputs ,
480+ enable_dropout = False ,
481+ rngs = {"params" : self .rng },
482+ mutable = True ,
483+ )
484+ return jnp .mean ((logits_b ) ** 2 )
485+
486+ grads_base_linen = jax .grad (loss_base_linen )(var , (ids , decoder_positions , decoder_segment_ids ))
487+ logits_b , _ = model .apply (
488+ var ,
489+ ids ,
490+ decoder_positions ,
491+ decoder_segment_ids ,
492+ enable_dropout = False ,
493+ rngs = {"params" : self .rng },
494+ mutable = True ,
495+ )
496+ self .__class__ ._cached_base_results = (grads_base_linen , logits_b )
497+
498+ grads_base_linen , logits = self .__class__ ._cached_base_results
499+
500+ quantized_vars = qt_model .init (
401501 {"params" : self .rng , "aqt" : self .rng , "dropout" : self .rng },
402502 ids ,
403503 decoder_positions ,
@@ -406,71 +506,37 @@ def quantization_config(self, quant, logits_tolerance=2e-1, grad_tolerance=5e-1)
406506 mutable = True ,
407507 )
408508
409- def loss_base (all_vars , inputs ):
410- logits , _ = model .apply (
509+ def loss_quant_linen (all_vars , inputs ):
510+ logits_q , _ = qt_model .apply (
411511 all_vars ,
412512 * inputs ,
413513 enable_dropout = False ,
414514 rngs = {"params" : self .rng },
415515 mutable = True ,
416516 )
417- return jnp .mean ((logits ) ** 2 )
517+ return jnp .mean ((logits_q ) ** 2 )
518+
519+ grads_quant_linen = jax .grad (loss_quant_linen )(quantized_vars , (ids , decoder_positions , decoder_segment_ids ))
418520
419- grads_base = jax .grad (loss_base )(var , (ids , decoder_positions , decoder_segment_ids ))
420- logits , _ = model .apply (
421- var ,
521+ quant_logits , _ = qt_model .apply (
522+ quantized_vars ,
422523 ids ,
423524 decoder_positions ,
424525 decoder_segment_ids ,
425526 enable_dropout = False ,
426527 rngs = {"params" : self .rng },
427528 mutable = True ,
428529 )
429- self .__class__ ._cached_base_results = (grads_base , logits )
430-
431- grads_base , logits = self .__class__ ._cached_base_results
432-
433- quantized_vars = qt_model .init (
434- {"params" : self .rng , "aqt" : self .rng , "dropout" : self .rng },
435- ids ,
436- decoder_positions ,
437- decoder_segment_ids ,
438- enable_dropout = False ,
439- mutable = True ,
440- )
441-
442- def loss_quant (all_vars , inputs ):
443- logits , _ = qt_model .apply (
444- all_vars ,
445- * inputs ,
446- enable_dropout = False ,
447- rngs = {"params" : self .rng },
448- mutable = True ,
530+ print ("relative error in logits:" f" { jnp .abs (quant_logits - logits ).mean () / jnp .abs (logits ).mean ()} " )
531+ assert jnp .abs (quant_logits - logits ).mean () / jnp .abs (logits ).mean () < logits_tolerance
532+ self .print_grad_diff (grads_base_linen ["params" ], grads_quant_linen ["params" ])
533+ self .assertTrue (
534+ self .pytree_allclose (
535+ grads_base_linen ["params" ],
536+ grads_quant_linen ["params" ],
537+ tolerance = grad_tolerance ,
538+ )
449539 )
450- return jnp .mean ((logits ) ** 2 )
451-
452- # Compute gradients w.r.t. both models
453- grads_quant = jax .grad (loss_quant )(quantized_vars , (ids , decoder_positions , decoder_segment_ids ))
454-
455- quant_logits , _ = qt_model .apply (
456- quantized_vars ,
457- ids ,
458- decoder_positions ,
459- decoder_segment_ids ,
460- enable_dropout = False ,
461- rngs = {"params" : self .rng },
462- mutable = True ,
463- )
464- print ("relative error in logits:" f" { jnp .abs (quant_logits - logits ).mean () / jnp .abs (logits ).mean ()} " )
465- assert jnp .abs (quant_logits - logits ).mean () / jnp .abs (logits ).mean () < logits_tolerance
466- self .print_grad_diff (grads_base ["params" ], grads_quant ["params" ])
467- self .assertTrue (
468- self .pytree_allclose (
469- grads_base ["params" ],
470- grads_quant ["params" ],
471- tolerance = grad_tolerance ,
472- )
473- )
474540
475541 @pytest .mark .tpu_only
476542 def test_int8_quantization (self ):
@@ -489,6 +555,11 @@ def test_fp8_full_quantization(self):
489555 def test_fp8_gpu_quantization (self ):
490556 self .quantization_config ("fp8_gpu" , grad_tolerance = 1.5 )
491557
558+ # @pytest.mark.gpu_only
559+ @pytest .mark .external_serving
560+ def test_fp8_gpu_quantization (self ):
561+ self .quantization_config ("fp8_gpu" , grad_tolerance = 1.5 , enable_nnx = True , pure_nnx_decoder = True , pure_nnx = True )
562+
492563 @pytest .mark .gpu_only
493564 @pytest .mark .external_serving
494565 def test_fp8_nanoo_quantization (self ):
0 commit comments