|
38 | 38 | from flax.linen import fp8_ops |
39 | 39 | from flax.linen import initializers as flax_initializers |
40 | 40 | import flax.linen as nn |
| 41 | +from flax import nnx |
| 42 | + |
| 43 | +from qwix._src import flax_util |
41 | 44 |
|
42 | 45 | from maxtext.common.common_types import DType, Config |
43 | 46 | from maxtext.inference.kvcache import KVQuant |
| 47 | +from maxtext.layers import nnx_wrappers |
44 | 48 |
|
45 | 49 | # Params used to define mixed precision quantization configs |
46 | 50 | DEFAULT = "__default__" # default config |
@@ -708,18 +712,16 @@ def configure_kv_quant(config): |
708 | 712 |
|
709 | 713 |
|
710 | 714 | def _apply_linen_module_in_nnx(linen_module_cls, op_id, *args, **kwargs): |
| 715 | + """Applies a Linen module within an NNX context.""" |
711 | 716 | try: |
712 | | - from qwix._src import flax_util |
713 | 717 | parent = flax_util.get_current_module() |
714 | | - from flax import nnx |
715 | 718 | is_nnx = isinstance(parent, nnx.Module) |
716 | | - except Exception: |
| 719 | + except Exception: # pylint: disable=broad-exception-caught |
717 | 720 | is_nnx = False |
718 | 721 |
|
719 | 722 | if is_nnx: |
720 | 723 | attr_name = f"_qwix_fp8_gpu_{op_id}" |
721 | 724 | if not hasattr(parent, attr_name): |
722 | | - from maxtext.layers import nnx_wrappers |
723 | 725 | rngs = getattr(parent, "qwix_rngs", None) |
724 | 726 | if rngs is None: |
725 | 727 | rngs = nnx.Rngs(0) |
@@ -838,11 +840,17 @@ def maybe_quantize_model(model, config): |
838 | 840 | if quantization_provider: |
839 | 841 | if config.pure_nnx: |
840 | 842 | input_shape = (config.micro_batch_size_to_train_on, config.max_target_length) |
841 | | - import jax.numpy as jnp |
842 | 843 | dummy_tokens = jnp.ones(input_shape, dtype=jnp.int32) |
843 | 844 | dummy_positions = jnp.ones(input_shape, dtype=jnp.int32) |
844 | 845 | dummy_segment_ids = jnp.ones(input_shape, dtype=jnp.int32) |
845 | | - model = qwix.quantize_model(model, quantization_provider, dummy_tokens, dummy_positions, dummy_segment_ids, enable_dropout=False) |
| 846 | + model = qwix.quantize_model( |
| 847 | + model, |
| 848 | + quantization_provider, |
| 849 | + dummy_tokens, |
| 850 | + dummy_positions, |
| 851 | + dummy_segment_ids, |
| 852 | + enable_dropout=False, |
| 853 | + ) |
846 | 854 | else: |
847 | 855 | model = qwix.quantize_model(model, quantization_provider) |
848 | 856 | return model |
|
0 commit comments