Skip to content

Commit 86e20ea

Browse files
committed
NNX post-train fixes: unpack MultimodalInput for NNX decoder; support scalar LR in adam_pax
- models.py: NNX Transformer was passing `multimodal_input=MultimodalInput(...)` to NNXDecoder, which expects individual keyword args (image_embeddings, image_masks, audio_embeddings, audio_masks, bidirectional_mask). Unpack the object at the call site. - optimizers.py: adam_pax called `learning_rate_fn(count)` unconditionally, failing when `optax.inject_hyperparams` passes a pre-evaluated scalar instead of a callable schedule. Add `callable()` guard to handle both cases.
1 parent 9895925 commit 86e20ea

2 files changed

Lines changed: 8 additions & 2 deletions

File tree

src/maxtext/models/models.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,11 @@ def __call__(
517517
previous_chunk=previous_chunk,
518518
slot=slot,
519519
page_state=page_state,
520-
multimodal_input=multimodal_input,
520+
image_embeddings=multimodal_input.image_embeddings if multimodal_input is not None else None,
521+
image_masks=multimodal_input.image_masks if multimodal_input is not None else None,
522+
audio_embeddings=multimodal_input.audio_embeddings if multimodal_input is not None else None,
523+
audio_masks=multimodal_input.audio_masks if multimodal_input is not None else None,
524+
bidirectional_mask=multimodal_input.bidirectional_mask if multimodal_input is not None else None,
521525
kv_caches=kv_caches,
522526
attention_metadata=attention_metadata,
523527
deepstack_visual_embeds=deepstack_visual_embeds,

src/maxtext/optimizers/optimizers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,9 @@ def _update_momentum(update, mu, nu):
336336
else:
337337
updates = jax.tree_util.tree_map(lambda x, v: x + weight_decay * v, updates, params)
338338

339-
step_size = -1.0 * learning_rate_fn(count)
339+
# learning_rate_fn may be a callable schedule or a scalar (e.g. when wrapped
340+
# by optax.inject_hyperparams, it is passed as a pre-evaluated scalar).
341+
step_size = -1.0 * (learning_rate_fn(count) if callable(learning_rate_fn) else learning_rate_fn)
340342
# Finally, fold in step size.
341343
updates = jax.tree_util.tree_map(lambda x: step_size * x, updates)
342344

0 commit comments

Comments
 (0)