Skip to content

Commit a049f9a

Browse files
Merge pull request #3470 from AI-Hypercomputer:feat/migrate-nnx-utils
PiperOrigin-RevId: 901445861
2 parents 7f78228 + 6a0f895 commit a049f9a

7 files changed

Lines changed: 338 additions & 90 deletions

File tree

src/maxtext/trainers/pre_train/train.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,32 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
355355
grads = maxtext_utils.apply_gradient_clipping(raw_grads, state, config.gradient_clipping_threshold)
356356
else:
357357
grads = raw_grads
358+
359+
# fp8 fix: sanitize NaN OWG (overwrite-with-gradient) stats before apply_gradients.
360+
# Under FSDP, the fp8 output gradient amax can be NaN at step 0, which propagates into
361+
# amax_history and corrupts future steps. Replace NaN OWG entries with the current state
362+
# values (skip the amax update for that step) instead of letting NaN flow through.
363+
# Also restore OWG values after apply_gradients to bypass optimizer corruption
364+
# (Adam should not update fp8 scale/amax_history).
365+
fp8_stats = dict(grads).get(maxtext_utils.OVERWRITE_WITH_GRADIENT, None)
366+
if fp8_stats is not None:
367+
if maxtext_utils.OVERWRITE_WITH_GRADIENT in state.params:
368+
current_fp8 = state.params[maxtext_utils.OVERWRITE_WITH_GRADIENT]
369+
fp8_stats = jax.tree_util.tree_map(
370+
lambda new, cur: jnp.where(jnp.isnan(new), cur, new),
371+
fp8_stats,
372+
current_fp8,
373+
)
374+
else:
375+
fp8_stats = jax.tree_util.tree_map(lambda x: jnp.nan_to_num(x, nan=0.0), fp8_stats)
376+
grads = dict(grads)
377+
grads[maxtext_utils.OVERWRITE_WITH_GRADIENT] = fp8_stats
378+
# Zero out any remaining NaN in float gradients to prevent param corruption
379+
grads = jax.tree_util.tree_map(
380+
lambda x: jnp.nan_to_num(x, nan=0.0) if jnp.issubdtype(x.dtype, jnp.floating) else x,
381+
grads,
382+
)
383+
358384
if config.optimizer_memory_host_offload:
359385
state = state.replace(
360386
opt_state=jax.device_put(
@@ -394,6 +420,12 @@ def move(path, value):
394420
else:
395421
new_state = state.apply_gradients(grads=grads)
396422

423+
# fp8 fix: restore sanitized OWG values, bypassing any optimizer update to fp8 stats.
424+
if fp8_stats is not None:
425+
new_params = dict(new_state.params)
426+
new_params[maxtext_utils.OVERWRITE_WITH_GRADIENT] = fp8_stats
427+
new_state = new_state.replace(params=new_params)
428+
397429
# Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family
398430
if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None:
399431
target_path = ("params", "decoder", "moe_layers", "DeepSeekMoeBlock_0", "MoeBlock_0", "gate", "bias")

src/maxtext/utils/maxtext_utils.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import functools
1919
import pickle
2020
import os
21+
from typing import Sequence
2122

2223
from flax import linen as nn
2324
from flax.linen import partitioning as nn_partitioning
@@ -27,6 +28,7 @@
2728

2829
from jax.experimental import mesh_utils
2930
from jax.experimental.serialize_executable import deserialize_and_load
31+
from jax.sharding import AxisType, Mesh
3032

3133
import jax
3234
import jax.numpy as jnp
@@ -36,7 +38,8 @@
3638
import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager
3739
import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager
3840

39-
from maxtext.common.common_types import DecoderBlockType, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE
41+
from maxtext.configs import pyconfig
42+
from maxtext.common.common_types import DecoderBlockType, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE, ShardMode
4043
from maxtext.configs import types
4144
from maxtext.inference.page_manager import PageState
4245
from maxtext.common import checkpointing
@@ -1681,3 +1684,27 @@ def maybe_dump_jaxpr(config, p_train_step, train_step_inputs):
16811684
delete_local_after=config.dump_jaxpr_delete_local_after, # Keeping local for debugging
16821685
all_host_upload=False, # Only upload from lead host (Host 0)
16831686
)
1687+
1688+
1689+
def get_mesh_from_config(
1690+
config: pyconfig.HyperParameters,
1691+
devices: Sequence[jax.Device] | None = None,
1692+
) -> Mesh:
1693+
"""
1694+
Geh mesh from the configuration.
1695+
1696+
Args:
1697+
config: the configuration
1698+
devices: the devices
1699+
1700+
Returns:
1701+
the device mesh
1702+
"""
1703+
devices_array = create_device_mesh(config, devices)
1704+
1705+
if config.shard_mode == ShardMode.EXPLICIT:
1706+
axis_types = tuple([AxisType.Explicit] * len(config.mesh_axes))
1707+
else:
1708+
axis_types = tuple([AxisType.Auto] * len(config.mesh_axes))
1709+
1710+
return Mesh(devices_array, config.mesh_axes, axis_types=axis_types)

src/maxtext/utils/maxtext_utils_nnx.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,18 @@
2222

2323
from maxtext.utils import max_logging
2424
from maxtext.configs import pyconfig
25+
from maxtext.common.common_types import MODEL_MODE_TRAIN
2526

2627

2728
def create_nnx_rngs(
28-
config: pyconfig.HyperParameters, is_training: bool = True, rng_key: jax.Array | None = None
29+
config: pyconfig.HyperParameters, model_mode: str = MODEL_MODE_TRAIN, rng_key: jax.Array | None = None
2930
) -> nnx.Rngs:
3031
"""
3132
Create NNX Rngs
3233
3334
Args:
3435
config: the configuration
35-
is_training: if the Rngs are for training
36+
model_mode: the model mode. See maxtext.common.common_types for valid values.
3637
rng_key: the Rng key
3738
3839
Returns:
@@ -41,7 +42,9 @@ def create_nnx_rngs(
4142
if rng_key is None:
4243
rng_key = jax.random.PRNGKey(config.init_weights_seed)
4344

44-
if is_training:
45+
if model_mode == MODEL_MODE_TRAIN:
46+
# Use fold_in to derive independent keys for each stream from a single seed.
47+
# aqt is needed for quantization-aware training.
4548
return nnx.Rngs(
4649
params=jax.random.fold_in(rng_key, 0), dropout=jax.random.fold_in(rng_key, 1), aqt=jax.random.fold_in(rng_key, 2)
4750
)

src/maxtext/utils/model_creation_utils.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,18 @@
1919
from collections.abc import Sequence
2020
from functools import partial
2121
from typing import overload
22-
2322
from etils import epath
2423
from flax import nnx
2524
import flax.linen as nn
2625
import jax
2726
import jax.numpy as jnp
28-
from jax.sharding import AxisType, Mesh
27+
from jax.sharding import Mesh
2928
from maxtext.configs import pyconfig
30-
from maxtext.common.common_types import MODEL_MODE_TRAIN, ShardMode
29+
from maxtext.common.common_types import MODEL_MODE_TRAIN
3130
from maxtext.layers import quantizations
3231
from maxtext.models import models
3332
from maxtext.utils import max_logging
34-
from maxtext.utils import max_utils
35-
from maxtext.utils import maxtext_utils
33+
from maxtext.utils import max_utils, maxtext_utils, maxtext_utils_nnx
3634
from orbax import checkpoint as ocp
3735

3836
try:
@@ -154,6 +152,7 @@ def from_config(
154152
mesh: Mesh | None = None,
155153
*,
156154
model_mode: str = MODEL_MODE_TRAIN,
155+
rngs: None = None,
157156
) -> nn.Module:
158157
...
159158

@@ -194,15 +193,7 @@ def from_config(
194193
model = from_config(config)
195194
"""
196195
if mesh is None:
197-
devices_array = maxtext_utils.create_device_mesh(config, devices)
198-
199-
if config.shard_mode == ShardMode.EXPLICIT:
200-
axis_types = tuple([AxisType.Explicit] * len(config.mesh_axes))
201-
else:
202-
axis_types = tuple([AxisType.Auto] * len(config.mesh_axes))
203-
204-
mesh = Mesh(devices_array, config.mesh_axes, axis_types=axis_types)
205-
196+
mesh = maxtext_utils.get_mesh_from_config(config, devices)
206197
model = create_model(config, mesh, model_mode=model_mode, rngs=rngs)
207198

208199
# Return only the model
@@ -245,9 +236,7 @@ def create_nnx_abstract_model(config, mesh, model_mode=MODEL_MODE_TRAIN, rng_key
245236
"""
246237

247238
def _create_model(rng_key=None):
248-
if rng_key is None:
249-
rng_key = jax.random.PRNGKey(config.init_weights_seed)
250-
rngs = nnx.Rngs(params=rng_key, dropout=1)
239+
rngs = maxtext_utils_nnx.create_nnx_rngs(config, model_mode=model_mode, rng_key=rng_key)
251240
return from_config(config, mesh=mesh, rngs=rngs, model_mode=model_mode)
252241

253242
_create_model_partial = partial(_create_model, rng_key=rng_key)
@@ -262,14 +251,7 @@ def create_nnx_model(config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAI
262251
"""Creates a NNX model with sharded parameters, possibly loading from a checkpoint."""
263252

264253
def _create_model(mesh: Mesh | None = None, model_mode: str = MODEL_MODE_TRAIN, rng_key: jax.Array | None = None):
265-
if rng_key is None:
266-
rng_key = jax.random.PRNGKey(config.init_weights_seed)
267-
268-
if model_mode == MODEL_MODE_TRAIN:
269-
rngs = nnx.Rngs(params=rng_key, dropout=1)
270-
else:
271-
rngs = nnx.Rngs(params=rng_key) # disable dropout RNG for inference
272-
254+
rngs = maxtext_utils_nnx.create_nnx_rngs(config, model_mode=model_mode, rng_key=rng_key)
273255
return from_config(config, devices, mesh, rngs=rngs, model_mode=model_mode)
274256

275257
_create_model_partial = partial(_create_model, mesh=mesh, model_mode=model_mode, rng_key=rng_key)
@@ -282,6 +264,17 @@ def _create_model(mesh: Mesh | None = None, model_mode: str = MODEL_MODE_TRAIN,
282264
if mesh is None:
283265
mesh = abstract_model.mesh
284266

267+
# Note for pure_nnx:
268+
# Currently, the NNX model returned has a linen decoder wrapped to NNX. So it is not a pure NNX model and
269+
# we still need to use nn.logical_axis_rules(config.logical_axis_rules) to get the out sharding from the linen
270+
# LogicallyPartitioned structure.
271+
# In the future if the pure NNX model is used, with pure NNX's eager sharding, there will be no LogicallyPartitioned
272+
# structure in the abstract state and we can get the sharded state with the following code:
273+
# graphdef, state = nnx.get_abstract_model(_create_model_partial, mesh)
274+
# abstract_model = nnx.merge(graphdef, state)
275+
# model = maxtext_utils_nnx.create_nnx_sharded_model(abstract_model, _create_model_partial, mesh=mesh)
276+
# sharded_state = nnx.state(model)
277+
285278
# JIT a function that creates the model state with proper sharding from the start.
286279
# By providing out_shardings, we instruct JAX to produce sharded output directly,
287280
# avoiding a large intermediate allocation on a single device.

0 commit comments

Comments
 (0)