Skip to content

Adds initial Keras Orbax checkpointer V2 implementation.#74

Open
copybara-service[bot] wants to merge 1 commit intomainfrom
test_774946771
Open

Adds initial Keras Orbax checkpointer V2 implementation.#74
copybara-service[bot] wants to merge 1 commit intomainfrom
test_774946771

Conversation

@copybara-service
Copy link
Copy Markdown

Adds initial Keras Orbax checkpointer V2 implementation.

First step in creating a memory efficient Keras + Jax checkpointer that uses nested PyTrees instead of flat tuples to enable model surgery.

  • Checkpoints the serialized model config as metadata.
  • Upgrades the checkpointing logic to the new Orbax API.
  • Writes checkpoints as a dict instead of a tuple.
  • Removes unnecessary expensive jax_state_sync calls.

Reverts changelist 793734230

First step in creating a memory efficient Keras + Jax checkpointer that uses nested PyTrees instead of flat tuples to enable model surgery.

- Checkpoints the serialized model config as metadata.
- Upgrades the checkpointing logic to the new Orbax API.
- Writes checkpoints as a dict instead of a tuple.
- Removes unnecessary expensive `jax_state_sync` calls.

Reverts changelist 793734230

PiperOrigin-RevId: 774946771
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant