2222 Architecture
2323
2424 ┌─────────────────────────────────┬──────────────────────────────────────────────────────────────────────────┐
25- │ Layer │ What it does │
25+ │function │ What it does │
2626 ├─────────────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
27- │ loss_fn / eval_loss_fn │ Forward-pass + cross-entropy; called directly on an nnx.Module │
27+ │ loss_fn │ Forward-pass + cross-entropy; for both train and eval; │
28+ │ │ called directly on an nnx.Module │
2829 ├─────────────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
2930 │ train_step │ Functional step — merges (graphdef, opt_state) → runs nnx.value_and_grad │
3031 │ │ → updates optimizer → returns new nnx.State │
8081from jax .sharding import Mesh
8182
8283from maxtext .common import checkpointing , profiler
83- from maxtext .common .common_types import MODEL_MODE_TRAIN , ShardMode
84+ from maxtext .common .common_types import ShardMode
8485from maxtext .common .data_loader import create_dataloader
8586from maxtext .common .gcloud_stub import cloud_diagnostics as _cloud_diag
8687from maxtext .common .gcloud_stub import is_decoupled , vertex_tensorboard_modules
9697from maxtext .common .metric_logger import MetricLogger
9798from maxtext .configs import pyconfig
9899from maxtext .input_pipeline .input_pipeline_interface import create_data_iterator
100+ from maxtext .layers .multi_token_prediction import calculate_mtp_acceptance_rate , calculate_mtp_loss
99101from maxtext .optimizers import optimizers
100102from maxtext .utils import exceptions , max_logging , max_utils , maxtext_utils , model_creation_utils , sharding
101103from maxtext .utils .globals import EPS
107109
108110
109111# ---------------------------------------------------------------------------
110- # Loss computation
112+ # Loss computation for both train and eval
111113# ---------------------------------------------------------------------------
112114
113115
114- def loss_fn (model : nnx .Module , config , data : dict [str , jax .Array ], dropout_rng : jax .Array ):
116+ def loss_fn (model : nnx .Module , config , data : dict [str , jax .Array ], dropout_rng : jax .Array , is_train = True ):
115117 """Compute cross-entropy loss for one batch using an NNX model.
116118
117119 Args:
@@ -121,58 +123,117 @@ def loss_fn(model: nnx.Module, config, data: dict[str, jax.Array], dropout_rng:
121123 data: Batch dict with keys "inputs", "inputs_position", "inputs_segmentation",
122124 "targets", "targets_segmentation".
123125 dropout_rng: PRNG key used to seed dropout layers.
126+ is_train: True for train_step and False for eval_step.
124127
125128 Returns:
126129 (loss, aux) where loss is a scalar and aux is a dict of auxiliary metrics.
127130 """
128131 rng1 , aqt_rng = jax .random .split (dropout_rng )
129132
130133 # Trim to micro-batch size (handles per_device_batch_size < 1 cases)
131- batch = {k : v [: config .micro_batch_size_to_train_on , :] for k , v in data .items ()}
134+ # decimate proportion of data when per_device_batch_size<1
135+ if is_train :
136+ batch = {k : v [: config .micro_batch_size_to_train_on , :] for k , v in data .items ()}
137+ else :
138+ batch = {k : v [: config .micro_batch_size_to_eval_on , :] for k , v in data .items ()}
132139
140+ # Flax NNX model
133141 logits = model (
134142 decoder_input_tokens = batch ["inputs" ],
135143 decoder_positions = batch ["inputs_position" ],
136144 decoder_segment_ids = batch ["inputs_segmentation" ],
137- enable_dropout = config .enable_dropout ,
145+ encoder_images = batch ["images" ] if config .use_multimodal else None ,
146+ encoder_image_masks = batch ["image_masks" ] if config .use_multimodal and "image_masks" in batch else None ,
147+ enable_dropout = config .enable_dropout if is_train else False ,
148+ decoder_target_tokens = batch ["targets" ],
149+ decoder_target_mask = batch ["targets_segmentation" ],
138150 )
139-
151+ intermediate_outputs = {}
140152 one_hot_targets = jax .nn .one_hot (batch ["targets" ], config .vocab_size )
141153 xent , z_loss = max_utils .cross_entropy_with_logits (logits , one_hot_targets , z_loss = config .z_loss_multiplier )
142154
143- # Zero out padding positions
144- target_mask = batch ["targets_segmentation" ] != 0
145- xent = xent * target_mask
146- z_loss = z_loss * target_mask
155+ xent = nn .with_logical_constraint (xent , ("activation_embed_and_logits_batch" , "activation_length" ))
156+ z_loss = nn .with_logical_constraint (z_loss , ("activation_embed_and_logits_batch" , "activation_length" ))
147157
148- total_loss = jnp . sum ( xent )
149- total_weights = jnp . sum ( target_mask )
150- total_z_loss = jnp . sum ( z_loss ) / ( total_weights + EPS )
158+ # Mask out paddings at the end of each example.
159+ xent = xent * ( batch [ "targets_segmentation" ] != 0 )
160+ z_loss = z_loss * ( batch [ "targets_segmentation" ] != 0 )
151161
152- loss = total_loss / (total_weights + EPS )
162+ total_loss = jnp .sum (xent )
163+ total_z_loss = jnp .sum (z_loss )
164+
165+ total_weights = jnp .sum (batch ["targets_segmentation" ] != 0 )
166+ # If gradient accumulation is enabled, we don't need to divide total_loss
167+ # by total_weights and then multiply the computed gradient by total_weights,
168+ # since it's equivalent to computing the gradient from total_loss.
169+ # This simplification reduces the number of operations and makes it easier
170+ # for XLA to move all-reduce out of the gradient accumulation loop when use
171+ # Zero1+GA to reduce communication overhead.
172+ # EPS was used to avoid division by zero, but it's not needed when gradient
173+ # accumulation is enabled since there's no division.
174+ if config .gradient_accumulation_steps > 1 and not config .use_tunix_gradient_accumulation :
175+ loss = total_loss
176+ else :
177+ # When using Tunix gradient accumulation, we revert to standard normalization.
178+ # Unlike the manual accumulation path above, Tunix (via optax.MultiSteps) expects
179+ # a normalized loss for each step. It handles the accumulation state
180+ # updates and scaling internally.
181+ loss = total_loss / (total_weights + EPS )
182+
183+ # We keep z-loss normalized by total_weights.
184+ total_z_loss = total_z_loss / (total_weights + EPS )
185+
186+ # Calculate and Add MTP Loss
187+ mtp_loss = 0.0
188+ if config .mtp_num_layers > 0 and is_train :
189+ mtp_loss = calculate_mtp_loss (intermediate_outputs , config )
190+ loss += mtp_loss
191+
192+ # get MoE load balance loss
193+ moe_lb_loss = 0.0
194+ if config .num_experts > 1 :
195+ # Note: the key is affected by the model implementation
196+ possible_keys = [
197+ ("intermediates" , "decoder" , "layers" , "moe_lb_loss" ),
198+ ("intermediates" , "decoder" , "moe_layers" , "moe_lb_loss" ),
199+ ]
200+
201+ total_moe_lb_loss = 0.0
202+ found_loss = False
203+ for nested_key in possible_keys :
204+ total_moe_lb_loss = maxtext_utils .get_nested_value (intermediate_outputs , nested_key , 0.0 )
205+ if total_moe_lb_loss != 0.0 :
206+ found_loss = True
207+ break
208+
209+ if not found_loss :
210+ max_logging .debug ("\n No MoE load balance loss found. Defaulting to 0.0." )
211+
212+ moe_lb_loss = jnp .mean (jnp .array (total_moe_lb_loss ))
213+ loss += moe_lb_loss
214+
215+ # get MoE routed bias term updates
216+ moe_bias_updates = None
217+ if config .routed_bias and config .routed_bias_update_rate > 0.0 :
218+ nested_key = ("intermediates" , "decoder" , "moe_layers" , "moe_bias_updates" )
219+ moe_bias_updates = maxtext_utils .get_nested_value (intermediate_outputs , nested_key , None )
220+
221+ # Add the model's primary output to the intermediates dict so it can be used
222+ # by the acceptance rate calculation in eval_step.
223+ intermediate_outputs ["logits" ] = logits
153224
154225 aux = {
226+ "intermediate_outputs" : intermediate_outputs ,
155227 "total_loss" : total_loss ,
156228 "z_loss" : total_z_loss ,
157229 "total_weights" : total_weights ,
230+ "moe_lb_loss" : moe_lb_loss ,
231+ "moe_bias_updates" : moe_bias_updates ,
232+ "mtp_loss" : mtp_loss ,
158233 }
159234 return loss , aux
160235
161-
162- def eval_loss_fn (model : nnx .Module , config , data : dict [str , jax .Array ], dropout_rng : jax .Array ):
163- """Evaluation variant of loss_fn (no dropout, full batch size)."""
164- batch = {k : v [: config .micro_batch_size_to_eval_on , :] for k , v in data .items ()}
165-
166- logits = model (
167- decoder_input_tokens = batch ["inputs" ],
168- decoder_positions = batch ["inputs_position" ],
169- decoder_segment_ids = batch ["inputs_segmentation" ],
170- enable_dropout = False ,
171- )
172-
173- one_hot_targets = jax .nn .one_hot (batch ["targets" ], config .vocab_size )
174- xent , z_loss = max_utils .cross_entropy_with_logits (logits , one_hot_targets , z_loss = config .z_loss_multiplier )
175-
236+ # Zero out padding positions
176237 target_mask = batch ["targets_segmentation" ] != 0
177238 xent = xent * target_mask
178239 z_loss = z_loss * target_mask
@@ -227,7 +288,7 @@ def train_step(
227288 # nnx.value_and_grad differentiates only through nnx.Param variables,
228289 # keeping non-differentiable state (RNGs, cache, etc.) frozen.
229290 grad_fn = nnx .value_and_grad (loss_fn , argnums = 0 , has_aux = True )
230- (loss , aux ), raw_grads = grad_fn (model , config , data , dropout_rng )
291+ (loss , aux ), raw_grads = grad_fn (model , config , data , dropout_rng , is_train = True )
231292
232293 # Cast gradients to configured dtype before clipping / accumulation
233294 raw_grads = jax .tree .map (
@@ -280,16 +341,31 @@ def eval_step(
280341 metrics: Dict of scalar evaluation metrics.
281342 """
282343 model : nnx .Module = nnx .merge (model_graphdef , model_state )
283- loss , aux = eval_loss_fn (model , config , data , dropout_rng )
344+ loss , aux = loss_fn (model , config , data , dropout_rng , is_train = False )
345+
346+ mtp_acceptance_rate = 0.0
347+ if config .mtp_eval_target_module > 0 :
348+ mtp_acceptance_rate = calculate_mtp_acceptance_rate (aux ["intermediate_outputs" ], config )
284349
350+ total_loss = aux ["total_loss" ]
351+ z_loss = aux ["z_loss" ]
352+ total_weights = aux ["total_weights" ]
353+ moe_lb_loss = aux ["moe_lb_loss" ]
354+ mtp_loss = aux ["mtp_loss" ]
285355 metrics = {
286356 "scalar" : {
287357 "evaluation/loss" : loss ,
288- "evaluation/z_loss" : aux ["z_loss" ],
289- "evaluation/total_loss" : aux ["total_loss" ],
290- "evaluation/total_weights" : aux ["total_weights" ],
291- }
358+ "evaluation/z_loss" : z_loss ,
359+ "evaluation/total_loss" : total_loss ,
360+ "evaluation/total_weights" : total_weights ,
361+ "evaluation/moe_lb_loss" : moe_lb_loss ,
362+ "evaluation/mtp_loss" : mtp_loss ,
363+ "evaluation/mtp_acceptance_rate_percent" : mtp_acceptance_rate ,
364+ },
292365 }
366+ # if config.use_dpo:
367+ # metrics["scalar"]["evaluation/dpo_reward_accuracy"] = aux["reward_accuracy"]
368+
293369 return metrics
294370
295371
0 commit comments