Skip to content

Commit 0a9b2fa

Browse files
author
Charles Li
committed
Combine loss_fn for both train and evel
1 parent 7561720 commit 0a9b2fa

1 file changed

Lines changed: 113 additions & 37 deletions

File tree

src/maxtext/trainers/pre_train/nnx_train.py

Lines changed: 113 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222
Architecture
2323
2424
┌─────────────────────────────────┬──────────────────────────────────────────────────────────────────────────┐
25-
Layer │ What it does │
25+
function │ What it does │
2626
├─────────────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
27-
│ loss_fn / eval_loss_fn │ Forward-pass + cross-entropy; called directly on an nnx.Module │
27+
│ loss_fn │ Forward-pass + cross-entropy; for both train and eval; │
28+
│ │ called directly on an nnx.Module │
2829
├─────────────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
2930
│ train_step │ Functional step — merges (graphdef, opt_state) → runs nnx.value_and_grad │
3031
│ │ → updates optimizer → returns new nnx.State │
@@ -80,7 +81,7 @@
8081
from jax.sharding import Mesh
8182

8283
from maxtext.common import checkpointing, profiler
83-
from maxtext.common.common_types import MODEL_MODE_TRAIN, ShardMode
84+
from maxtext.common.common_types import ShardMode
8485
from maxtext.common.data_loader import create_dataloader
8586
from maxtext.common.gcloud_stub import cloud_diagnostics as _cloud_diag
8687
from maxtext.common.gcloud_stub import is_decoupled, vertex_tensorboard_modules
@@ -96,6 +97,7 @@
9697
from maxtext.common.metric_logger import MetricLogger
9798
from maxtext.configs import pyconfig
9899
from maxtext.input_pipeline.input_pipeline_interface import create_data_iterator
100+
from maxtext.layers.multi_token_prediction import calculate_mtp_acceptance_rate, calculate_mtp_loss
99101
from maxtext.optimizers import optimizers
100102
from maxtext.utils import exceptions, max_logging, max_utils, maxtext_utils, model_creation_utils, sharding
101103
from maxtext.utils.globals import EPS
@@ -107,11 +109,11 @@
107109

108110

109111
# ---------------------------------------------------------------------------
110-
# Loss computation
112+
# Loss computation for both train and eval
111113
# ---------------------------------------------------------------------------
112114

113115

114-
def loss_fn(model: nnx.Module, config, data: dict[str, jax.Array], dropout_rng: jax.Array):
116+
def loss_fn(model: nnx.Module, config, data: dict[str, jax.Array], dropout_rng: jax.Array, is_train=True):
115117
"""Compute cross-entropy loss for one batch using an NNX model.
116118
117119
Args:
@@ -121,58 +123,117 @@ def loss_fn(model: nnx.Module, config, data: dict[str, jax.Array], dropout_rng:
121123
data: Batch dict with keys "inputs", "inputs_position", "inputs_segmentation",
122124
"targets", "targets_segmentation".
123125
dropout_rng: PRNG key used to seed dropout layers.
126+
is_train: True for train_step and False for eval_step.
124127
125128
Returns:
126129
(loss, aux) where loss is a scalar and aux is a dict of auxiliary metrics.
127130
"""
128131
rng1, aqt_rng = jax.random.split(dropout_rng)
129132

130133
# Trim to micro-batch size (handles per_device_batch_size < 1 cases)
131-
batch = {k: v[: config.micro_batch_size_to_train_on, :] for k, v in data.items()}
134+
# decimate proportion of data when per_device_batch_size<1
135+
if is_train:
136+
batch = {k: v[: config.micro_batch_size_to_train_on, :] for k, v in data.items()}
137+
else:
138+
batch = {k: v[: config.micro_batch_size_to_eval_on, :] for k, v in data.items()}
132139

140+
# Flax NNX model
133141
logits = model(
134142
decoder_input_tokens=batch["inputs"],
135143
decoder_positions=batch["inputs_position"],
136144
decoder_segment_ids=batch["inputs_segmentation"],
137-
enable_dropout=config.enable_dropout,
145+
encoder_images=batch["images"] if config.use_multimodal else None,
146+
encoder_image_masks=batch["image_masks"] if config.use_multimodal and "image_masks" in batch else None,
147+
enable_dropout=config.enable_dropout if is_train else False,
148+
decoder_target_tokens=batch["targets"],
149+
decoder_target_mask=batch["targets_segmentation"],
138150
)
139-
151+
intermediate_outputs = {}
140152
one_hot_targets = jax.nn.one_hot(batch["targets"], config.vocab_size)
141153
xent, z_loss = max_utils.cross_entropy_with_logits(logits, one_hot_targets, z_loss=config.z_loss_multiplier)
142154

143-
# Zero out padding positions
144-
target_mask = batch["targets_segmentation"] != 0
145-
xent = xent * target_mask
146-
z_loss = z_loss * target_mask
155+
xent = nn.with_logical_constraint(xent, ("activation_embed_and_logits_batch", "activation_length"))
156+
z_loss = nn.with_logical_constraint(z_loss, ("activation_embed_and_logits_batch", "activation_length"))
147157

148-
total_loss = jnp.sum(xent)
149-
total_weights = jnp.sum(target_mask)
150-
total_z_loss = jnp.sum(z_loss) / (total_weights + EPS)
158+
# Mask out paddings at the end of each example.
159+
xent = xent * (batch["targets_segmentation"] != 0)
160+
z_loss = z_loss * (batch["targets_segmentation"] != 0)
151161

152-
loss = total_loss / (total_weights + EPS)
162+
total_loss = jnp.sum(xent)
163+
total_z_loss = jnp.sum(z_loss)
164+
165+
total_weights = jnp.sum(batch["targets_segmentation"] != 0)
166+
# If gradient accumulation is enabled, we don't need to divide total_loss
167+
# by total_weights and then multiply the computed gradient by total_weights,
168+
# since it's equivalent to computing the gradient from total_loss.
169+
# This simplification reduces the number of operations and makes it easier
170+
# for XLA to move all-reduce out of the gradient accumulation loop when use
171+
# Zero1+GA to reduce communication overhead.
172+
# EPS was used to avoid division by zero, but it's not needed when gradient
173+
# accumulation is enabled since there's no division.
174+
if config.gradient_accumulation_steps > 1 and not config.use_tunix_gradient_accumulation:
175+
loss = total_loss
176+
else:
177+
# When using Tunix gradient accumulation, we revert to standard normalization.
178+
# Unlike the manual accumulation path above, Tunix (via optax.MultiSteps) expects
179+
# a normalized loss for each step. It handles the accumulation state
180+
# updates and scaling internally.
181+
loss = total_loss / (total_weights + EPS)
182+
183+
# We keep z-loss normalized by total_weights.
184+
total_z_loss = total_z_loss / (total_weights + EPS)
185+
186+
# Calculate and Add MTP Loss
187+
mtp_loss = 0.0
188+
if config.mtp_num_layers > 0 and is_train:
189+
mtp_loss = calculate_mtp_loss(intermediate_outputs, config)
190+
loss += mtp_loss
191+
192+
# get MoE load balance loss
193+
moe_lb_loss = 0.0
194+
if config.num_experts > 1:
195+
# Note: the key is affected by the model implementation
196+
possible_keys = [
197+
("intermediates", "decoder", "layers", "moe_lb_loss"),
198+
("intermediates", "decoder", "moe_layers", "moe_lb_loss"),
199+
]
200+
201+
total_moe_lb_loss = 0.0
202+
found_loss = False
203+
for nested_key in possible_keys:
204+
total_moe_lb_loss = maxtext_utils.get_nested_value(intermediate_outputs, nested_key, 0.0)
205+
if total_moe_lb_loss != 0.0:
206+
found_loss = True
207+
break
208+
209+
if not found_loss:
210+
max_logging.debug("\nNo MoE load balance loss found. Defaulting to 0.0.")
211+
212+
moe_lb_loss = jnp.mean(jnp.array(total_moe_lb_loss))
213+
loss += moe_lb_loss
214+
215+
# get MoE routed bias term updates
216+
moe_bias_updates = None
217+
if config.routed_bias and config.routed_bias_update_rate > 0.0:
218+
nested_key = ("intermediates", "decoder", "moe_layers", "moe_bias_updates")
219+
moe_bias_updates = maxtext_utils.get_nested_value(intermediate_outputs, nested_key, None)
220+
221+
# Add the model's primary output to the intermediates dict so it can be used
222+
# by the acceptance rate calculation in eval_step.
223+
intermediate_outputs["logits"] = logits
153224

154225
aux = {
226+
"intermediate_outputs": intermediate_outputs,
155227
"total_loss": total_loss,
156228
"z_loss": total_z_loss,
157229
"total_weights": total_weights,
230+
"moe_lb_loss": moe_lb_loss,
231+
"moe_bias_updates": moe_bias_updates,
232+
"mtp_loss": mtp_loss,
158233
}
159234
return loss, aux
160235

161-
162-
def eval_loss_fn(model: nnx.Module, config, data: dict[str, jax.Array], dropout_rng: jax.Array):
163-
"""Evaluation variant of loss_fn (no dropout, full batch size)."""
164-
batch = {k: v[: config.micro_batch_size_to_eval_on, :] for k, v in data.items()}
165-
166-
logits = model(
167-
decoder_input_tokens=batch["inputs"],
168-
decoder_positions=batch["inputs_position"],
169-
decoder_segment_ids=batch["inputs_segmentation"],
170-
enable_dropout=False,
171-
)
172-
173-
one_hot_targets = jax.nn.one_hot(batch["targets"], config.vocab_size)
174-
xent, z_loss = max_utils.cross_entropy_with_logits(logits, one_hot_targets, z_loss=config.z_loss_multiplier)
175-
236+
# Zero out padding positions
176237
target_mask = batch["targets_segmentation"] != 0
177238
xent = xent * target_mask
178239
z_loss = z_loss * target_mask
@@ -227,7 +288,7 @@ def train_step(
227288
# nnx.value_and_grad differentiates only through nnx.Param variables,
228289
# keeping non-differentiable state (RNGs, cache, etc.) frozen.
229290
grad_fn = nnx.value_and_grad(loss_fn, argnums=0, has_aux=True)
230-
(loss, aux), raw_grads = grad_fn(model, config, data, dropout_rng)
291+
(loss, aux), raw_grads = grad_fn(model, config, data, dropout_rng, is_train=True)
231292

232293
# Cast gradients to configured dtype before clipping / accumulation
233294
raw_grads = jax.tree.map(
@@ -280,16 +341,31 @@ def eval_step(
280341
metrics: Dict of scalar evaluation metrics.
281342
"""
282343
model: nnx.Module = nnx.merge(model_graphdef, model_state)
283-
loss, aux = eval_loss_fn(model, config, data, dropout_rng)
344+
loss, aux = loss_fn(model, config, data, dropout_rng, is_train=False)
345+
346+
mtp_acceptance_rate = 0.0
347+
if config.mtp_eval_target_module > 0:
348+
mtp_acceptance_rate = calculate_mtp_acceptance_rate(aux["intermediate_outputs"], config)
284349

350+
total_loss = aux["total_loss"]
351+
z_loss = aux["z_loss"]
352+
total_weights = aux["total_weights"]
353+
moe_lb_loss = aux["moe_lb_loss"]
354+
mtp_loss = aux["mtp_loss"]
285355
metrics = {
286356
"scalar": {
287357
"evaluation/loss": loss,
288-
"evaluation/z_loss": aux["z_loss"],
289-
"evaluation/total_loss": aux["total_loss"],
290-
"evaluation/total_weights": aux["total_weights"],
291-
}
358+
"evaluation/z_loss": z_loss,
359+
"evaluation/total_loss": total_loss,
360+
"evaluation/total_weights": total_weights,
361+
"evaluation/moe_lb_loss": moe_lb_loss,
362+
"evaluation/mtp_loss": mtp_loss,
363+
"evaluation/mtp_acceptance_rate_percent": mtp_acceptance_rate,
364+
},
292365
}
366+
# if config.use_dpo:
367+
# metrics["scalar"]["evaluation/dpo_reward_accuracy"] = aux["reward_accuracy"]
368+
293369
return metrics
294370

295371

0 commit comments

Comments
 (0)