Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
609 changes: 609 additions & 0 deletions src/maxtext/checkpoint_conversion/compare_linen_nnx_checkpoint.py

Large diffs are not rendered by default.

581 changes: 581 additions & 0 deletions src/maxtext/checkpoint_conversion/linen_nnx_converter.py

Large diffs are not rendered by default.

32 changes: 27 additions & 5 deletions src/maxtext/common/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from absl import flags
import datetime
from etils import epath
from flax import nnx
from flax.training import train_state
import jax
from maxtext.utils.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE
Expand Down Expand Up @@ -536,7 +537,7 @@ def load_state_if_possible(
load_parameters_from_path: str,
load_full_state_from_path: str,
checkpoint_storage_concurrent_gb: int,
abstract_unboxed_pre_state: train_state.TrainState,
abstract_unboxed_pre_state: train_state.TrainState | nnx.State,
enable_single_replica_ckpt_restoring: bool | None = False,
dataset_type: str | None = "tfds",
step: int = -1, # -1 means latest
Expand Down Expand Up @@ -604,8 +605,13 @@ def map_to_pspec(data):
)
ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True)

restore_args = jax.tree_util.tree_map(map_to_pspec, abstract_unboxed_pre_state)
checkpoint_args = ocp.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args)
# Convert nnx.State to pure dict to match how checkpoints are saved for NNX
restore_target = abstract_unboxed_pre_state
if isinstance(abstract_unboxed_pre_state, nnx.State):
restore_target = abstract_unboxed_pre_state.to_pure_dict()

restore_args = jax.tree_util.tree_map(map_to_pspec, restore_target)
checkpoint_args = ocp.args.PyTreeRestore(item=restore_target, restore_args=restore_args)

match (checkpoint_manager, dataset_type, data_iterator):
# Case 1: Matches if 'checkpoint_manager' is an instance of either EmergencyCheckpointManager
Expand Down Expand Up @@ -640,9 +646,14 @@ def map_to_pspec(data):
return (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)), None)

if load_parameters_from_path != "":
if isinstance(abstract_unboxed_pre_state, nnx.State):
_, params, _ = nnx.split(abstract_unboxed_pre_state.model, nnx.Param, ...)
else:
params = abstract_unboxed_pre_state.params

restored_params = load_params_from_path(
load_parameters_from_path,
abstract_unboxed_pre_state.params,
params,
checkpoint_storage_concurrent_gb,
use_ocdbt=use_ocdbt,
use_zarr3=use_zarr3,
Expand Down Expand Up @@ -734,7 +745,18 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step
# Determine the effective step for saving a checkpoint.
# If 'step' is not provided, this call is for a potential final checkpoint
# and use the last completed step from the state.
actual_step = (int(state.step) - 1) if step is None else int(step)
if step is not None:
actual_step = int(step)
else:
if config.pure_nnx:
actual_step = int(state.optimizer.step) - 1
else:
# Linen TrainState has .step attribute
actual_step = int(state.step) - 1

if config.pure_nnx:
# Convert nnx.State to dict.
state = state.to_pure_dict()

# Determine if a checkpoint save should be forced, overriding the usual `config.checkpoint_period` logic.
# This occurs if this function was called:
Expand Down
8 changes: 8 additions & 0 deletions src/maxtext/layers/nnx_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
MODEL_MODE_TRAIN,
Config,
DecoderBlockType,
MultimodalInput,
ShardMode,
)
from maxtext.inference import page_manager
Expand Down Expand Up @@ -904,7 +905,14 @@ def __call__(
audio_embeddings: None | jnp.ndarray = None,
audio_masks: None | jnp.ndarray = None,
deepstack_visual_embeds: None | list[jnp.ndarray] = None,
multimodal_input: None | MultimodalInput = None,
):
if multimodal_input is not None:
image_embeddings = multimodal_input.image_embeddings
image_masks = multimodal_input.image_masks
audio_embeddings = multimodal_input.audio_embeddings
audio_masks = multimodal_input.audio_masks
bidirectional_mask = multimodal_input.bidirectional_mask
cfg = self.config
assert decoder_input_tokens.ndim == 2 # [batch, len]

Expand Down
Loading
Loading