Skip to content

Commit 9f9629c

Browse files
author
Charles Li
committed
Support gradient_accumulation and align to latest train.py
1 parent c311ead commit 9f9629c

2 files changed

Lines changed: 272 additions & 33 deletions

File tree

src/maxtext/trainers/pre_train/nnx_train.py

Lines changed: 150 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,14 @@
9393
maybe_record_goodput,
9494
record_goodput,
9595
)
96-
from maxtext.common.metric_logger import MetricLogger
96+
from maxtext.common.metric_logger import MetricLogger, record_activation_metrics
9797
from maxtext.configs import pyconfig
9898
from maxtext.input_pipeline.input_pipeline_interface import create_data_iterator
9999
from maxtext.layers.multi_token_prediction import calculate_mtp_acceptance_rate, calculate_mtp_loss
100100
from maxtext.optimizers import optimizers
101101
from maxtext.utils import exceptions, max_logging, max_utils, maxtext_utils, model_creation_utils, sharding
102102
from maxtext.utils.globals import EPS
103+
from maxtext.utils.gradient_accumulation import nnx_gradient_accumulation_loss_and_grad
103104
from maxtext.utils.rampup_batch import create_rampup_manager
104105

105106
_diag_modules = _cloud_diag()
@@ -127,7 +128,7 @@ def loss_fn(model: nnx.Module, config, data: dict[str, jax.Array], dropout_rng:
127128
Returns:
128129
(loss, aux) where loss is a scalar and aux is a dict of auxiliary metrics.
129130
"""
130-
rng1, aqt_rng = jax.random.split(dropout_rng)
131+
# rng1, aqt_rng = jax.random.split(dropout_rng)
131132

132133
# Trim to micro-batch size (handles per_device_batch_size < 1 cases)
133134
# decimate proportion of data when per_device_batch_size<1
@@ -188,6 +189,24 @@ def loss_fn(model: nnx.Module, config, data: dict[str, jax.Array], dropout_rng:
188189
mtp_loss = calculate_mtp_loss(intermediate_outputs, config)
189190
loss += mtp_loss
190191

192+
# get indexer loss
193+
indexer_loss = 0.0
194+
if config.use_sparse_indexer and config.indexer_loss_scaling_factor > 0.0:
195+
indexer_losses = []
196+
# Extract 'indexer_loss' from model intermediates.
197+
# We check for paths ending in ('self_attention', 'indexer_loss').
198+
# This handles varying paths caused by different layer names.
199+
for path, val in jax.tree_util.tree_leaves_with_path(intermediate_outputs):
200+
path_keys = tuple(k.key for k in path if hasattr(k, "key"))
201+
if path_keys[-2:] == ("self_attention", "indexer_loss"):
202+
indexer_losses.append(jnp.ravel(val))
203+
204+
if indexer_losses:
205+
indexer_loss = jnp.mean(jnp.concatenate(indexer_losses))
206+
loss += indexer_loss
207+
else:
208+
max_logging.debug("No indexer loss found.")
209+
191210
# get MoE load balance loss
192211
moe_lb_loss = 0.0
193212
if config.num_experts > 1:
@@ -227,29 +246,12 @@ def loss_fn(model: nnx.Module, config, data: dict[str, jax.Array], dropout_rng:
227246
"z_loss": total_z_loss,
228247
"total_weights": total_weights,
229248
"moe_lb_loss": moe_lb_loss,
249+
"indexer_loss": indexer_loss,
230250
"moe_bias_updates": moe_bias_updates,
231251
"mtp_loss": mtp_loss,
232252
}
233253
return loss, aux
234254

235-
# Zero out padding positions
236-
target_mask = batch["targets_segmentation"] != 0
237-
xent = xent * target_mask
238-
z_loss = z_loss * target_mask
239-
240-
total_loss = jnp.sum(xent)
241-
total_weights = jnp.sum(target_mask)
242-
total_z_loss = jnp.sum(z_loss) / (total_weights + EPS)
243-
244-
loss = total_loss / (total_weights + EPS)
245-
246-
aux = {
247-
"total_loss": total_loss,
248-
"z_loss": total_z_loss,
249-
"total_weights": total_weights,
250-
}
251-
return loss, aux
252-
253255

254256
# ---------------------------------------------------------------------------
255257
# Train / eval steps (purely functional, JIT-able)
@@ -282,41 +284,139 @@ def train_step(
282284
"""
283285
model: nnx.Module = nnx.merge(model_graphdef, model_state)
284286
optimizer: nnx.Optimizer = nnx.merge(opt_graphdef, opt_state)
287+
if config.use_dpo:
288+
# Need impl on NNX
289+
pass
290+
# state, reference_params = _split_dpo_state(state)
291+
# state_mesh_shardings, reference_params_sharding = _split_dpo_state(state_mesh_shardings)
292+
# extra_dpo_args = [reference_params]
293+
# loss_fn = dpo_loss_fn
285294

286295
# Compute loss and gradients w.r.t. model parameters.
287296
# nnx.value_and_grad differentiates only through nnx.Param variables,
288297
# keeping non-differentiable state (RNGs, cache, etc.) frozen.
289-
grad_fn = nnx.value_and_grad(loss_fn, argnums=0, has_aux=True)
290-
(loss, aux), raw_grads = grad_fn(model, config, data, dropout_rng, is_train=True)
298+
if config.gradient_accumulation_steps > 1:
299+
loss, aux, raw_grads = nnx_gradient_accumulation_loss_and_grad(loss_fn, model, config, data, dropout_rng)
300+
else:
301+
if config.optimizer_memory_host_offload:
302+
# Need impl on NNX
303+
pass
304+
# if config.use_dpo:
305+
# reference_params = jax.device_put(
306+
# reference_params,
307+
# max_utils.with_memory_kind(reference_params_sharding, "device"),
308+
# )
309+
# extra_dpo_args = [reference_params]
310+
if config.shard_optimizer_over_data:
311+
# Need impl on NNX
312+
pass
313+
# params = jax.tree.map(
314+
# functools.partial(sharding.maybe_shard_with_name, shard_mode=config.shard_mode),
315+
# params,
316+
# params_shardings,
317+
# )
318+
grad_fn = nnx.value_and_grad(loss_fn, argnums=0, has_aux=True)
319+
(loss, aux), raw_grads = grad_fn(model, config, data, dropout_rng, is_train=True)
291320

292321
# Cast gradients to configured dtype before clipping / accumulation
293322
raw_grads = jax.tree.map(
294323
lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x,
295324
raw_grads,
296325
)
326+
intermediate_outputs = aux["intermediate_outputs"]
327+
total_weights = aux["total_weights"]
328+
moe_lb_loss = aux["moe_lb_loss"]
329+
indexer_loss = aux["indexer_loss"]
330+
z_loss = aux["z_loss"]
331+
moe_bias_updates = aux["moe_bias_updates"]
332+
mtp_loss = aux["mtp_loss"]
297333

298334
# Gradient clipping (implemented directly to avoid Linen TrainState dependency)
299335
if config.gradient_clipping_threshold > 0:
300336
clip_tx = optax.clip_by_global_norm(config.gradient_clipping_threshold)
301337
grads, _ = clip_tx.update(raw_grads, clip_tx.init(raw_grads), None)
302338
else:
303339
grads = raw_grads
340+
if config.optimizer_memory_host_offload:
341+
# Need impl on NNX
342+
pass
343+
# state = state.replace(
344+
# opt_state=jax.device_put(
345+
# state.opt_state,
346+
# jax.tree_util.tree_map(
347+
# lambda x: x.with_memory_kind(kind="device"),
348+
# state_mesh_shardings.opt_state,
349+
# ),
350+
# )
351+
# )
352+
# Move all parameters to device before optimizer update
353+
if config.parameter_memory_host_offload:
354+
max_logging.log("\nMoving all parameters to device before optimizer update")
355+
# Need impl on NNX
356+
# def move(path, value):
357+
# max_logging.log(f"train.py: Moving f{path} to device")
358+
# return value.with_memory_kind(kind="device")
359+
360+
# state = state.replace(
361+
# params=jax.device_put(
362+
# state.params,
363+
# jax.tree_util.tree_map_with_path(move, state_mesh_shardings.params),
364+
# )
365+
# )
304366

305367
# NNX 0.11+: update takes (model, grads) explicitly.
306368
optimizer.update(model, grads)
307369

308370
new_model_state = nnx.state(model)
309371
new_opt_state = nnx.state(optimizer)
310372

373+
# Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family
374+
if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None:
375+
# Need impl on NNX
376+
pass
377+
# target_path = ("params", "decoder", "moe_layers", "DeepSeekMoeBlock_0", "MoeBlock_0", "gate", "bias")
378+
# Flax 'sow' returns a tuple, so we take the first element [0].
379+
# Updates the shape to be aligned with state.
380+
# moe_bias_updates = jnp.array(moe_bias_updates[0]).transpose()
381+
# new_state = maxtext_utils.update_state_param(new_state, target_path, moe_bias_updates)
382+
311383
scalar_metrics = {
312384
"learning/loss": loss,
313-
"learning/z_loss": aux["z_loss"],
314-
"learning/total_weights": aux["total_weights"],
315-
"learning/grad_norm": max_utils.l2norm_pytree(grads),
316-
"learning/raw_grad_norm": max_utils.l2norm_pytree(raw_grads),
317-
"learning/param_norm": max_utils.l2norm_pytree(nnx.state(model, nnx.Param)),
385+
"learning/z_loss": z_loss,
386+
"learning/moe_lb_loss": moe_lb_loss,
387+
"learning/indexer_loss": indexer_loss,
388+
"learning/mtp_loss": mtp_loss,
389+
"learning/total_weights": total_weights,
318390
}
319-
metrics = {"scalar": scalar_metrics, "scalars": {}}
391+
if config.use_qk_clip:
392+
# Apply QK-Clip
393+
# Need impl on NNX
394+
pass
395+
# new_state = qk_clip_utils.apply_qk_clip(new_state, intermediate_outputs, config)
396+
397+
# Report max_logits metric
398+
# global_max_logit = qk_clip_utils.calculate_max_logit_metric(intermediate_outputs)
399+
# if global_max_logit is not None:
400+
# scalar_metrics["learning/max_logits"] = global_max_logit
401+
402+
if not config.optimizer_memory_host_offload:
403+
scalar_metrics["learning/grad_norm"] = max_utils.l2norm_pytree(grads)
404+
scalar_metrics["learning/raw_grad_norm"] = max_utils.l2norm_pytree(raw_grads)
405+
scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(nnx.state(model, nnx.Param))
406+
if config.use_dpo:
407+
scalar_metrics["learning/dpo_reward_accuracy"] = aux["reward_accuracy"]
408+
metrics = {
409+
"scalar": scalar_metrics,
410+
"scalars": {},
411+
}
412+
413+
if config.record_internal_nn_metrics:
414+
record_activation_metrics(metrics, intermediate_outputs, config)
415+
416+
if config.use_dpo:
417+
# Need impl on NNX
418+
pass
419+
# new_state = _merge_dpo_state(new_state, reference_params)
320420
return (new_model_state, new_opt_state), metrics
321421

322422

@@ -350,6 +450,7 @@ def eval_step(
350450
z_loss = aux["z_loss"]
351451
total_weights = aux["total_weights"]
352452
moe_lb_loss = aux["moe_lb_loss"]
453+
indexer_loss = aux["indexer_loss"]
353454
mtp_loss = aux["mtp_loss"]
354455
metrics = {
355456
"scalar": {
@@ -358,6 +459,7 @@ def eval_step(
358459
"evaluation/total_loss": total_loss,
359460
"evaluation/total_weights": total_weights,
360461
"evaluation/moe_lb_loss": moe_lb_loss,
462+
"evaluation/indexer_loss": indexer_loss,
361463
"evaluation/mtp_loss": mtp_loss,
362464
"evaluation/mtp_acceptance_rate_percent": mtp_acceptance_rate,
363465
},
@@ -415,8 +517,8 @@ def _create_and_shard_optimizer(model: nnx.Module, config, mesh: Mesh):
415517
_, opt_state = nnx.split(optimizer)
416518

417519
@functools.partial(jax.jit, out_shardings=(model_shardings, opt_shardings))
418-
def shard_states(ms, os):
419-
return ms, os
520+
def shard_states(mshard, oshard):
521+
return mshard, oshard
420522

421523
with mesh:
422524
model_state, opt_state = shard_states(model_state, opt_state)
@@ -608,7 +710,9 @@ def train_loop(config, recorder, state=None):
608710
shaped_batch = maxtext_utils.get_shaped_batch(config)
609711
init_rng = jax.random.PRNGKey(config.init_weights_seed)
610712
example_rng = jax.jit(jax.random.fold_in)(init_rng, 0)
611-
if config.compiled_trainstep_file == "":
713+
# Need imple below func on NNX
714+
# maxtext_utils.maybe_dump_jaxpr(config, p_train_step, (model_state, opt_state, shaped_batch, example_rng))
715+
if config.compiled_trainstep_file == "": # compile only when there is no pre-compiled file loaded
612716
compiled = p_train_step.lower(model_state, opt_state, shaped_batch, example_rng).compile()
613717
compiled_stats = compiled.memory_analysis()
614718
max_utils.print_compiled_memory_stats(compiled_stats)
@@ -624,14 +728,14 @@ def train_loop(config, recorder, state=None):
624728
_job_completed_gracefully = False
625729
try:
626730
last_step_completion = datetime.datetime.now()
731+
max_logging.info(f"Entering train loop from start_step={start_step}")
627732

628733
for step in np.arange(start_step, config.steps):
629734
prof.maybe_activate_profiler(step, opt_state)
630735

631736
with jax.profiler.StepTraceAnnotation("train", step_num=step):
632737
example_batch = data_loader.load_next_batch(rampup_manager=rampup_manager)
633738
nextrng = jax.jit(jax.random.fold_in)(init_rng, step)
634-
635739
with maybe_record_goodput(recorder, GoodputEvent.STEP, step):
636740
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
637741
(model_state, opt_state), metrics = p_train_step(model_state, opt_state, example_batch, nextrng)
@@ -649,15 +753,18 @@ def train_loop(config, recorder, state=None):
649753
and (step + 1) % config.eval_interval == 0
650754
):
651755
assert eval_data_iterator
756+
# Explicitly reset the eval iterator and counters before starting the eval loop
652757
eval_data_iterator.reset()
653758
metric_logger.reset_eval_metrics()
759+
654760
eval_step_count = 0
655761
for eval_batch in eval_data_iterator:
656762
if config.eval_steps > 0 and eval_step_count >= config.eval_steps:
657763
break
658764
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
659765
eval_metrics = p_eval_step(model_state, eval_batch, nextrng)
660766
metric_logger.record_eval_metrics(step, metrics=eval_metrics)
767+
max_logging.log(f"Completed eval step {eval_step_count}")
661768
eval_step_count += 1
662769

663770
metric_logger.record_eval_metrics(step, eval_step_count=eval_step_count)
@@ -678,6 +785,7 @@ def train_loop(config, recorder, state=None):
678785
checkpoint_manager, model_state, opt_state, config, data_iterator, step=int(config.steps - 1)
679786
)
680787
if checkpoint_manager is not None:
788+
# in case the last checkpoint_period checkpoint is still in progress
681789
checkpoint_manager.wait_until_finished()
682790

683791
_job_completed_gracefully = True
@@ -727,8 +835,10 @@ def initialize(argv: Sequence[str]):
727835
if config.use_vertex_tensorboard or os.environ.get("UPLOAD_DATA_TO_TENSORBOARD"):
728836
vertex_tensorboard_manager.configure_vertex_tensorboard(config)
729837

838+
# Create the Goodput recorder
730839
recorder = create_goodput_recorder(config)
731840

841+
# Stack traces configurations
732842
debug_config = debug_configuration.DebugConfig(
733843
stack_trace_config=stack_trace_configuration.StackTraceConfig(
734844
collect_stack_trace=config.collect_stack_trace,
@@ -741,13 +851,20 @@ def initialize(argv: Sequence[str]):
741851

742852

743853
def run(config, recorder, diagnostic_config):
744-
"""Run the NNX training job."""
854+
"""Run the NNX training job.
855+
856+
In decoupled mode (DECOUPLE_GCLOUD=TRUE) cloud diagnostics may be stubbed; if so, skip wrapping.
857+
"""
858+
# Use nullcontext when diagnostics are stubbed or in decoupled mode
745859
diagnostics_context = (
746860
contextlib.nullcontext()
747861
if is_decoupled() or getattr(diagnostic, "__class__", None).__name__ == "_StubDiag"
748862
else diagnostic.diagnose(diagnostic_config)
749863
)
750864

865+
if is_decoupled() or getattr(diagnostic, "__class__", None).__name__ == "_StubDiag":
866+
max_logging.log("[DECOUPLED NO-OP] skipping cloud diagnostics wrapper.")
867+
751868
with (
752869
diagnostics_context,
753870
max_utils.maybe_get_transformer_engine_context(config),

0 commit comments

Comments
 (0)