3838from flax .linen import fp8_ops
3939from flax .linen import initializers as flax_initializers
4040import flax .linen as nn
41+ from flax import nnx
42+ # Support different packaging structures across environments even within
43+ # the same Qwix version identifier (imports from _src.utils vs _src).
44+ try :
45+ from qwix ._src .utils import flax_util
46+ except ImportError :
47+ from qwix ._src import flax_util # pytype: disable=import-error
48+ from maxtext .layers import nnx_wrappers
4149
4250from maxtext .common .common_types import DType , Config
4351from maxtext .inference .kvcache import KVQuant
@@ -710,6 +718,32 @@ def configure_kv_quant(config):
710718 return None if not config .quantize_kvcache else KVQuant (config )
711719
712720
721+ def _apply_linen_module_in_nnx (linen_module_cls , op_id , * args , ** kwargs ):
722+ """Applies a Linen module within an NNX context."""
723+ try :
724+ parent = flax_util .get_current_module ()
725+ is_nnx = isinstance (parent , nnx .Module )
726+ except ValueError :
727+ is_nnx = False
728+
729+ if is_nnx :
730+ attr_name = f"_qwix_fp8_gpu_{ op_id } "
731+ if not hasattr (parent , attr_name ):
732+ rngs = getattr (parent , "qwix_rngs" , None )
733+ if rngs is None :
734+ parent_rngs = getattr (parent , "rngs" , None )
735+ if parent_rngs is not None and hasattr (parent_rngs , "fork" ):
736+ rngs = parent_rngs .fork ()
737+ else :
738+ rngs = nnx .Rngs (0 )
739+ wrapper = nnx_wrappers .ToNNX (linen_module_cls (name = op_id ), rngs = rngs )
740+ wrapper .lazy_init (* args , ** kwargs )
741+ setattr (parent , attr_name , wrapper )
742+ return getattr (parent , attr_name )(* args , mutable = ["_overwrite_with_gradient" ], ** kwargs )
743+ else :
744+ return linen_module_cls (name = op_id )(* args , ** kwargs )
745+
746+
713747class NvidaFp8Provider (qwix .QtProvider ):
714748 """Wraps nn.Fp8DirectDotGeneralOp with Qwix's provider interface."""
715749
@@ -718,13 +752,13 @@ def dot_general(self, *args, **kwargs):
718752 rule , op_id = self ._get_current_rule_and_op_id ("dot_general" )
719753 if rule is None :
720754 return jax .lax .dot_general (* args , ** kwargs )
721- return nn .Fp8DirectDotGeneralOp ( name = op_id )( * args , ** kwargs )
755+ return _apply_linen_module_in_nnx ( nn .Fp8DirectDotGeneralOp , op_id , * args , ** kwargs )
722756
723757 def einsum (self , * args , ** kwargs ):
724758 rule , op_id = self ._get_current_rule_and_op_id ("einsum" )
725759 if rule is None :
726760 return jnp .einsum (* args , ** kwargs )
727- return nn .Fp8Einsum ( name = op_id )( * args , ** kwargs )
761+ return _apply_linen_module_in_nnx ( nn .Fp8Einsum , op_id , * args , ** kwargs )
728762
729763
730764class NANOOFp8Provider (qwix .QtProvider ):
@@ -734,7 +768,7 @@ def dot_general(self, *args, **kwargs):
734768 rule , op_id = self ._get_current_rule_and_op_id ("dot_general" )
735769 if rule is None :
736770 return jax .lax .dot_general (* args , ** kwargs )
737- return nn .NANOOFp8DotGeneralOp ( name = op_id )( * args , ** kwargs )
771+ return _apply_linen_module_in_nnx ( nn .NANOOFp8DotGeneralOp , op_id , * args , ** kwargs )
738772
739773
740774def get_fp8_full_qwix_rule_w_sparsity (config : Config ):
@@ -815,7 +849,21 @@ def maybe_quantize_model(model, config):
815849 if config .use_qwix_quantization and not config .use_batch_split_schedule :
816850 quantization_provider = get_qt_provider (config )
817851 if quantization_provider :
818- model = qwix .quantize_model (model , quantization_provider )
852+ if config .pure_nnx :
853+ input_shape = (config .micro_batch_size_to_train_on , config .max_target_length )
854+ dummy_tokens = jnp .ones (input_shape , dtype = jnp .int32 )
855+ dummy_positions = jnp .ones (input_shape , dtype = jnp .int32 )
856+ dummy_segment_ids = jnp .ones (input_shape , dtype = jnp .int32 )
857+ model = qwix .quantize_model (
858+ model ,
859+ quantization_provider ,
860+ dummy_tokens ,
861+ dummy_positions ,
862+ dummy_segment_ids ,
863+ enable_dropout = False ,
864+ )
865+ else :
866+ model = qwix .quantize_model (model , quantization_provider )
819867 return model
820868
821869
0 commit comments