Skip to content

Commit 7c1bf78

Browse files
Support Qwix quantization on NNX
1 parent 7c68a9d commit 7c1bf78

4 files changed

Lines changed: 174 additions & 84 deletions

File tree

src/maxtext/layers/nnx_decoders.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -671,22 +671,7 @@ def layer_fn(carry, scanned_vars):
671671
params = nnx_ensure_scan_leading_axis(params, length)
672672
state = nnx_ensure_scan_leading_axis(state, length)
673673

674-
# Linen FP8 ops keep amax_history in mutable Linen scope; jax.lax.scan
675-
# leaks the tracer and hits UnexpectedTracerError. Use a Python for-loop
676-
# for FP8 instead.
677-
uses_linen_fp8_mutable_state = self.config.quantization in ("fp8_nanoo", "fp8_gpu")
678-
if uses_linen_fp8_mutable_state:
679-
carry = x_in
680-
per_layer_states = []
681-
for i in range(length):
682-
current_params = jax.tree.map(lambda x, i=i: x[i], params)
683-
current_state = jax.tree.map(lambda x, i=i: x[i], state)
684-
carry, new_state_i = layer_fn(carry, (current_params, current_state))
685-
per_layer_states.append(new_state_i)
686-
final_carry = carry
687-
scanned_state = jax.tree.map(lambda *xs: jnp.stack(list(xs)), *per_layer_states)
688-
else:
689-
final_carry, scanned_state = jax.lax.scan(layer_fn_wrapped, x_in, (params, state))
674+
final_carry, scanned_state = jax.lax.scan(layer_fn_wrapped, x_in, (params, state))
690675
returned_kv_stacked = None
691676

692677
if scan_axis != 0:

src/maxtext/layers/quantizations.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,30 @@ def configure_kv_quant(config):
707707
return None if not config.quantize_kvcache else KVQuant(config)
708708

709709

710+
def _apply_linen_module_in_nnx(linen_module_cls, op_id, *args, **kwargs):
711+
try:
712+
from qwix._src import flax_util
713+
parent = flax_util.get_current_module()
714+
from flax import nnx
715+
is_nnx = isinstance(parent, nnx.Module)
716+
except Exception:
717+
is_nnx = False
718+
719+
if is_nnx:
720+
attr_name = f"_qwix_fp8_gpu_{op_id}"
721+
if not hasattr(parent, attr_name):
722+
from maxtext.layers import nnx_wrappers
723+
rngs = getattr(parent, "qwix_rngs", None)
724+
if rngs is None:
725+
rngs = nnx.Rngs(0)
726+
wrapper = nnx_wrappers.ToNNX(linen_module_cls(name=op_id), rngs=rngs)
727+
wrapper.lazy_init(*args, **kwargs)
728+
setattr(parent, attr_name, wrapper)
729+
return getattr(parent, attr_name)(*args, mutable=["_overwrite_with_gradient"], **kwargs)
730+
else:
731+
return linen_module_cls(name=op_id)(*args, **kwargs)
732+
733+
710734
class NvidaFp8Provider(qwix.QtProvider):
711735
"""Wraps nn.Fp8DirectDotGeneralOp with Qwix's provider interface."""
712736

@@ -715,13 +739,13 @@ def dot_general(self, *args, **kwargs):
715739
rule, op_id = self._get_current_rule_and_op_id("dot_general")
716740
if rule is None:
717741
return jax.lax.dot_general(*args, **kwargs)
718-
return nn.Fp8DirectDotGeneralOp(name=op_id)(*args, **kwargs)
742+
return _apply_linen_module_in_nnx(nn.Fp8DirectDotGeneralOp, op_id, *args, **kwargs)
719743

720744
def einsum(self, *args, **kwargs):
721745
rule, op_id = self._get_current_rule_and_op_id("einsum")
722746
if rule is None:
723747
return jnp.einsum(*args, **kwargs)
724-
return nn.Fp8Einsum(name=op_id)(*args, **kwargs)
748+
return _apply_linen_module_in_nnx(nn.Fp8Einsum, op_id, *args, **kwargs)
725749

726750

727751
class NANOOFp8Provider(qwix.QtProvider):
@@ -731,7 +755,7 @@ def dot_general(self, *args, **kwargs):
731755
rule, op_id = self._get_current_rule_and_op_id("dot_general")
732756
if rule is None:
733757
return jax.lax.dot_general(*args, **kwargs)
734-
return nn.NANOOFp8DotGeneralOp(name=op_id)(*args, **kwargs)
758+
return _apply_linen_module_in_nnx(nn.NANOOFp8DotGeneralOp, op_id, *args, **kwargs)
735759

736760

737761
def get_fp8_full_qwix_rule_w_sparsity(config: Config):
@@ -812,7 +836,15 @@ def maybe_quantize_model(model, config):
812836
if config.use_qwix_quantization and not config.use_batch_split_schedule:
813837
quantization_provider = get_qt_provider(config)
814838
if quantization_provider:
815-
model = qwix.quantize_model(model, quantization_provider)
839+
if config.pure_nnx:
840+
input_shape = (config.micro_batch_size_to_train_on, config.max_target_length)
841+
import jax.numpy as jnp
842+
dummy_tokens = jnp.ones(input_shape, dtype=jnp.int32)
843+
dummy_positions = jnp.ones(input_shape, dtype=jnp.int32)
844+
dummy_segment_ids = jnp.ones(input_shape, dtype=jnp.int32)
845+
model = qwix.quantize_model(model, quantization_provider, dummy_tokens, dummy_positions, dummy_segment_ids, enable_dropout=False)
846+
else:
847+
model = qwix.quantize_model(model, quantization_provider)
816848
return model
817849

818850

src/maxtext/trainers/pre_train/train.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,8 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
349349
is_train=True,
350350
)
351351
else:
352-
model_graphdef, curr_params, rest = nnx.split(state.model, nnx.Param, ...)
352+
OverwriteWithGradient = nnx.variablelib.variable_type_from_name(maxtext_utils.OVERWRITE_WITH_GRADIENT, allow_register=True)
353+
model_graphdef, curr_params, overwrite_vars, rest = nnx.split(state.model, nnx.Param, OverwriteWithGradient, ...)
353354
if config.parameter_memory_host_offload:
354355
# Params are kept on host (pinned_host) in in_shardings. Move only Param
355356
# variables to device before the forward/backward pass so that all dot_general
@@ -371,15 +372,16 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
371372
)
372373
nnx.update(state.model, curr_params)
373374

374-
def diff_wrapper(param, rest, config, data):
375-
local_model = nnx.merge(model_graphdef, param, rest, copy=True)
375+
def diff_wrapper(param, overwrite_vars, rest, config, data):
376+
local_model = nnx.merge(model_graphdef, param, overwrite_vars, rest, copy=True)
376377
loss, aux = loss_fn(local_model, config, data, None, None, is_train=True)
377-
_, _, new_rest = nnx.split(local_model, nnx.Param, ...)
378-
return loss, (aux, new_rest)
378+
_, _, new_overwrite_vars, new_rest = nnx.split(local_model, nnx.Param, OverwriteWithGradient, ...)
379+
return loss, (aux, new_overwrite_vars, new_rest)
379380

380-
grad_func = jax.value_and_grad(diff_wrapper, argnums=0, has_aux=True)
381-
(loss, (aux, new_rest)), raw_grads = grad_func(curr_params, rest, config, data)
381+
grad_func = jax.value_and_grad(diff_wrapper, argnums=(0, 1), has_aux=True)
382+
(loss, (aux, new_overwrite_vars, new_rest)), (raw_grads, overwrite_grads) = grad_func(curr_params, overwrite_vars, rest, config, data)
382383
nnx.update(state.model, new_rest)
384+
nnx.update(state.model, overwrite_grads)
383385

384386
raw_grads = jax.tree_util.tree_map(
385387
lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x,

tests/unit/quantizations_test.py

Lines changed: 128 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -387,17 +387,117 @@ def compare_fn(path, x, y):
387387

388388
jax.tree_util.tree_map_with_path(compare_fn, a, b)
389389

390-
def quantization_config(self, quant, logits_tolerance=2e-1, grad_tolerance=5e-1):
390+
def quantization_config(self, quant, logits_tolerance=2e-1, grad_tolerance=5e-1, **kwargs):
391391
"""Run forward pass and backward pass for quantized model and compare with base model."""
392392
# pylint: disable=protected-access
393-
cfg = self.init_pyconfig(quantization=quant)
394-
qt_model = model_creation_utils.create_model(cfg, self.mesh)
395-
393+
cfg = self.init_pyconfig(quantization=quant, **kwargs)
396394
ids, decoder_segment_ids, decoder_positions = self.get_data()
397395

398-
if not hasattr(self.__class__, "_cached_base_results"):
399-
model = model_creation_utils.create_model(self.cfg, self.mesh)
400-
var = model.init(
396+
if cfg.pure_nnx:
397+
qt_model = model_creation_utils.create_model(cfg, self.mesh, rngs=nnx.Rngs(0))
398+
if getattr(self.__class__, "_cached_base_results_nnx", None) is None:
399+
base_cfg = self.init_pyconfig(quantization="", **kwargs)
400+
base_model = model_creation_utils.create_model(base_cfg, self.mesh, rngs=nnx.Rngs(0))
401+
402+
def loss_base(model):
403+
logits = model(
404+
decoder_input_tokens=ids,
405+
decoder_positions=decoder_positions,
406+
decoder_segment_ids=decoder_segment_ids,
407+
enable_dropout=False,
408+
)
409+
return jnp.mean((logits) ** 2)
410+
411+
grads_base = nnx.grad(loss_base)(base_model)
412+
logits_base = base_model(
413+
decoder_input_tokens=ids,
414+
decoder_positions=decoder_positions,
415+
decoder_segment_ids=decoder_segment_ids,
416+
enable_dropout=False,
417+
)
418+
self.__class__._cached_base_results_nnx = (grads_base, logits_base)
419+
420+
grads_base, logits = self.__class__._cached_base_results_nnx
421+
422+
def loss_quant(model):
423+
logits_q = model(
424+
decoder_input_tokens=ids,
425+
decoder_positions=decoder_positions,
426+
decoder_segment_ids=decoder_segment_ids,
427+
enable_dropout=False,
428+
)
429+
return jnp.mean((logits_q) ** 2)
430+
431+
grads_quant = nnx.grad(loss_quant)(qt_model)
432+
quant_logits = qt_model(
433+
decoder_input_tokens=ids,
434+
decoder_positions=decoder_positions,
435+
decoder_segment_ids=decoder_segment_ids,
436+
enable_dropout=False,
437+
)
438+
439+
print("relative error in logits:" f" {jnp.abs(quant_logits - logits).mean() / jnp.abs(logits).mean()}")
440+
assert jnp.abs(quant_logits - logits).mean() / jnp.abs(logits).mean() < logits_tolerance
441+
442+
# nnx.grad returns a State object which is a mapping of paths to gradients.
443+
# Flatten them to check for tolerance.
444+
from flax.nnx import traversals
445+
grads_base_flat = traversals.flatten_mapping(grads_base)
446+
grads_quant_flat = traversals.flatten_mapping(grads_quant)
447+
448+
# Filter for param collections to compare only parameters and not stats/buffers if any
449+
# Note: NNX grads structure might contain variables like 'kernel', 'bias'.
450+
# For simplicity we compare all matching keys.
451+
def flatten_and_filter(grads_flat):
452+
return {k: v for k, v in grads_flat.items() if hasattr(v, "shape") and "quant_stats" not in str(k)}
453+
454+
gb_f = flatten_and_filter(grads_base_flat)
455+
gq_f = flatten_and_filter(grads_quant_flat)
456+
457+
for k in gb_f:
458+
if k in gq_f:
459+
diff = jnp.abs(gb_f[k] - gq_f[k]).mean() / (jnp.abs(gb_f[k]).mean() + 1e-8)
460+
if diff > grad_tolerance:
461+
print(f"Gradient mismatch for {k}: rel_error = {diff}")
462+
assert diff <= grad_tolerance
463+
else:
464+
qt_model = model_creation_utils.create_model(cfg, self.mesh)
465+
if not hasattr(self.__class__, "_cached_base_results"):
466+
model = model_creation_utils.create_model(self.cfg, self.mesh)
467+
var = model.init(
468+
{"params": self.rng, "aqt": self.rng, "dropout": self.rng},
469+
ids,
470+
decoder_positions,
471+
decoder_segment_ids,
472+
enable_dropout=False,
473+
mutable=True,
474+
)
475+
476+
def loss_base_linen(all_vars, inputs):
477+
logits_b, _ = model.apply(
478+
all_vars,
479+
*inputs,
480+
enable_dropout=False,
481+
rngs={"params": self.rng},
482+
mutable=True,
483+
)
484+
return jnp.mean((logits_b) ** 2)
485+
486+
grads_base_linen = jax.grad(loss_base_linen)(var, (ids, decoder_positions, decoder_segment_ids))
487+
logits_b, _ = model.apply(
488+
var,
489+
ids,
490+
decoder_positions,
491+
decoder_segment_ids,
492+
enable_dropout=False,
493+
rngs={"params": self.rng},
494+
mutable=True,
495+
)
496+
self.__class__._cached_base_results = (grads_base_linen, logits_b)
497+
498+
grads_base_linen, logits = self.__class__._cached_base_results
499+
500+
quantized_vars = qt_model.init(
401501
{"params": self.rng, "aqt": self.rng, "dropout": self.rng},
402502
ids,
403503
decoder_positions,
@@ -406,71 +506,37 @@ def quantization_config(self, quant, logits_tolerance=2e-1, grad_tolerance=5e-1)
406506
mutable=True,
407507
)
408508

409-
def loss_base(all_vars, inputs):
410-
logits, _ = model.apply(
509+
def loss_quant_linen(all_vars, inputs):
510+
logits_q, _ = qt_model.apply(
411511
all_vars,
412512
*inputs,
413513
enable_dropout=False,
414514
rngs={"params": self.rng},
415515
mutable=True,
416516
)
417-
return jnp.mean((logits) ** 2)
517+
return jnp.mean((logits_q) ** 2)
518+
519+
grads_quant_linen = jax.grad(loss_quant_linen)(quantized_vars, (ids, decoder_positions, decoder_segment_ids))
418520

419-
grads_base = jax.grad(loss_base)(var, (ids, decoder_positions, decoder_segment_ids))
420-
logits, _ = model.apply(
421-
var,
521+
quant_logits, _ = qt_model.apply(
522+
quantized_vars,
422523
ids,
423524
decoder_positions,
424525
decoder_segment_ids,
425526
enable_dropout=False,
426527
rngs={"params": self.rng},
427528
mutable=True,
428529
)
429-
self.__class__._cached_base_results = (grads_base, logits)
430-
431-
grads_base, logits = self.__class__._cached_base_results
432-
433-
quantized_vars = qt_model.init(
434-
{"params": self.rng, "aqt": self.rng, "dropout": self.rng},
435-
ids,
436-
decoder_positions,
437-
decoder_segment_ids,
438-
enable_dropout=False,
439-
mutable=True,
440-
)
441-
442-
def loss_quant(all_vars, inputs):
443-
logits, _ = qt_model.apply(
444-
all_vars,
445-
*inputs,
446-
enable_dropout=False,
447-
rngs={"params": self.rng},
448-
mutable=True,
530+
print("relative error in logits:" f" {jnp.abs(quant_logits - logits).mean() / jnp.abs(logits).mean()}")
531+
assert jnp.abs(quant_logits - logits).mean() / jnp.abs(logits).mean() < logits_tolerance
532+
self.print_grad_diff(grads_base_linen["params"], grads_quant_linen["params"])
533+
self.assertTrue(
534+
self.pytree_allclose(
535+
grads_base_linen["params"],
536+
grads_quant_linen["params"],
537+
tolerance=grad_tolerance,
538+
)
449539
)
450-
return jnp.mean((logits) ** 2)
451-
452-
# Compute gradients w.r.t. both models
453-
grads_quant = jax.grad(loss_quant)(quantized_vars, (ids, decoder_positions, decoder_segment_ids))
454-
455-
quant_logits, _ = qt_model.apply(
456-
quantized_vars,
457-
ids,
458-
decoder_positions,
459-
decoder_segment_ids,
460-
enable_dropout=False,
461-
rngs={"params": self.rng},
462-
mutable=True,
463-
)
464-
print("relative error in logits:" f" {jnp.abs(quant_logits - logits).mean() / jnp.abs(logits).mean()}")
465-
assert jnp.abs(quant_logits - logits).mean() / jnp.abs(logits).mean() < logits_tolerance
466-
self.print_grad_diff(grads_base["params"], grads_quant["params"])
467-
self.assertTrue(
468-
self.pytree_allclose(
469-
grads_base["params"],
470-
grads_quant["params"],
471-
tolerance=grad_tolerance,
472-
)
473-
)
474540

475541
@pytest.mark.tpu_only
476542
def test_int8_quantization(self):
@@ -489,6 +555,11 @@ def test_fp8_full_quantization(self):
489555
def test_fp8_gpu_quantization(self):
490556
self.quantization_config("fp8_gpu", grad_tolerance=1.5)
491557

558+
# @pytest.mark.gpu_only
559+
@pytest.mark.external_serving
560+
def test_fp8_gpu_quantization(self):
561+
self.quantization_config("fp8_gpu", grad_tolerance=1.5, enable_nnx=True, pure_nnx_decoder=True, pure_nnx=True)
562+
492563
@pytest.mark.gpu_only
493564
@pytest.mark.external_serving
494565
def test_fp8_nanoo_quantization(self):

0 commit comments

Comments
 (0)