Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 1 addition & 16 deletions src/maxtext/layers/nnx_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,22 +671,7 @@ def layer_fn(carry, scanned_vars):
params = nnx_ensure_scan_leading_axis(params, length)
state = nnx_ensure_scan_leading_axis(state, length)

# Linen FP8 ops keep amax_history in mutable Linen scope; jax.lax.scan
# leaks the tracer and hits UnexpectedTracerError. Use a Python for-loop
# for FP8 instead.
uses_linen_fp8_mutable_state = self.config.quantization in ("fp8_nanoo", "fp8_gpu")
if uses_linen_fp8_mutable_state:
carry = x_in
per_layer_states = []
for i in range(length):
current_params = jax.tree.map(lambda x, i=i: x[i], params)
current_state = jax.tree.map(lambda x, i=i: x[i], state)
carry, new_state_i = layer_fn(carry, (current_params, current_state))
per_layer_states.append(new_state_i)
final_carry = carry
scanned_state = jax.tree.map(lambda *xs: jnp.stack(list(xs)), *per_layer_states)
else:
final_carry, scanned_state = jax.lax.scan(layer_fn_wrapped, x_in, (params, state))
final_carry, scanned_state = jax.lax.scan(layer_fn_wrapped, x_in, (params, state))
returned_kv_stacked = None

if scan_axis != 0:
Expand Down
11 changes: 10 additions & 1 deletion src/maxtext/layers/nnx_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,16 @@ def maybe_unbox(x):
for path, _ in unknown_state_flat.items():
paths_str += f"\n - {'/'.join(map(str, path))}"

warnings.warn(f"Found unknown module paths in incoming state:{paths_str}")
# Dynamically reconstruct the unknown variables
curr = module
for p in path[:-1]:
if not hasattr(curr, p):
setattr(curr, p, nnx.Module())
curr = getattr(curr, p)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 The dynamic reconstruction of modules here is fragile and will likely fail in several scenarios: 1. **Integer paths**: If any part of `path[:-1]` is an integer (e.g., when traversing a list of layers or an `nnx.Sequential`), `setattr(curr, p, ...)` will raise a `TypeError` because attribute names must be strings. 2. **Container Types**: It always uses `nnx.Module()`. If the path expects a list or another container type, the structure will be incorrect. 3. **Callable Wrappers**: For Qwix quantization, these attributes are often intended to be `ToNNX` wrappers. If they are pre-created as plain `nnx.Module` instances, they will fail when called during the forward pass (specifically in `_apply_linen_module_in_nnx` which expects the attribute to be callable).

Consider a more robust way to ensure the model structure matches the incoming state, or ensure that the model is fully initialized (including quantization wrappers) before nnx.update is called.

Suggested change
# Dynamically reconstruct the unknown variables
curr = module
for p in path[:-1]:
if not hasattr(curr, p):
setattr(curr, p, nnx.Module())
curr = getattr(curr, p)

warnings.warn(
f"Found unknown module paths in incoming state:{paths_str}. Intermediate modules have been reconstructed."
)

nnx.update(module, new_state)
_refresh_variable_trace_state(module)
Expand Down
48 changes: 44 additions & 4 deletions src/maxtext/layers/quantizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,13 @@
from flax.linen import fp8_ops
from flax.linen import initializers as flax_initializers
import flax.linen as nn
from flax import nnx

from qwix._src import flax_util

from maxtext.common.common_types import DType, Config
from maxtext.inference.kvcache import KVQuant
from maxtext.layers import nnx_wrappers

# Params used to define mixed precision quantization configs
DEFAULT = "__default__" # default config
Expand Down Expand Up @@ -707,6 +711,28 @@ def configure_kv_quant(config):
return None if not config.quantize_kvcache else KVQuant(config)


def _apply_linen_module_in_nnx(linen_module_cls, op_id, *args, **kwargs):
"""Applies a Linen module within an NNX context."""
try:
parent = flax_util.get_current_module()
is_nnx = isinstance(parent, nnx.Module)
except Exception: # pylint: disable=broad-exception-caught
is_nnx = False

if is_nnx:
attr_name = f"_qwix_fp8_gpu_{op_id}"
if not hasattr(parent, attr_name):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Hardcoded RNG seed `nnx.Rngs(0)`. While likely acceptable for initialization of quantization stats, it's generally safer to fork an existing RNG or use a properly seeded one to avoid potential collisions if multiple modules are initialized this way.

rngs = getattr(parent, "qwix_rngs", None)
if rngs is None:
rngs = nnx.Rngs(0)
wrapper = nnx_wrappers.ToNNX(linen_module_cls(name=op_id), rngs=rngs)
wrapper.lazy_init(*args, **kwargs)
setattr(parent, attr_name, wrapper)
return getattr(parent, attr_name)(*args, mutable=["_overwrite_with_gradient"], **kwargs)
else:
return linen_module_cls(name=op_id)(*args, **kwargs)


class NvidaFp8Provider(qwix.QtProvider):
"""Wraps nn.Fp8DirectDotGeneralOp with Qwix's provider interface."""

Expand All @@ -715,13 +741,13 @@ def dot_general(self, *args, **kwargs):
rule, op_id = self._get_current_rule_and_op_id("dot_general")
if rule is None:
return jax.lax.dot_general(*args, **kwargs)
return nn.Fp8DirectDotGeneralOp(name=op_id)(*args, **kwargs)
return _apply_linen_module_in_nnx(nn.Fp8DirectDotGeneralOp, op_id, *args, **kwargs)

def einsum(self, *args, **kwargs):
rule, op_id = self._get_current_rule_and_op_id("einsum")
if rule is None:
return jnp.einsum(*args, **kwargs)
return nn.Fp8Einsum(name=op_id)(*args, **kwargs)
return _apply_linen_module_in_nnx(nn.Fp8Einsum, op_id, *args, **kwargs)


class NANOOFp8Provider(qwix.QtProvider):
Expand All @@ -731,7 +757,7 @@ def dot_general(self, *args, **kwargs):
rule, op_id = self._get_current_rule_and_op_id("dot_general")
if rule is None:
return jax.lax.dot_general(*args, **kwargs)
return nn.NANOOFp8DotGeneralOp(name=op_id)(*args, **kwargs)
return _apply_linen_module_in_nnx(nn.NANOOFp8DotGeneralOp, op_id, *args, **kwargs)


def get_fp8_full_qwix_rule_w_sparsity(config: Config):
Expand Down Expand Up @@ -812,7 +838,21 @@ def maybe_quantize_model(model, config):
if config.use_qwix_quantization and not config.use_batch_split_schedule:
quantization_provider = get_qt_provider(config)
if quantization_provider:
model = qwix.quantize_model(model, quantization_provider)
if config.pure_nnx:
input_shape = (config.micro_batch_size_to_train_on, config.max_target_length)
dummy_tokens = jnp.ones(input_shape, dtype=jnp.int32)
dummy_positions = jnp.ones(input_shape, dtype=jnp.int32)
dummy_segment_ids = jnp.ones(input_shape, dtype=jnp.int32)
model = qwix.quantize_model(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Dummy inputs for `qwix.quantize_model` are hardcoded to `jnp.ones`. This might not be suitable for all models or all input types. Consider making this more flexible or deriving it from the model configuration if possible.

model,
quantization_provider,
dummy_tokens,
dummy_positions,
dummy_segment_ids,
enable_dropout=False,
)
else:
model = qwix.quantize_model(model, quantization_provider)
return model


Expand Down
20 changes: 13 additions & 7 deletions src/maxtext/trainers/pre_train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,10 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
is_train=True,
)
else:
model_graphdef, curr_params, rest = nnx.split(state.model, nnx.Param, ...)
OverwriteWithGradient = nnx.variablelib.variable_type_from_name(
maxtext_utils.OVERWRITE_WITH_GRADIENT, allow_register=True
)
model_graphdef, curr_params, overwrite_vars, rest = nnx.split(state.model, nnx.Param, OverwriteWithGradient, ...)
if config.parameter_memory_host_offload:
# Params are kept on host (pinned_host) in in_shardings. Move only Param
# variables to device before the forward/backward pass so that all dot_general
Expand All @@ -371,15 +374,18 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
)
nnx.update(state.model, curr_params)

def diff_wrapper(param, rest, config, data):
local_model = nnx.merge(model_graphdef, param, rest, copy=True)
def diff_wrapper(param, overwrite_vars, rest, config, data):
local_model = nnx.merge(model_graphdef, param, overwrite_vars, rest, copy=True)
loss, aux = loss_fn(local_model, config, data, None, None, is_train=True)
_, _, new_rest = nnx.split(local_model, nnx.Param, ...)
return loss, (aux, new_rest)
_, _, new_overwrite_vars, new_rest = nnx.split(local_model, nnx.Param, OverwriteWithGradient, ...)
return loss, (aux, new_overwrite_vars, new_rest)

grad_func = jax.value_and_grad(diff_wrapper, argnums=0, has_aux=True)
(loss, (aux, new_rest)), raw_grads = grad_func(curr_params, rest, config, data)
grad_func = jax.value_and_grad(diff_wrapper, argnums=(0, 1), has_aux=True)
(loss, (aux, _, new_rest)), (raw_grads, overwrite_grads) = grad_func(
curr_params, overwrite_vars, rest, config, data
)
nnx.update(state.model, new_rest)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 The logic for handling `overwrite_grads` and `new_rest` seems partially redundant. - You are extracting `new_overwrite_vars` in the `diff_wrapper` aux but then ignoring it in the caller: `(loss, (aux, _, new_rest))`. - Instead, you are using `overwrite_grads` to update the model. - If `overwrite_vars` (variables of type `OverwriteWithGradient`) use the "custom gradient trick" to return their updated values as gradients, then `overwrite_grads` and `new_overwrite_vars` should be identical. - However, if any variable of this type does *not* use this trick, its gradient will be 0, and `nnx.update(state.model, overwrite_grads)` will incorrectly zero it out, whereas `new_overwrite_vars` would have preserved the updated (or original) value.

Is there a reason to prefer overwrite_grads over new_overwrite_vars? If the goal is to support distributed all-reduce of these stats, that makes sense, but it should be documented.

Suggested change
nnx.update(state.model, new_rest)
grad_func = jax.value_and_grad(diff_wrapper, argnums=(0, 1), has_aux=True)
(loss, (aux, _, new_rest)), (raw_grads, overwrite_grads) = grad_func(
curr_params, overwrite_vars, rest, config, data
)
nnx.update(state.model, new_rest)
nnx.update(state.model, overwrite_grads)

nnx.update(state.model, overwrite_grads)

raw_grads = jax.tree_util.tree_map(
lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x,
Expand Down
Loading
Loading