@@ -710,8 +710,10 @@ def configure_kv_quant(config):
710710def _apply_linen_module_in_nnx (linen_module_cls , op_id , * args , ** kwargs ):
711711 try :
712712 from qwix ._src import flax_util
713+
713714 parent = flax_util .get_current_module ()
714715 from flax import nnx
716+
715717 is_nnx = isinstance (parent , nnx .Module )
716718 except Exception :
717719 is_nnx = False
@@ -720,6 +722,7 @@ def _apply_linen_module_in_nnx(linen_module_cls, op_id, *args, **kwargs):
720722 attr_name = f"_qwix_fp8_gpu_{ op_id } "
721723 if not hasattr (parent , attr_name ):
722724 from maxtext .layers import nnx_wrappers
725+
723726 rngs = getattr (parent , "qwix_rngs" , None )
724727 if rngs is None :
725728 rngs = nnx .Rngs (0 )
@@ -839,10 +842,13 @@ def maybe_quantize_model(model, config):
839842 if config .pure_nnx :
840843 input_shape = (config .micro_batch_size_to_train_on , config .max_target_length )
841844 import jax .numpy as jnp
845+
842846 dummy_tokens = jnp .ones (input_shape , dtype = jnp .int32 )
843847 dummy_positions = jnp .ones (input_shape , dtype = jnp .int32 )
844848 dummy_segment_ids = jnp .ones (input_shape , dtype = jnp .int32 )
845- model = qwix .quantize_model (model , quantization_provider , dummy_tokens , dummy_positions , dummy_segment_ids , enable_dropout = False )
849+ model = qwix .quantize_model (
850+ model , quantization_provider , dummy_tokens , dummy_positions , dummy_segment_ids , enable_dropout = False
851+ )
846852 else :
847853 model = qwix .quantize_model (model , quantization_provider )
848854 return model
0 commit comments