Skip to content

Commit 4f4c0b0

Browse files
committed
fix(nnx): support Zero-1 input shardings on NNX flat state
Under shard_optimizer_over_data, train_compile builds the AOT train-step input shardings by calling state_mesh_shardings.replace(params=params_shardings). That's a TrainState (flax.struct) method; with PR#11's NNX defaults, state_mesh_shardings is a flat nnx.State and the call dies with 'No attribute replace in State'. Add sharding.build_zero1_input_state_mesh_shardings that overlays params_shardings' Param leaves onto state_mesh_shardings.model for the NNX path while keeping the existing .replace behavior for Linen, and route both train_compile call sites through it. Fixes test_zero1_optimizer_sharding.
1 parent c40004a commit 4f4c0b0

2 files changed

Lines changed: 60 additions & 95 deletions

File tree

src/maxtext/trainers/pre_train/train_compile.py

Lines changed: 31 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,9 @@
6161
def validate_config(config):
6262
"""Validates the config is is setup correctly to compile, returning a useful error message if not."""
6363
assert config.compile_topology != "", (
64-
"You must pass your desired target hardware in compile_topology, e.g."
65-
" compile_topology=v5e-256"
64+
"You must pass your desired target hardware in compile_topology, e.g." " compile_topology=v5e-256"
6665
)
67-
assert (
68-
config.compile_topology_num_slices > 0
69-
), "You must set compile_topology_num_slices to a positive integer"
66+
assert config.compile_topology_num_slices > 0, "You must set compile_topology_num_slices to a positive integer"
7067

7168

7269
def get_topology_mesh(config):
@@ -78,18 +75,12 @@ def get_topology_mesh(config):
7875
num_slices=config.compile_topology_num_slices,
7976
).devices
8077
else:
81-
target_hardware = accelerator_to_spec_map.get_system_characteristics(
82-
config.compile_topology
83-
)
78+
target_hardware = accelerator_to_spec_map.get_system_characteristics(config.compile_topology)
8479
if target_hardware.platform == "gpu":
8580
# Disable sharded autotuning. This is an optimization to distribute
8681
# autotuning across the fleet, but can cause hangs with AoT compilation.
87-
os.environ["XLA_FLAGS"] = (
88-
os.environ.get("XLA_FLAGS", "") + " --xla_gpu_shard_autotuning=false"
89-
)
90-
jax.config.update(
91-
"mock_num_gpu_processes", config.compile_topology_num_slices
92-
)
82+
os.environ["XLA_FLAGS"] = os.environ.get("XLA_FLAGS", "") + " --xla_gpu_shard_autotuning=false"
83+
jax.config.update("mock_num_gpu_processes", config.compile_topology_num_slices)
9384
topology_devices = jax.devices()
9485
else:
9586
topology_devices = get_topology_desc(
@@ -104,14 +95,8 @@ def get_topology_mesh(config):
10495
"jax_remove_size_one_mesh_axis_from_type",
10596
config.remove_size_one_mesh_axis_from_type,
10697
)
107-
topology_device_mesh = maxtext_utils.create_device_mesh(
108-
config, topology_devices
109-
)
110-
mesh_axis_type = (
111-
AxisType.Explicit
112-
if config.shard_mode == ShardMode.EXPLICIT
113-
else AxisType.Auto
114-
)
98+
topology_device_mesh = maxtext_utils.create_device_mesh(config, topology_devices)
99+
mesh_axis_type = AxisType.Explicit if config.shard_mode == ShardMode.EXPLICIT else AxisType.Auto
115100
topology_mesh = Mesh(
116101
topology_device_mesh,
117102
config.mesh_axes,
@@ -129,9 +114,7 @@ def _collect_nnx_activation_shardings(create_model_fn, config, mesh):
129114
input_shape = (config.micro_batch_size_to_train_on, config.max_target_length)
130115
abstract_input = jax.ShapeDtypeStruct(input_shape, jnp.int32)
131116

132-
def _nnx_forward(
133-
decoder_input_tokens, decoder_positions, decoder_segment_ids
134-
):
117+
def _nnx_forward(decoder_input_tokens, decoder_positions, decoder_segment_ids):
135118
model_instance = create_model_fn()
136119
return model_instance(
137120
decoder_input_tokens=decoder_input_tokens,
@@ -140,9 +123,7 @@ def _nnx_forward(
140123
enable_dropout=False,
141124
)
142125

143-
with jax.set_mesh(mesh), nn_partitioning.axis_rules(
144-
config.logical_axis_rules
145-
):
126+
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
146127
jax.eval_shape(_nnx_forward, abstract_input, abstract_input, abstract_input)
147128

148129

@@ -151,13 +132,9 @@ def get_shaped_inputs(topology_mesh, config):
151132
# Construct the model and optimizer to get shaped versions of the state
152133
quant = quantizations.configure_quantization(config)
153134
if config.pure_nnx:
154-
_create_model_partial, model = (
155-
model_creation_utils.create_nnx_abstract_model(config, topology_mesh)
156-
)
135+
_create_model_partial, model = model_creation_utils.create_nnx_abstract_model(config, topology_mesh)
157136
else:
158-
model = Transformer(
159-
config, topology_mesh, quant=quant, model_mode=MODEL_MODE_TRAIN
160-
)
137+
model = Transformer(config, topology_mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
161138
# The learning_rate_schedule is baked into the compiled object.
162139
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config)
163140
# pass in model for muon
@@ -176,20 +153,14 @@ def create_train_state_fn():
176153

177154
init_state_fn = create_train_state_fn
178155
else:
179-
init_state_fn = functools.partial(
180-
maxtext_utils.init_initial_state, model, tx, config, True, example_rng
181-
)
156+
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, example_rng)
182157

183158
# Shaped state
184-
abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state(
185-
config, topology_mesh, init_state_fn, True
186-
)
159+
abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state(config, topology_mesh, init_state_fn, True)
187160

188161
if config.pure_nnx:
189162
# NNX doesn't use Linen logical annotations; derive PartitionSpecs from the physical shardings.
190-
logical_annotations = maxtext_utils_nnx.get_partition_spec_nnx(
191-
state_mesh_shardings
192-
)
163+
logical_annotations = maxtext_utils_nnx.get_partition_spec_nnx(state_mesh_shardings)
193164
# For NNX, get_functional_train_with_signature expects the graphdef (static structure),
194165
# not the raw model — mirroring how the training loop does nnx.split(train_state).
195166
with nn_partitioning.axis_rules(config.logical_axis_rules):
@@ -198,9 +169,7 @@ def create_train_state_fn():
198169
model = graphdef
199170
else:
200171
# unsharded logical annotations
201-
logical_annotations = maxtext_utils.get_logical_annotations(
202-
config, topology_mesh, init_state_fn
203-
)
172+
logical_annotations = maxtext_utils.get_logical_annotations(config, topology_mesh, init_state_fn)
204173

205174
# Shaped batch
206175
shaped_batch = maxtext_utils.get_shaped_batch(config)
@@ -217,9 +186,7 @@ def create_train_state_fn():
217186
# Collect NNX activation shardings via an abstract forward pass (must run
218187
# after get_abstract_state, which only traces __init__).
219188
if config.debug_sharding and config.pure_nnx:
220-
_collect_nnx_activation_shardings(
221-
_create_model_partial, config, topology_mesh
222-
)
189+
_collect_nnx_activation_shardings(_create_model_partial, config, topology_mesh)
223190

224191
return (
225192
shaped_train_args,
@@ -256,9 +223,7 @@ def jit_and_compile(
256223
maxtext_utils.maybe_dump_jaxpr(config, jitted, func_input_args)
257224
lowered = jitted.lower(*func_input_args, **func_input_kwargs)
258225
# Import libtpu flags as compiler options. Defaults to empty dict if string is empty.
259-
compiler_options = max_utils.parse_libtpu_flags_to_dict(
260-
config.compile_xla_flags
261-
)
226+
compiler_options = max_utils.parse_libtpu_flags_to_dict(config.compile_xla_flags)
262227
compiled = lowered.compile(compiler_options=compiler_options)
263228
return compiled
264229

@@ -293,20 +258,11 @@ def is_oom(argv: Sequence[str]) -> bool:
293258
) = get_shaped_inputs(topology_mesh, config)
294259

295260
# Update params_shardings when shard_optimizer_over_data is enabled (Zero-1)
296-
params_shardings, state_mesh_shardings = (
297-
sharding.maybe_update_params_sharding_with_opt(
298-
config, state_mesh_shardings
299-
)
300-
)
261+
params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings)
301262

302-
# When ZeRO-1 is enabled, we need to use the original params_shardings for input shardings
303-
# but keep the updated state_mesh_shardings for the optimizer state
304-
if config.shard_optimizer_over_data:
305-
input_state_mesh_shardings = state_mesh_shardings.replace(
306-
params=params_shardings
307-
)
308-
else:
309-
input_state_mesh_shardings = state_mesh_shardings
263+
input_state_mesh_shardings = sharding.build_zero1_input_state_mesh_shardings(
264+
config, state_mesh_shardings, params_shardings
265+
)
310266

311267
# Get data sharding
312268
data_sharding = sharding.get_input_data_sharding(config, topology_mesh)
@@ -355,8 +311,7 @@ def is_oom(argv: Sequence[str]) -> bool:
355311
def main(argv: Sequence[str]) -> None:
356312
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
357313
os.environ["LIBTPU_INIT_ARGS"] = (
358-
os.environ.get("LIBTPU_INIT_ARGS", "")
359-
+ " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
314+
os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
360315
)
361316
print("Starting train_compile.py...", flush=True)
362317

@@ -381,41 +336,26 @@ def main(argv: Sequence[str]) -> None:
381336
) = get_shaped_inputs(topology_mesh, config)
382337

383338
# Update params_shardings when shard_optimizer_over_data is enabled (Zero-1)
384-
params_shardings, state_mesh_shardings = (
385-
sharding.maybe_update_params_sharding_with_opt(
386-
config, state_mesh_shardings
387-
)
388-
)
339+
params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings)
389340

390-
# When ZeRO-1 is enabled, we need to use the original params_shardings for input shardings
391-
# but keep the updated state_mesh_shardings for the optimizer state
392-
if config.shard_optimizer_over_data:
393-
input_state_mesh_shardings = state_mesh_shardings.replace(
394-
params=params_shardings
395-
)
396-
else:
397-
input_state_mesh_shardings = state_mesh_shardings
341+
input_state_mesh_shardings = sharding.build_zero1_input_state_mesh_shardings(
342+
config, state_mesh_shardings, params_shardings
343+
)
398344

399345
# Get data sharding
400346
data_sharding = sharding.get_input_data_sharding(config, topology_mesh)
401347
if config.enable_diloco:
402348
# Build abstract DiLoCo state and shardings for AOT compilation
403349
abstract_state = shaped_train_args[0]
404-
diloco_state, state_mesh_shardings, inner_state_shardings = (
405-
diloco.build_abstract_diloco_state(
406-
config, abstract_state, state_mesh_shardings, topology_mesh
407-
)
350+
diloco_state, state_mesh_shardings, inner_state_shardings = diloco.build_abstract_diloco_state(
351+
config, abstract_state, state_mesh_shardings, topology_mesh
408352
)
409353
# For NNX, shaped_train_args has 2 elements (state, batch) — no rng; pass None for prng.
410-
shaped_rng_arg = (
411-
shaped_train_args[2] if len(shaped_train_args) > 2 else None
412-
)
354+
shaped_rng_arg = shaped_train_args[2] if len(shaped_train_args) > 2 else None
413355
shaped_train_args = (diloco_state, shaped_train_args[1], shaped_rng_arg)
414356

415357
# Wrap train_step with diloco
416-
train_step_partial = functools.partial(
417-
train.train_step, model, config, inner_state_shardings, params_shardings
418-
)
358+
train_step_partial = functools.partial(train.train_step, model, config, inner_state_shardings, params_shardings)
419359
train_step_fn = diloco.build_diloco_train_step(config, train_step_partial)
420360

421361
# For DiLoCo, the train_step_fn is already fully wrapped and takes (state, batch, prng)
@@ -480,10 +420,7 @@ def main(argv: Sequence[str]) -> None:
480420
if config.compiled_trainstep_file != "":
481421
print("Saving compiled object...")
482422
save_compiled(compiled, config.compiled_trainstep_file)
483-
print(
484-
"Successfully saved compiled object as"
485-
f" {config.compiled_trainstep_file}"
486-
)
423+
print("Successfully saved compiled object as" f" {config.compiled_trainstep_file}")
487424
print("Finished train_compile.py successfully!", flush=True)
488425
print(f"Cost analysis: {compiled.cost_analysis()}")
489426
print(f"Memory analysis: {compiled.memory_analysis()}")

src/maxtext/utils/sharding.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
# pylint: disable=line-too-long, disable=bare-except, consider-using-generator
16-
""" Utils that are only interesting to MaxText and sharding related. """
16+
"""Utils that are only interesting to MaxText and sharding related."""
1717

1818
from flax import linen as nn, nnx
1919

@@ -620,6 +620,34 @@ def _update_model_var(path, var):
620620
return prev_params_shardings, updated_state
621621

622622

623+
def build_zero1_input_state_mesh_shardings(config, state_mesh_shardings, params_shardings):
624+
"""Build the train-step input shardings under shard_optimizer_over_data (Zero-1).
625+
626+
Model params on input use the original pre-Zero-1 sharding (params_shardings), while the rest
627+
of the state — including the optimizer state — keeps the Zero-1 layout from state_mesh_shardings,
628+
so the optimizer state input matches its output. When shard_optimizer_over_data is False,
629+
state_mesh_shardings passes through unchanged.
630+
"""
631+
if not config.shard_optimizer_over_data:
632+
return state_mesh_shardings
633+
if not config.pure_nnx:
634+
return state_mesh_shardings.replace(params=params_shardings)
635+
# nnx.State has no .replace. tree_map below shallow-copies state_mesh_shardings preserving
636+
# nested container types; then we walk params_shardings and overwrite the matching keys under
637+
# input_state.model (the NNX home of model params).
638+
input_state = jax.tree_util.tree_map(lambda x: x, state_mesh_shardings, is_leaf=lambda x: isinstance(x, nnx.Variable))
639+
640+
def _overlay(model_node, params_node):
641+
for k, pv in params_node.items():
642+
if isinstance(pv, nnx.Variable):
643+
model_node[k] = pv
644+
elif hasattr(pv, "items"):
645+
_overlay(model_node[k], pv)
646+
647+
_overlay(input_state.model, params_shardings)
648+
return input_state
649+
650+
623651
def logical_axis_rules_pp_act_as_dp(logical_rules):
624652
"""Add stage as a physical axes before data for each rule, so stage acts just like data instead of PP.
625653
This is used when we want to pipeline only a subset of layers, and leave the rest like DP.

0 commit comments

Comments
 (0)