Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/maxtext/inference/maxengine/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,12 +446,16 @@ def _load_params_nnx(self, params, rng):
rest_dict = rest_state.to_pure_dict()

def _overlay(dst, src):
if isinstance(dst, dict):
if isinstance(dst, dict) and isinstance(src, dict):
for k, v in dst.items():
if k in src:
dst[k] = _overlay(v, src[k])
return dst
return src if not isinstance(src, dict) else dst
# On structural mismatch keep dst (PREFILL); swapping a leaf for a subtree
# (or the other way) would corrupt the model. Both-leaves is the overlay case.
if isinstance(dst, dict) or isinstance(src, dict):
return dst
return src

rest_dict = _overlay(rest_dict, loaded_rest_dict)
nnx.replace_by_pure_dict(rest_state, rest_dict)
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/layers/nnx_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,7 @@ def apply_output_head(self, shared_embedding, y, deterministic, model_mode):
if cfg.logits_via_embedding:
# Use the transpose of embedding matrix for logit transform.
if isinstance(shared_embedding, nnx.Module):
embedding_table = shared_embedding.embedding.value
embedding_table = shared_embedding.embedding[...]
else:
embedding_table = shared_embedding.variables["params"]["embedding"]
if isinstance(embedding_table, nn.spmd.LogicallyPartitioned):
Expand Down
5 changes: 0 additions & 5 deletions src/maxtext/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,6 @@ def __init__(
else:
decoder_linen = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode)
self.decoder = nnx_wrappers.ToNNX(decoder_linen, rngs=rngs)
self.hidden_states = None

batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config=cfg, model_mode=model_mode)
dummy_decoder_input_tokens = jnp.ones((batch_size, seq_len), dtype=jnp.int32)
Expand Down Expand Up @@ -567,10 +566,6 @@ def __call__(
mutable=mutable_collections,
) # pytype: disable=wrong-keyword-args

# Materialize hidden state when vocab tiling is enabled
if self.config.num_vocab_tiling > 1:
self.hidden_states = hidden_state

# If we are initializing the model AND MTP is enabled, we must create
# dummy target tensors. This allows Flax to trace the MTPBlock and create
# all its necessary parameters, without requiring the main training pipeline
Expand Down
18 changes: 8 additions & 10 deletions src/maxtext/utils/model_creation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,16 +1039,14 @@ def _build_value_target(v):
def _free_device_memory(node):
if isinstance(node, nnx.Variable) and not isinstance(node, (nnx.RngState, nnx.Cache)):
inner = node.get_value() if hasattr(node, "get_value") else node[...]
# Same QTensor caveat as `_build_value_target`: AQT serve-mode `qrhs.frozen`
# wraps a QTensor whose `__getitem__` fails on `LogicallyPartitioned`.
# We only need to free a single jax.Array leaf — for composite values
# there's nothing to free at this level, so skip.
val = inner if hasattr(inner, "shape") else None
else:
val = node

if isinstance(val, jax.Array) and not val.is_deleted():
val.delete()
# AQT serve-mode `qrhs.frozen` wraps a QTensor (composite pytree) rather
# than a single jax.Array. Walking via tree_leaves frees the qvalue/scale
# arrays too; the single-leaf case is a 1-element tree.
for leaf in jax.tree_util.tree_leaves(inner):
if isinstance(leaf, jax.Array) and not leaf.is_deleted():
leaf.delete()
elif isinstance(node, jax.Array) and not node.is_deleted():
node.delete()

return node

Expand Down
194 changes: 156 additions & 38 deletions src/maxtext/utils/vocabulary_tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,29 @@
from maxtext.utils import max_utils


# Submodule names whose params are used by logits_from_hidden_states_for_vocab_tiling:
# the final norm, the LM-head dense, and the embedding table when logits are tied.
# vocab_tiling_nnx_loss splits these out as the only params the loss differentiates.
_OUTPUT_HEAD_PATH_KEYS = ("token_embedder", "shared_embedding", "decoder_norm", "logits_dense")


def _is_output_head_param_path(path, _value):
"""Filter for nnx.split: True when the param path belongs to the output head."""

# JAX path entries differ by key type: DictKey uses .key, GetAttrKey uses .name
# in newer Flax and .attr in older. Check all three so the filter survives
# version upgrades.
def _name(k):
for attr in ("key", "attr", "name"):
v = getattr(k, attr, None)
if v is not None:
return str(v)
return str(k)

keys = [_name(k) for k in path]
return any(k in keys for k in _OUTPUT_HEAD_PATH_KEYS)
Comment thread
ecnal-cienet marked this conversation as resolved.


def vocab_tiling_linen_loss(
hidden_states,
data,
Expand Down Expand Up @@ -253,12 +276,12 @@ def _bwd_scan_body(grad_params_acc, chunk_data):
def vocab_tiling_nnx_loss(model, hidden_states, data, config, is_train):
"""Computes cross-entropy loss with vocab tiling for NNX models.

NNX equivalent of ``vocab_tiling_linen_loss``. Scans the vocab dimension
and calls ``model.logits_from_hidden_states_for_vocab_tiling`` per chunk. The NNX model
carries its own parameters, so no explicit gather is needed.

Uses default autograd; a custom_vjp for backward memory savings can be
added later if needed.
NNX equivalent of `vocab_tiling_linen_loss`. A `custom_vjp` runs the loss in
vocab chunks via `jax.lax.scan` so the backward only holds one chunk's logits
at a time, matching the Linen path's memory profile. `nnx.split` separates the
output-head params (which the loss differentiates) from everything else; the
rest of the model is passed through but not differentiated, so the scan's
residuals stay small.

Args:
model: NNX model exposing ``logits_from_hidden_states_for_vocab_tiling``.
Expand Down Expand Up @@ -320,42 +343,137 @@ def _reshape(inputs, out_shape, out_sharding):
labels = _maybe_shard_with_name(labels, label_spec)
segmentation = _maybe_shard_with_name(segmentation, label_spec)

batch_size, seq_len, emb_dim = hidden_states.shape
vocab_tile_size = (batch_size * seq_len) // config.num_vocab_tiling
# head_params is what the loss differentiates; other_params (transformer layers) and
# rest (rngs) are passed through the custom_vjp but not differentiated. They go through
# as primals rather than closure captures: capturing them leaks tracers across the
# custom_vjp + lax.scan boundary, which fails for tied embeddings.
graphdef, head_params, other_params, rest = nnx.split(model, _is_output_head_param_path, nnx.Param, ...)

reshaped_hidden_states = _reshape(
hidden_states, (config.num_vocab_tiling, vocab_tile_size, emb_dim), reshaped_hidden_spec
)
reshaped_labels = _reshape(labels, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec)
reshaped_segmentation = _reshape(segmentation, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec)

# Rebuild the model per chunk inside the scan: the output head pulls an rng stream, and
# mutating the outer model's rng inside scan's sub-trace raises TraceContextError.
# nnx.merge(..., copy=True) makes fresh Variables local to each iteration.
graphdef, model_state = nnx.split(model)

def _scan_body(accumulators, chunk_data):
loss_accumulator, z_loss_accumulator = accumulators
hidden_chunk, label_chunk, segmentation_chunk = chunk_data
hidden_chunk = _maybe_shard_with_name(hidden_chunk, chunked_hidden_spec)
label_chunk = _maybe_shard_with_name(label_chunk, chunked_data_spec)
segmentation_chunk = _maybe_shard_with_name(segmentation_chunk, chunked_data_spec)

chunk_model = nnx.merge(graphdef, model_state, copy=True)
chunk_logits = chunk_model.logits_from_hidden_states_for_vocab_tiling(hidden_chunk, deterministic, model_mode)
chunk_logits = _maybe_shard_with_name(chunk_logits, chunked_logits_spec)
one_hot_label_chunk = jax.nn.one_hot(label_chunk, config.vocab_size)
chunk_xent, chunk_z_loss = max_utils.cross_entropy_with_logits(
chunk_logits, one_hot_label_chunk, z_loss=config.z_loss_multiplier
def _logits_for_chunk(chunk_head_params, chunk_other_params, chunk_rest, hidden_chunk):
local_model = nnx.merge(graphdef, chunk_head_params, chunk_other_params, chunk_rest, copy=True)
chunk_logits = local_model.logits_from_hidden_states_for_vocab_tiling(hidden_chunk, deterministic, model_mode)
return _maybe_shard_with_name(chunk_logits, chunked_logits_spec)

@jax.custom_vjp
def chunked_cross_entropy_loss(chunk_head_params, chunk_other_params, chunk_rest, hidden_states, labels, segmentation):
(total_loss, total_z_loss), _ = _chunked_cross_entropy_loss_fwd(
chunk_head_params, chunk_other_params, chunk_rest, hidden_states, labels, segmentation
)
return total_loss, total_z_loss

def _chunked_cross_entropy_loss_fwd(
chunk_head_params, chunk_other_params, chunk_rest, hidden_states, labels, segmentation
):
batch_size, seq_len, emb_dim = hidden_states.shape
vocab_tile_size = (batch_size * seq_len) // config.num_vocab_tiling

reshaped_hidden_states = _reshape(
hidden_states, (config.num_vocab_tiling, vocab_tile_size, emb_dim), reshaped_hidden_spec
)
reshaped_labels = _reshape(labels, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec)
reshaped_segmentation = _reshape(segmentation, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec)

def _fwd_scan_body(accumulators, chunk_data):
loss_accumulator, z_loss_accumulator = accumulators
hidden_chunk, label_chunk, segmentation_chunk = chunk_data
hidden_chunk = _maybe_shard_with_name(hidden_chunk, chunked_hidden_spec)
label_chunk = _maybe_shard_with_name(label_chunk, chunked_data_spec)
segmentation_chunk = _maybe_shard_with_name(segmentation_chunk, chunked_data_spec)

masked_xent = jnp.sum(chunk_xent * (segmentation_chunk != 0))
masked_z_loss = jnp.sum(chunk_z_loss * (segmentation_chunk != 0))
chunk_logits = _logits_for_chunk(chunk_head_params, chunk_other_params, chunk_rest, hidden_chunk)
one_hot_label_chunk = jax.nn.one_hot(label_chunk, config.vocab_size)
chunk_xent, chunk_z_loss = max_utils.cross_entropy_with_logits(
chunk_logits, one_hot_label_chunk, z_loss=config.z_loss_multiplier
)

return (loss_accumulator + masked_xent, z_loss_accumulator + masked_z_loss), None
masked_xent = jnp.sum(chunk_xent * (segmentation_chunk != 0))
masked_z_loss = jnp.sum(chunk_z_loss * (segmentation_chunk != 0))

return (loss_accumulator + masked_xent, z_loss_accumulator + masked_z_loss), None

initial_acc = (jnp.zeros((), dtype=hidden_states.dtype), jnp.zeros((), dtype=hidden_states.dtype))
(total_loss, total_z_loss), _ = jax.lax.scan(
_scan_body, initial_acc, (reshaped_hidden_states, reshaped_labels, reshaped_segmentation)
# Always accumulate in fp32 — `cross_entropy_with_logits` returns fp32 regardless of
# logits dtype, and a bf16 carry would mismatch the body output type under lax.scan.
initial_acc = (jnp.zeros((), dtype=jnp.float32), jnp.zeros((), dtype=jnp.float32))
(total_loss, total_z_loss), _ = jax.lax.scan(
_fwd_scan_body, initial_acc, (reshaped_hidden_states, reshaped_labels, reshaped_segmentation)
)
residuals = (
chunk_head_params,
chunk_other_params,
chunk_rest,
reshaped_hidden_states,
reshaped_labels,
reshaped_segmentation,
batch_size,
seq_len,
emb_dim,
)
return (total_loss, total_z_loss), residuals

def _chunked_cross_entropy_loss_bwd(residuals, cotangents):
# z_loss is folded into the xent loss inside cross_entropy_with_logits.
loss_cotangent, _ = cotangents

(
chunk_head_params,
chunk_other_params,
chunk_rest,
reshaped_hidden_states,
reshaped_labels,
reshaped_segmentation,
batch_size,
seq_len,
emb_dim,
) = residuals

def _single_chunk_loss_fn(input_head_params, input_hidden_chunk, input_label_chunk, input_segmentation_chunk):
chunk_logits = _logits_for_chunk(input_head_params, chunk_other_params, chunk_rest, input_hidden_chunk)
one_hot_label_chunk = jax.nn.one_hot(input_label_chunk, config.vocab_size)
xent, _ = max_utils.cross_entropy_with_logits(chunk_logits, one_hot_label_chunk, z_loss=config.z_loss_multiplier)
return jnp.sum(xent * (input_segmentation_chunk != 0))

def _bwd_scan_body(grad_head_acc, chunk_data):
hidden_chunk, label_chunk, segmentation_chunk = chunk_data
hidden_chunk = _maybe_shard_with_name(hidden_chunk, chunked_hidden_spec)
label_chunk = _maybe_shard_with_name(label_chunk, chunked_data_spec)
segmentation_chunk = _maybe_shard_with_name(segmentation_chunk, chunked_data_spec)

# pylint: disable=unnecessary-lambda-assignment
loss_fn_for_vjp = lambda p, h: _single_chunk_loss_fn(p, h, label_chunk, segmentation_chunk)
_, vjp_fn = jax.vjp(loss_fn_for_vjp, chunk_head_params, hidden_chunk)
(grad_head_update, grad_hidden_chunk) = vjp_fn(1.0)
grad_hidden_chunk = _maybe_shard_with_name(grad_hidden_chunk, chunked_hidden_spec)

grad_head_acc = jax.tree_util.tree_map(lambda acc, update: acc + update, grad_head_acc, grad_head_update)
return grad_head_acc, grad_hidden_chunk

initial_grad_head = jax.tree_util.tree_map(jnp.zeros_like, chunk_head_params)

grad_head, grad_reshaped_hidden_states = jax.lax.scan(
_bwd_scan_body, initial_grad_head, (reshaped_hidden_states, reshaped_labels, reshaped_segmentation)
)
grad_reshaped_hidden_states = _maybe_shard_with_name(grad_reshaped_hidden_states, reshaped_hidden_spec)
grad_head = jax.tree_util.tree_map(lambda g: g * loss_cotangent, grad_head)
grad_head = jax.tree_util.tree_map(lambda x, y: y.astype(x.dtype), chunk_head_params, grad_head)
grad_reshaped_hidden_states = _reshape(grad_reshaped_hidden_states, (batch_size, seq_len, emb_dim), hidden_spec)

# Return explicit zeros for other_params and rest, not None. With None, JAX builds
# the zero cotangents with the wrong layer-axis order for scanned params, and the
# AOT trace fails the cotangent shape check.
grad_other = jax.tree_util.tree_map(jnp.zeros_like, chunk_other_params)
grad_rest = jax.tree_util.tree_map(jnp.zeros_like, chunk_rest)
return (
grad_head,
grad_other,
grad_rest,
grad_reshaped_hidden_states.astype(reshaped_hidden_states.dtype),
None,
None,
)

chunked_cross_entropy_loss.defvjp(_chunked_cross_entropy_loss_fwd, _chunked_cross_entropy_loss_bwd)

total_loss, total_z_loss = chunked_cross_entropy_loss(
head_params, other_params, rest, hidden_states, labels, segmentation
)
return total_loss, total_z_loss
Loading
Loading