Commit 86e20ea
committed
NNX post-train fixes: unpack MultimodalInput for NNX decoder; support scalar LR in adam_pax
- models.py: NNX Transformer was passing `multimodal_input=MultimodalInput(...)` to
NNXDecoder, which expects individual keyword args (image_embeddings, image_masks,
audio_embeddings, audio_masks, bidirectional_mask). Unpack the object at the call site.
- optimizers.py: adam_pax called `learning_rate_fn(count)` unconditionally, failing when
`optax.inject_hyperparams` passes a pre-evaluated scalar instead of a callable schedule.
Add `callable()` guard to handle both cases.1 parent 9895925 commit 86e20ea
2 files changed
Lines changed: 8 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
517 | 517 | | |
518 | 518 | | |
519 | 519 | | |
520 | | - | |
| 520 | + | |
| 521 | + | |
| 522 | + | |
| 523 | + | |
| 524 | + | |
521 | 525 | | |
522 | 526 | | |
523 | 527 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
336 | 336 | | |
337 | 337 | | |
338 | 338 | | |
339 | | - | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
340 | 342 | | |
341 | 343 | | |
342 | 344 | | |
| |||
0 commit comments