9393 maybe_record_goodput ,
9494 record_goodput ,
9595)
96- from maxtext .common .metric_logger import MetricLogger
96+ from maxtext .common .metric_logger import MetricLogger , record_activation_metrics
9797from maxtext .configs import pyconfig
9898from maxtext .input_pipeline .input_pipeline_interface import create_data_iterator
9999from maxtext .layers .multi_token_prediction import calculate_mtp_acceptance_rate , calculate_mtp_loss
100100from maxtext .optimizers import optimizers
101101from maxtext .utils import exceptions , max_logging , max_utils , maxtext_utils , model_creation_utils , sharding
102102from maxtext .utils .globals import EPS
103+ from maxtext .utils .gradient_accumulation import nnx_gradient_accumulation_loss_and_grad
103104from maxtext .utils .rampup_batch import create_rampup_manager
104105
105106_diag_modules = _cloud_diag ()
@@ -127,7 +128,7 @@ def loss_fn(model: nnx.Module, config, data: dict[str, jax.Array], dropout_rng:
127128 Returns:
128129 (loss, aux) where loss is a scalar and aux is a dict of auxiliary metrics.
129130 """
130- rng1 , aqt_rng = jax .random .split (dropout_rng )
131+ # rng1, aqt_rng = jax.random.split(dropout_rng)
131132
132133 # Trim to micro-batch size (handles per_device_batch_size < 1 cases)
133134 # decimate proportion of data when per_device_batch_size<1
@@ -188,6 +189,24 @@ def loss_fn(model: nnx.Module, config, data: dict[str, jax.Array], dropout_rng:
188189 mtp_loss = calculate_mtp_loss (intermediate_outputs , config )
189190 loss += mtp_loss
190191
192+ # get indexer loss
193+ indexer_loss = 0.0
194+ if config .use_sparse_indexer and config .indexer_loss_scaling_factor > 0.0 :
195+ indexer_losses = []
196+ # Extract 'indexer_loss' from model intermediates.
197+ # We check for paths ending in ('self_attention', 'indexer_loss').
198+ # This handles varying paths caused by different layer names.
199+ for path , val in jax .tree_util .tree_leaves_with_path (intermediate_outputs ):
200+ path_keys = tuple (k .key for k in path if hasattr (k , "key" ))
201+ if path_keys [- 2 :] == ("self_attention" , "indexer_loss" ):
202+ indexer_losses .append (jnp .ravel (val ))
203+
204+ if indexer_losses :
205+ indexer_loss = jnp .mean (jnp .concatenate (indexer_losses ))
206+ loss += indexer_loss
207+ else :
208+ max_logging .debug ("No indexer loss found." )
209+
191210 # get MoE load balance loss
192211 moe_lb_loss = 0.0
193212 if config .num_experts > 1 :
@@ -227,29 +246,12 @@ def loss_fn(model: nnx.Module, config, data: dict[str, jax.Array], dropout_rng:
227246 "z_loss" : total_z_loss ,
228247 "total_weights" : total_weights ,
229248 "moe_lb_loss" : moe_lb_loss ,
249+ "indexer_loss" : indexer_loss ,
230250 "moe_bias_updates" : moe_bias_updates ,
231251 "mtp_loss" : mtp_loss ,
232252 }
233253 return loss , aux
234254
235- # Zero out padding positions
236- target_mask = batch ["targets_segmentation" ] != 0
237- xent = xent * target_mask
238- z_loss = z_loss * target_mask
239-
240- total_loss = jnp .sum (xent )
241- total_weights = jnp .sum (target_mask )
242- total_z_loss = jnp .sum (z_loss ) / (total_weights + EPS )
243-
244- loss = total_loss / (total_weights + EPS )
245-
246- aux = {
247- "total_loss" : total_loss ,
248- "z_loss" : total_z_loss ,
249- "total_weights" : total_weights ,
250- }
251- return loss , aux
252-
253255
254256# ---------------------------------------------------------------------------
255257# Train / eval steps (purely functional, JIT-able)
@@ -282,41 +284,139 @@ def train_step(
282284 """
283285 model : nnx .Module = nnx .merge (model_graphdef , model_state )
284286 optimizer : nnx .Optimizer = nnx .merge (opt_graphdef , opt_state )
287+ if config .use_dpo :
288+ # Need impl on NNX
289+ pass
290+ # state, reference_params = _split_dpo_state(state)
291+ # state_mesh_shardings, reference_params_sharding = _split_dpo_state(state_mesh_shardings)
292+ # extra_dpo_args = [reference_params]
293+ # loss_fn = dpo_loss_fn
285294
286295 # Compute loss and gradients w.r.t. model parameters.
287296 # nnx.value_and_grad differentiates only through nnx.Param variables,
288297 # keeping non-differentiable state (RNGs, cache, etc.) frozen.
289- grad_fn = nnx .value_and_grad (loss_fn , argnums = 0 , has_aux = True )
290- (loss , aux ), raw_grads = grad_fn (model , config , data , dropout_rng , is_train = True )
298+ if config .gradient_accumulation_steps > 1 :
299+ loss , aux , raw_grads = nnx_gradient_accumulation_loss_and_grad (loss_fn , model , config , data , dropout_rng )
300+ else :
301+ if config .optimizer_memory_host_offload :
302+ # Need impl on NNX
303+ pass
304+ # if config.use_dpo:
305+ # reference_params = jax.device_put(
306+ # reference_params,
307+ # max_utils.with_memory_kind(reference_params_sharding, "device"),
308+ # )
309+ # extra_dpo_args = [reference_params]
310+ if config .shard_optimizer_over_data :
311+ # Need impl on NNX
312+ pass
313+ # params = jax.tree.map(
314+ # functools.partial(sharding.maybe_shard_with_name, shard_mode=config.shard_mode),
315+ # params,
316+ # params_shardings,
317+ # )
318+ grad_fn = nnx .value_and_grad (loss_fn , argnums = 0 , has_aux = True )
319+ (loss , aux ), raw_grads = grad_fn (model , config , data , dropout_rng , is_train = True )
291320
292321 # Cast gradients to configured dtype before clipping / accumulation
293322 raw_grads = jax .tree .map (
294323 lambda x : x .astype (config .grad_dtype ) if x .dtype == jnp .float32 else x ,
295324 raw_grads ,
296325 )
326+ intermediate_outputs = aux ["intermediate_outputs" ]
327+ total_weights = aux ["total_weights" ]
328+ moe_lb_loss = aux ["moe_lb_loss" ]
329+ indexer_loss = aux ["indexer_loss" ]
330+ z_loss = aux ["z_loss" ]
331+ moe_bias_updates = aux ["moe_bias_updates" ]
332+ mtp_loss = aux ["mtp_loss" ]
297333
298334 # Gradient clipping (implemented directly to avoid Linen TrainState dependency)
299335 if config .gradient_clipping_threshold > 0 :
300336 clip_tx = optax .clip_by_global_norm (config .gradient_clipping_threshold )
301337 grads , _ = clip_tx .update (raw_grads , clip_tx .init (raw_grads ), None )
302338 else :
303339 grads = raw_grads
340+ if config .optimizer_memory_host_offload :
341+ # Need impl on NNX
342+ pass
343+ # state = state.replace(
344+ # opt_state=jax.device_put(
345+ # state.opt_state,
346+ # jax.tree_util.tree_map(
347+ # lambda x: x.with_memory_kind(kind="device"),
348+ # state_mesh_shardings.opt_state,
349+ # ),
350+ # )
351+ # )
352+ # Move all parameters to device before optimizer update
353+ if config .parameter_memory_host_offload :
354+ max_logging .log ("\n Moving all parameters to device before optimizer update" )
355+ # Need impl on NNX
356+ # def move(path, value):
357+ # max_logging.log(f"train.py: Moving f{path} to device")
358+ # return value.with_memory_kind(kind="device")
359+
360+ # state = state.replace(
361+ # params=jax.device_put(
362+ # state.params,
363+ # jax.tree_util.tree_map_with_path(move, state_mesh_shardings.params),
364+ # )
365+ # )
304366
305367 # NNX 0.11+: update takes (model, grads) explicitly.
306368 optimizer .update (model , grads )
307369
308370 new_model_state = nnx .state (model )
309371 new_opt_state = nnx .state (optimizer )
310372
373+ # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family
374+ if config .routed_bias and config .routed_bias_update_rate > 0.0 and moe_bias_updates is not None :
375+ # Need impl on NNX
376+ pass
377+ # target_path = ("params", "decoder", "moe_layers", "DeepSeekMoeBlock_0", "MoeBlock_0", "gate", "bias")
378+ # Flax 'sow' returns a tuple, so we take the first element [0].
379+ # Updates the shape to be aligned with state.
380+ # moe_bias_updates = jnp.array(moe_bias_updates[0]).transpose()
381+ # new_state = maxtext_utils.update_state_param(new_state, target_path, moe_bias_updates)
382+
311383 scalar_metrics = {
312384 "learning/loss" : loss ,
313- "learning/z_loss" : aux [ " z_loss" ] ,
314- "learning/total_weights " : aux [ "total_weights" ] ,
315- "learning/grad_norm " : max_utils . l2norm_pytree ( grads ) ,
316- "learning/raw_grad_norm " : max_utils . l2norm_pytree ( raw_grads ) ,
317- "learning/param_norm " : max_utils . l2norm_pytree ( nnx . state ( model , nnx . Param )) ,
385+ "learning/z_loss" : z_loss ,
386+ "learning/moe_lb_loss " : moe_lb_loss ,
387+ "learning/indexer_loss " : indexer_loss ,
388+ "learning/mtp_loss " : mtp_loss ,
389+ "learning/total_weights " : total_weights ,
318390 }
319- metrics = {"scalar" : scalar_metrics , "scalars" : {}}
391+ if config .use_qk_clip :
392+ # Apply QK-Clip
393+ # Need impl on NNX
394+ pass
395+ # new_state = qk_clip_utils.apply_qk_clip(new_state, intermediate_outputs, config)
396+
397+ # Report max_logits metric
398+ # global_max_logit = qk_clip_utils.calculate_max_logit_metric(intermediate_outputs)
399+ # if global_max_logit is not None:
400+ # scalar_metrics["learning/max_logits"] = global_max_logit
401+
402+ if not config .optimizer_memory_host_offload :
403+ scalar_metrics ["learning/grad_norm" ] = max_utils .l2norm_pytree (grads )
404+ scalar_metrics ["learning/raw_grad_norm" ] = max_utils .l2norm_pytree (raw_grads )
405+ scalar_metrics ["learning/param_norm" ] = max_utils .l2norm_pytree (nnx .state (model , nnx .Param ))
406+ if config .use_dpo :
407+ scalar_metrics ["learning/dpo_reward_accuracy" ] = aux ["reward_accuracy" ]
408+ metrics = {
409+ "scalar" : scalar_metrics ,
410+ "scalars" : {},
411+ }
412+
413+ if config .record_internal_nn_metrics :
414+ record_activation_metrics (metrics , intermediate_outputs , config )
415+
416+ if config .use_dpo :
417+ # Need impl on NNX
418+ pass
419+ # new_state = _merge_dpo_state(new_state, reference_params)
320420 return (new_model_state , new_opt_state ), metrics
321421
322422
@@ -350,6 +450,7 @@ def eval_step(
350450 z_loss = aux ["z_loss" ]
351451 total_weights = aux ["total_weights" ]
352452 moe_lb_loss = aux ["moe_lb_loss" ]
453+ indexer_loss = aux ["indexer_loss" ]
353454 mtp_loss = aux ["mtp_loss" ]
354455 metrics = {
355456 "scalar" : {
@@ -358,6 +459,7 @@ def eval_step(
358459 "evaluation/total_loss" : total_loss ,
359460 "evaluation/total_weights" : total_weights ,
360461 "evaluation/moe_lb_loss" : moe_lb_loss ,
462+ "evaluation/indexer_loss" : indexer_loss ,
361463 "evaluation/mtp_loss" : mtp_loss ,
362464 "evaluation/mtp_acceptance_rate_percent" : mtp_acceptance_rate ,
363465 },
@@ -415,8 +517,8 @@ def _create_and_shard_optimizer(model: nnx.Module, config, mesh: Mesh):
415517 _ , opt_state = nnx .split (optimizer )
416518
417519 @functools .partial (jax .jit , out_shardings = (model_shardings , opt_shardings ))
418- def shard_states (ms , os ):
419- return ms , os
520+ def shard_states (mshard , oshard ):
521+ return mshard , oshard
420522
421523 with mesh :
422524 model_state , opt_state = shard_states (model_state , opt_state )
@@ -608,7 +710,9 @@ def train_loop(config, recorder, state=None):
608710 shaped_batch = maxtext_utils .get_shaped_batch (config )
609711 init_rng = jax .random .PRNGKey (config .init_weights_seed )
610712 example_rng = jax .jit (jax .random .fold_in )(init_rng , 0 )
611- if config .compiled_trainstep_file == "" :
713+ # Need imple below func on NNX
714+ # maxtext_utils.maybe_dump_jaxpr(config, p_train_step, (model_state, opt_state, shaped_batch, example_rng))
715+ if config .compiled_trainstep_file == "" : # compile only when there is no pre-compiled file loaded
612716 compiled = p_train_step .lower (model_state , opt_state , shaped_batch , example_rng ).compile ()
613717 compiled_stats = compiled .memory_analysis ()
614718 max_utils .print_compiled_memory_stats (compiled_stats )
@@ -624,14 +728,14 @@ def train_loop(config, recorder, state=None):
624728 _job_completed_gracefully = False
625729 try :
626730 last_step_completion = datetime .datetime .now ()
731+ max_logging .info (f"Entering train loop from start_step={ start_step } " )
627732
628733 for step in np .arange (start_step , config .steps ):
629734 prof .maybe_activate_profiler (step , opt_state )
630735
631736 with jax .profiler .StepTraceAnnotation ("train" , step_num = step ):
632737 example_batch = data_loader .load_next_batch (rampup_manager = rampup_manager )
633738 nextrng = jax .jit (jax .random .fold_in )(init_rng , step )
634-
635739 with maybe_record_goodput (recorder , GoodputEvent .STEP , step ):
636740 with jax .set_mesh (mesh ), nn_partitioning .axis_rules (config .logical_axis_rules ):
637741 (model_state , opt_state ), metrics = p_train_step (model_state , opt_state , example_batch , nextrng )
@@ -649,15 +753,18 @@ def train_loop(config, recorder, state=None):
649753 and (step + 1 ) % config .eval_interval == 0
650754 ):
651755 assert eval_data_iterator
756+ # Explicitly reset the eval iterator and counters before starting the eval loop
652757 eval_data_iterator .reset ()
653758 metric_logger .reset_eval_metrics ()
759+
654760 eval_step_count = 0
655761 for eval_batch in eval_data_iterator :
656762 if config .eval_steps > 0 and eval_step_count >= config .eval_steps :
657763 break
658764 with jax .set_mesh (mesh ), nn_partitioning .axis_rules (config .logical_axis_rules ):
659765 eval_metrics = p_eval_step (model_state , eval_batch , nextrng )
660766 metric_logger .record_eval_metrics (step , metrics = eval_metrics )
767+ max_logging .log (f"Completed eval step { eval_step_count } " )
661768 eval_step_count += 1
662769
663770 metric_logger .record_eval_metrics (step , eval_step_count = eval_step_count )
@@ -678,6 +785,7 @@ def train_loop(config, recorder, state=None):
678785 checkpoint_manager , model_state , opt_state , config , data_iterator , step = int (config .steps - 1 )
679786 )
680787 if checkpoint_manager is not None :
788+ # in case the last checkpoint_period checkpoint is still in progress
681789 checkpoint_manager .wait_until_finished ()
682790
683791 _job_completed_gracefully = True
@@ -727,8 +835,10 @@ def initialize(argv: Sequence[str]):
727835 if config .use_vertex_tensorboard or os .environ .get ("UPLOAD_DATA_TO_TENSORBOARD" ):
728836 vertex_tensorboard_manager .configure_vertex_tensorboard (config )
729837
838+ # Create the Goodput recorder
730839 recorder = create_goodput_recorder (config )
731840
841+ # Stack traces configurations
732842 debug_config = debug_configuration .DebugConfig (
733843 stack_trace_config = stack_trace_configuration .StackTraceConfig (
734844 collect_stack_trace = config .collect_stack_trace ,
@@ -741,13 +851,20 @@ def initialize(argv: Sequence[str]):
741851
742852
743853def run (config , recorder , diagnostic_config ):
744- """Run the NNX training job."""
854+ """Run the NNX training job.
855+
856+ In decoupled mode (DECOUPLE_GCLOUD=TRUE) cloud diagnostics may be stubbed; if so, skip wrapping.
857+ """
858+ # Use nullcontext when diagnostics are stubbed or in decoupled mode
745859 diagnostics_context = (
746860 contextlib .nullcontext ()
747861 if is_decoupled () or getattr (diagnostic , "__class__" , None ).__name__ == "_StubDiag"
748862 else diagnostic .diagnose (diagnostic_config )
749863 )
750864
865+ if is_decoupled () or getattr (diagnostic , "__class__" , None ).__name__ == "_StubDiag" :
866+ max_logging .log ("[DECOUPLED NO-OP] skipping cloud diagnostics wrapper." )
867+
751868 with (
752869 diagnostics_context ,
753870 max_utils .maybe_get_transformer_engine_context (config ),
0 commit comments