Skip to content

Commit 1432864

Browse files
Merge pull request #4198 from AI-Hypercomputer:fix/fp8_qwix
PiperOrigin-RevId: 934430447
2 parents eacf996 + 9b4d96c commit 1432864

5 files changed

Lines changed: 243 additions & 104 deletions

File tree

src/maxtext/layers/nnx_decoders.py

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import jax.numpy as jnp
2626
from flax import linen as nn
2727
from flax import nnx
28-
from flax.nnx import wrappers as nnx_wrappers
28+
from maxtext.layers import nnx_wrappers
2929
from jax.ad_checkpoint import checkpoint_name
3030
from jax.sharding import Mesh
3131

@@ -939,17 +939,8 @@ def pure_layer_fn(state_in, y_in):
939939
out = merged_layer(y_in, **kwargs)
940940
return out, nnx.state(merged_layer)
941941

942-
# Linen FP8 ops keep amax_history in mutable Linen scope; jax.checkpoint
943-
# re-traces and hits UnexpectedTracerError. Skip remat for FP8.
944-
uses_linen_fp8_mutable_state = self.config.quantization in {
945-
"fp8_nanoo",
946-
"fp8_gpu",
947-
}
948-
if uses_linen_fp8_mutable_state:
949-
out, new_state = pure_layer_fn(state, y)
950-
else:
951-
checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
952-
out, new_state = checkpointed_fn(state, y)
942+
checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
943+
out, new_state = checkpointed_fn(state, y)
953944
nnx.update(layer, new_state)
954945

955946
return out
@@ -1077,26 +1068,7 @@ def layer_fn(carry, scanned_vars):
10771068
params = nnx_ensure_scan_leading_axis(params, length)
10781069
state = nnx_ensure_scan_leading_axis(state, length)
10791070

1080-
# Linen FP8 ops keep amax_history in mutable Linen scope; jax.lax.scan
1081-
# leaks the tracer and hits UnexpectedTracerError. Use a Python for-loop
1082-
# for FP8 instead.
1083-
uses_linen_fp8_mutable_state = self.config.quantization in {
1084-
"fp8_nanoo",
1085-
"fp8_gpu",
1086-
}
1087-
if uses_linen_fp8_mutable_state:
1088-
carry = x_in
1089-
per_layer_states = []
1090-
for i in range(length):
1091-
current_params = jax.tree.map(lambda x, i=i: x[i], params)
1092-
current_state = jax.tree.map(lambda x, i=i: x[i], state)
1093-
carry, new_state_i = layer_fn(carry, (current_params, current_state))
1094-
per_layer_states.append(new_state_i)
1095-
final_carry = carry
1096-
# pylint: disable-next=no-value-for-parameter (*per_layer_states supplies the `tree` arg)
1097-
scanned_state = jax.tree.map(lambda *xs: jnp.stack(list(xs)), *per_layer_states)
1098-
else:
1099-
final_carry, scanned_state = jax.lax.scan(layer_fn_wrapped, x_in, (params, state))
1071+
final_carry, scanned_state = jax.lax.scan(layer_fn_wrapped, x_in, (params, state))
11001072
returned_kv_stacked = None
11011073

11021074
if scan_axis != 0:

src/maxtext/layers/nnx_wrappers.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,14 +285,18 @@ def __call__(
285285
# Get `mutable` from top level bridge.Module context if any
286286
if mutable is not None:
287287
pass
288-
elif (m := bdg_module.current_module()) is not None:
288+
elif getattr(bdg_module.MODULE_CONTEXT, "module_stack", None) and (m := bdg_module.current_module()) is not None:
289289
assert m.scope is not None
290290
mutable = m.scope.mutable
291291
elif (m := current_linen_module()) is not None:
292292
assert m.scope is not None
293293
mutable = m.scope.mutable
294294
else:
295-
mutable = False
295+
# Safe fallback mutability: when running functionally isolated inside standard JAX transforms,
296+
# we determine which collections (such as "stats" or "amax_history") are present and mark them mutable.
297+
mutable = [k for k in variables.keys() if k != "params"]
298+
if not mutable:
299+
mutable = False
296300

297301
out = self.to_nnx__module.apply(variables, *args, rngs=_rngs, method=method, mutable=mutable, **kwargs)
298302

@@ -509,7 +513,31 @@ def maybe_unbox(x):
509513
for path, _ in unknown_state_flat.items():
510514
paths_str += f"\n - {'/'.join(map(str, path))}"
511515

512-
warnings.warn(f"Found unknown module paths in incoming state:{paths_str}")
516+
# Dynamically reconstruct the unknown variables
517+
curr = module
518+
for p in path[:-1]:
519+
if isinstance(curr, dict):
520+
if p not in curr:
521+
curr[p] = nnx.Module()
522+
curr = curr[p]
523+
elif isinstance(curr, list):
524+
if not isinstance(p, int):
525+
raise TypeError(f"Expected int index for list, got {type(p)}: {p}")
526+
while len(curr) <= p:
527+
curr.append(nnx.Module())
528+
curr = curr[p]
529+
elif isinstance(curr, tuple):
530+
raise ValueError(f"Cannot dynamically reconstruct elements within a tuple at path {path}.")
531+
else:
532+
if not isinstance(p, str):
533+
p = str(p)
534+
if not hasattr(curr, p):
535+
setattr(curr, p, nnx.Module())
536+
curr = getattr(curr, p)
537+
538+
warnings.warn(
539+
f"Found unknown module paths in incoming state:{paths_str}. Intermediate modules have been reconstructed."
540+
)
513541

514542
_fix_for_qwix_quantization(module)
515543
nnx.update(module, new_state)

src/maxtext/layers/quantizations.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@
3838
from flax.linen import fp8_ops
3939
from flax.linen import initializers as flax_initializers
4040
import flax.linen as nn
41+
from flax import nnx
42+
# Support different packaging structures across environments even within
43+
# the same Qwix version identifier (imports from _src.utils vs _src).
44+
try:
45+
from qwix._src.utils import flax_util
46+
except ImportError:
47+
from qwix._src import flax_util # pytype: disable=import-error
48+
from maxtext.layers import nnx_wrappers
4149

4250
from maxtext.common.common_types import DType, Config
4351
from maxtext.inference.kvcache import KVQuant
@@ -710,6 +718,32 @@ def configure_kv_quant(config):
710718
return None if not config.quantize_kvcache else KVQuant(config)
711719

712720

721+
def _apply_linen_module_in_nnx(linen_module_cls, op_id, *args, **kwargs):
722+
"""Applies a Linen module within an NNX context."""
723+
try:
724+
parent = flax_util.get_current_module()
725+
is_nnx = isinstance(parent, nnx.Module)
726+
except ValueError:
727+
is_nnx = False
728+
729+
if is_nnx:
730+
attr_name = f"_qwix_fp8_gpu_{op_id}"
731+
if not hasattr(parent, attr_name):
732+
rngs = getattr(parent, "qwix_rngs", None)
733+
if rngs is None:
734+
parent_rngs = getattr(parent, "rngs", None)
735+
if parent_rngs is not None and hasattr(parent_rngs, "fork"):
736+
rngs = parent_rngs.fork()
737+
else:
738+
rngs = nnx.Rngs(0)
739+
wrapper = nnx_wrappers.ToNNX(linen_module_cls(name=op_id), rngs=rngs)
740+
wrapper.lazy_init(*args, **kwargs)
741+
setattr(parent, attr_name, wrapper)
742+
return getattr(parent, attr_name)(*args, mutable=["_overwrite_with_gradient"], **kwargs)
743+
else:
744+
return linen_module_cls(name=op_id)(*args, **kwargs)
745+
746+
713747
class NvidaFp8Provider(qwix.QtProvider):
714748
"""Wraps nn.Fp8DirectDotGeneralOp with Qwix's provider interface."""
715749

@@ -718,13 +752,13 @@ def dot_general(self, *args, **kwargs):
718752
rule, op_id = self._get_current_rule_and_op_id("dot_general")
719753
if rule is None:
720754
return jax.lax.dot_general(*args, **kwargs)
721-
return nn.Fp8DirectDotGeneralOp(name=op_id)(*args, **kwargs)
755+
return _apply_linen_module_in_nnx(nn.Fp8DirectDotGeneralOp, op_id, *args, **kwargs)
722756

723757
def einsum(self, *args, **kwargs):
724758
rule, op_id = self._get_current_rule_and_op_id("einsum")
725759
if rule is None:
726760
return jnp.einsum(*args, **kwargs)
727-
return nn.Fp8Einsum(name=op_id)(*args, **kwargs)
761+
return _apply_linen_module_in_nnx(nn.Fp8Einsum, op_id, *args, **kwargs)
728762

729763

730764
class NANOOFp8Provider(qwix.QtProvider):
@@ -734,7 +768,7 @@ def dot_general(self, *args, **kwargs):
734768
rule, op_id = self._get_current_rule_and_op_id("dot_general")
735769
if rule is None:
736770
return jax.lax.dot_general(*args, **kwargs)
737-
return nn.NANOOFp8DotGeneralOp(name=op_id)(*args, **kwargs)
771+
return _apply_linen_module_in_nnx(nn.NANOOFp8DotGeneralOp, op_id, *args, **kwargs)
738772

739773

740774
def get_fp8_full_qwix_rule_w_sparsity(config: Config):
@@ -815,7 +849,21 @@ def maybe_quantize_model(model, config):
815849
if config.use_qwix_quantization and not config.use_batch_split_schedule:
816850
quantization_provider = get_qt_provider(config)
817851
if quantization_provider:
818-
model = qwix.quantize_model(model, quantization_provider)
852+
if config.pure_nnx:
853+
input_shape = (config.micro_batch_size_to_train_on, config.max_target_length)
854+
dummy_tokens = jnp.ones(input_shape, dtype=jnp.int32)
855+
dummy_positions = jnp.ones(input_shape, dtype=jnp.int32)
856+
dummy_segment_ids = jnp.ones(input_shape, dtype=jnp.int32)
857+
model = qwix.quantize_model(
858+
model,
859+
quantization_provider,
860+
dummy_tokens,
861+
dummy_positions,
862+
dummy_segment_ids,
863+
enable_dropout=False,
864+
)
865+
else:
866+
model = qwix.quantize_model(model, quantization_provider)
819867
return model
820868

821869

src/maxtext/trainers/pre_train/train.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939
from flax import linen as nn, nnx
4040
from flax.linen import partitioning as nn_partitioning
41+
from flax.nnx import variablelib
4142

4243
from maxtext.configs import pyconfig
4344
from maxtext.utils.globals import EPS
@@ -359,7 +360,9 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
359360
is_train=True,
360361
)
361362
else:
362-
model_graphdef, curr_params, rest = nnx.split(state.model, nnx.Param, ...)
363+
owg_type = variablelib.variable_type_from_name("_overwrite_with_gradient", allow_register=True)
364+
custom_param_filter = nnx.Any(owg_type)
365+
model_graphdef, curr_params, custom_params, rest = nnx.split(state.model, nnx.Param, custom_param_filter, ...)
363366
if config.parameter_memory_host_offload:
364367
# Params are kept on host (pinned_host) in in_shardings. Move only Param
365368
# variables to device before the forward/backward pass so that all dot_general
@@ -381,15 +384,15 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
381384
)
382385
nnx.update(state.model, curr_params)
383386

384-
def diff_wrapper(param, rest, config, data):
385-
local_model = nnx.merge(model_graphdef, param, rest, copy=True)
387+
def diff_wrapper(curr_params, custom_params, rest, config, data):
388+
local_model = nnx.merge(model_graphdef, curr_params, custom_params, rest, copy=True)
386389
loss, aux = loss_fn(local_model, config, data, None, None, is_train=True)
387-
_, _, new_rest = nnx.split(local_model, nnx.Param, ...)
390+
_, _, _, new_rest = nnx.split(local_model, nnx.Param, custom_param_filter, ...)
388391
return loss, (aux, new_rest)
389392

390-
grad_func = jax.value_and_grad(diff_wrapper, argnums=0, has_aux=True)
391-
(loss, (aux, new_rest)), raw_grads = grad_func(curr_params, rest, config, data)
392-
nnx.update(state.model, new_rest)
393+
grad_func = jax.value_and_grad(diff_wrapper, argnums=(0, 1), has_aux=True)
394+
(loss, (aux, new_rest)), (raw_grads, custom_grads) = grad_func(curr_params, custom_params, rest, config, data)
395+
nnx.update(state.model, nnx.State.merge(custom_grads, new_rest))
393396

394397
raw_grads = jax.tree_util.tree_map(
395398
lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x,

0 commit comments

Comments
 (0)