Skip to content

Commit 2993617

Browse files
test(quant): Add NNX versions of Qwix quantization unit tests
1 parent 7c1bf78 commit 2993617

4 files changed

Lines changed: 52 additions & 14 deletions

File tree

src/maxtext/layers/nnx_wrappers.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,16 @@ def maybe_unbox(x):
499499
for path, _ in unknown_state_flat.items():
500500
paths_str += f"\n - {'/'.join(map(str, path))}"
501501

502-
warnings.warn(f"Found unknown module paths in incoming state:{paths_str}")
502+
# Dynamically reconstruct the unknown variables
503+
curr = module
504+
for p in path[:-1]:
505+
if not hasattr(curr, p):
506+
setattr(curr, p, nnx.Module())
507+
curr = getattr(curr, p)
508+
509+
warnings.warn(
510+
f"Found unknown module paths in incoming state:{paths_str}. Intermediate modules have been reconstructed."
511+
)
503512

504513
nnx.update(module, new_state)
505514
_refresh_variable_trace_state(module)

src/maxtext/layers/quantizations.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,13 @@
3838
from flax.linen import fp8_ops
3939
from flax.linen import initializers as flax_initializers
4040
import flax.linen as nn
41+
from flax import nnx
42+
43+
from qwix._src import flax_util
4144

4245
from maxtext.common.common_types import DType, Config
4346
from maxtext.inference.kvcache import KVQuant
47+
from maxtext.layers import nnx_wrappers
4448

4549
# Params used to define mixed precision quantization configs
4650
DEFAULT = "__default__" # default config
@@ -708,18 +712,16 @@ def configure_kv_quant(config):
708712

709713

710714
def _apply_linen_module_in_nnx(linen_module_cls, op_id, *args, **kwargs):
715+
"""Applies a Linen module within an NNX context."""
711716
try:
712-
from qwix._src import flax_util
713717
parent = flax_util.get_current_module()
714-
from flax import nnx
715718
is_nnx = isinstance(parent, nnx.Module)
716-
except Exception:
719+
except Exception: # pylint: disable=broad-exception-caught
717720
is_nnx = False
718721

719722
if is_nnx:
720723
attr_name = f"_qwix_fp8_gpu_{op_id}"
721724
if not hasattr(parent, attr_name):
722-
from maxtext.layers import nnx_wrappers
723725
rngs = getattr(parent, "qwix_rngs", None)
724726
if rngs is None:
725727
rngs = nnx.Rngs(0)
@@ -838,11 +840,17 @@ def maybe_quantize_model(model, config):
838840
if quantization_provider:
839841
if config.pure_nnx:
840842
input_shape = (config.micro_batch_size_to_train_on, config.max_target_length)
841-
import jax.numpy as jnp
842843
dummy_tokens = jnp.ones(input_shape, dtype=jnp.int32)
843844
dummy_positions = jnp.ones(input_shape, dtype=jnp.int32)
844845
dummy_segment_ids = jnp.ones(input_shape, dtype=jnp.int32)
845-
model = qwix.quantize_model(model, quantization_provider, dummy_tokens, dummy_positions, dummy_segment_ids, enable_dropout=False)
846+
model = qwix.quantize_model(
847+
model,
848+
quantization_provider,
849+
dummy_tokens,
850+
dummy_positions,
851+
dummy_segment_ids,
852+
enable_dropout=False,
853+
)
846854
else:
847855
model = qwix.quantize_model(model, quantization_provider)
848856
return model

src/maxtext/trainers/pre_train/train.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,9 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
349349
is_train=True,
350350
)
351351
else:
352-
OverwriteWithGradient = nnx.variablelib.variable_type_from_name(maxtext_utils.OVERWRITE_WITH_GRADIENT, allow_register=True)
352+
OverwriteWithGradient = nnx.variablelib.variable_type_from_name(
353+
maxtext_utils.OVERWRITE_WITH_GRADIENT, allow_register=True
354+
)
353355
model_graphdef, curr_params, overwrite_vars, rest = nnx.split(state.model, nnx.Param, OverwriteWithGradient, ...)
354356
if config.parameter_memory_host_offload:
355357
# Params are kept on host (pinned_host) in in_shardings. Move only Param
@@ -379,7 +381,9 @@ def diff_wrapper(param, overwrite_vars, rest, config, data):
379381
return loss, (aux, new_overwrite_vars, new_rest)
380382

381383
grad_func = jax.value_and_grad(diff_wrapper, argnums=(0, 1), has_aux=True)
382-
(loss, (aux, new_overwrite_vars, new_rest)), (raw_grads, overwrite_grads) = grad_func(curr_params, overwrite_vars, rest, config, data)
384+
(loss, (aux, _, new_rest)), (raw_grads, overwrite_grads) = grad_func(
385+
curr_params, overwrite_vars, rest, config, data
386+
)
383387
nnx.update(state.model, new_rest)
384388
nnx.update(state.model, overwrite_grads)
385389

tests/unit/quantizations_test.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from aqt.jax.v2 import aqt_tensor
2323
from aqt.jax.v2.flax import aqt_flax
2424
from flax import nnx
25+
from flax.nnx import traversals
2526
import jax
2627
from jax import lax
2728
from jax import numpy as jnp
@@ -48,7 +49,7 @@ def __init__(
4849
self,
4950
quantization: quantizations.AqtQuantization,
5051
data_type: Any,
51-
rngs: nnx.Rngs,
52+
rngs: nnx.Rngs, # pylint: disable=unused-argument
5253
):
5354
self.quantization = quantization
5455
self.identity = jnp.identity(2, dtype=data_type)
@@ -441,10 +442,9 @@ def loss_quant(model):
441442

442443
# nnx.grad returns a State object which is a mapping of paths to gradients.
443444
# Flatten them to check for tolerance.
444-
from flax.nnx import traversals
445445
grads_base_flat = traversals.flatten_mapping(grads_base)
446446
grads_quant_flat = traversals.flatten_mapping(grads_quant)
447-
447+
448448
# Filter for param collections to compare only parameters and not stats/buffers if any
449449
# Note: NNX grads structure might contain variables like 'kernel', 'bias'.
450450
# For simplicity we compare all matching keys.
@@ -542,29 +542,46 @@ def loss_quant_linen(all_vars, inputs):
542542
def test_int8_quantization(self):
543543
self.quantization_config("int8")
544544

545+
@pytest.mark.tpu_only
546+
def test_int8_quantization_nnx(self):
547+
self.quantization_config("int8", enable_nnx=True, pure_nnx_decoder=True, pure_nnx=True)
548+
545549
@pytest.mark.tpu_only
546550
def test_fp8_quantization(self):
547551
self.quantization_config("fp8")
548552

553+
@pytest.mark.tpu_only
554+
def test_fp8_quantization_nnx(self):
555+
self.quantization_config("fp8", enable_nnx=True, pure_nnx_decoder=True, pure_nnx=True)
556+
549557
@pytest.mark.tpu_only
550558
def test_fp8_full_quantization(self):
551559
self.quantization_config("fp8_full")
552560

561+
@pytest.mark.tpu_only
562+
def test_fp8_full_quantization_nnx(self):
563+
self.quantization_config("fp8_full", enable_nnx=True, pure_nnx_decoder=True, pure_nnx=True)
564+
553565
@pytest.mark.gpu_only
554566
@pytest.mark.external_serving
555567
def test_fp8_gpu_quantization(self):
556568
self.quantization_config("fp8_gpu", grad_tolerance=1.5)
557569

558-
# @pytest.mark.gpu_only
570+
@pytest.mark.gpu_only
559571
@pytest.mark.external_serving
560-
def test_fp8_gpu_quantization(self):
572+
def test_fp8_gpu_quantization_nnx(self):
561573
self.quantization_config("fp8_gpu", grad_tolerance=1.5, enable_nnx=True, pure_nnx_decoder=True, pure_nnx=True)
562574

563575
@pytest.mark.gpu_only
564576
@pytest.mark.external_serving
565577
def test_fp8_nanoo_quantization(self):
566578
self.quantization_config("fp8_nanoo", grad_tolerance=1.5)
567579

580+
@pytest.mark.gpu_only
581+
@pytest.mark.external_serving
582+
def test_fp8_nanoo_quantization_nnx(self):
583+
self.quantization_config("fp8_nanoo", grad_tolerance=1.5, enable_nnx=True, pure_nnx_decoder=True, pure_nnx=True)
584+
568585
@pytest.mark.skip(reason="No runner with GPU arch >= 89 is available")
569586
@pytest.mark.gpu_only
570587
def test_fp8_te_fp8_delayedscaling_quantization(self):

0 commit comments

Comments
 (0)