Skip to content

Commit bff6a28

Browse files
committed
NNX in logits check, workaround
this is still broadcast among devices update update
1 parent 5478bad commit bff6a28

3 files changed

Lines changed: 113 additions & 3 deletions

File tree

src/maxtext/utils/maxtext_utils.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,6 +1087,58 @@ def setup_decode_state(model, config, rng, mesh, checkpoint_manager):
10871087
return state, state_mesh_annotations
10881088

10891089

1090+
def setup_decode_state_from_nnx(model, config, rng, mesh):
1091+
"""Setup decode state by loading params from an NNX checkpoint.
1092+
1093+
Args:
1094+
model: the flax model to initialize
1095+
config: config object
1096+
rng: jax.prng key
1097+
mesh: jax.devices() mesh
1098+
1099+
Returns:
1100+
state: state with decode params loaded from the NNX checkpoint
1101+
state_mesh_annotations: the mesh annotations for the state
1102+
"""
1103+
from maxtext.checkpoint_conversion.utils.utils import load_orbax_checkpoint # pylint: disable=import-outside-toplevel
1104+
1105+
def unwrap_nnx_values(tree):
1106+
if isinstance(tree, dict):
1107+
filtered_tree = {k: v for k, v in tree.items() if 'to_nnx__rngs' not in k}
1108+
if len(filtered_tree) == 1 and 'value' in filtered_tree and isinstance(filtered_tree['value'], (jax.Array, np.ndarray)):
1109+
return filtered_tree['value']
1110+
return {k: unwrap_nnx_values(v) for k, v in filtered_tree.items()}
1111+
return tree
1112+
1113+
def to_numpy(tree):
1114+
def convert_leaf(x):
1115+
if isinstance(x, jax.Array):
1116+
return np.array(x)
1117+
return x
1118+
return jax.tree_util.tree_map(convert_leaf, tree)
1119+
1120+
def reshard_to_match(target_tree, source_tree):
1121+
def copy_sharding(target_leaf, source_leaf):
1122+
if isinstance(source_leaf, (jax.Array, np.ndarray)):
1123+
if isinstance(target_leaf, jax.Array) and hasattr(target_leaf, 'sharding'):
1124+
return jax.device_put(source_leaf, target_leaf.sharding)
1125+
return source_leaf
1126+
return source_leaf
1127+
return jax.tree_util.tree_map(copy_sharding, target_tree, source_tree)
1128+
1129+
unboxed_abstract_state, state_mesh_annotations, _ = get_abstract_state(model, None, config, rng, mesh, False)
1130+
1131+
loaded_params = load_orbax_checkpoint(config)
1132+
loaded_params = unwrap_nnx_values(loaded_params)
1133+
# NNX-RL checkpoints may have an extra 'base' nesting level, while NNX-SFT may not
1134+
loaded_params = {'params': loaded_params['base'] if 'base' in loaded_params else loaded_params}
1135+
loaded_params = to_numpy(loaded_params)
1136+
params = reshard_to_match(unboxed_abstract_state.params, loaded_params)
1137+
1138+
state = init_decode_state(model.apply, params)
1139+
return state, state_mesh_annotations
1140+
1141+
10901142
def setup_training_state(model, data_iterator, tx, config, rng, mesh, checkpoint_manager):
10911143
is_training = True
10921144
return setup_initial_state(

src/maxtext/utils/model_creation_utils.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,16 @@ def create_sharded_state():
183183
metadata = ckptr.metadata(config.load_parameters_path)
184184

185185
is_nnx_checkpoint = True
186+
has_base_key = False
187+
188+
# Helper function to enable padding/truncation for mesh mismatches
189+
def create_flexible_restore_args(tree):
190+
return jax.tree.map(
191+
lambda x: ocp.type_handlers.ArrayRestoreArgs(sharding=x.sharding) if hasattr(x, "sharding") else None,
192+
tree,
193+
is_leaf=lambda x: hasattr(x, "sharding"),
194+
)
195+
186196
if (
187197
"params" in metadata.item_metadata.tree.keys()
188198
and "params" in metadata.item_metadata.tree.get("params", {}).keys()
@@ -196,7 +206,19 @@ def create_sharded_state():
196206
)
197207

198208
item_to_restore = {"params": {"params": target_for_restore}}
199-
restore_args = {"params": {"params": ocp.checkpoint_utils.construct_restore_args(target_for_restore)}}
209+
# restore_args = {"params": {"params": ocp.checkpoint_utils.construct_restore_args(target_for_restore)}}
210+
restore_args = {"params": {"params": create_flexible_restore_args(target_for_restore)}}
211+
elif "base" in metadata.item_metadata.tree.keys():
212+
# structure of nnx-rl checkpoint: {'base': {'decoder': {..., 'value': ...}}}
213+
has_base_key = True
214+
target_for_restore = jax.tree.map(
215+
lambda v: {"value": v.value},
216+
sharded_state,
217+
is_leaf=lambda n: isinstance(n, nnx.Variable),
218+
)
219+
item_to_restore = {"base": target_for_restore}
220+
# restore_args = {"base": ocp.checkpoint_utils.construct_restore_args(target_for_restore)}
221+
restore_args = {"base": create_flexible_restore_args(target_for_restore)}
200222
else:
201223
# structure of nnx checkpoint: {'decoder': {'value': ...}}
202224
target_for_restore = jax.tree.map(
@@ -205,7 +227,8 @@ def create_sharded_state():
205227
is_leaf=lambda n: isinstance(n, nnx.Variable),
206228
)
207229
item_to_restore = target_for_restore
208-
restore_args = ocp.checkpoint_utils.construct_restore_args(target_for_restore)
230+
# restore_args = ocp.checkpoint_utils.construct_restore_args(target_for_restore)
231+
restore_args = create_flexible_restore_args(target_for_restore)
209232

210233
restored = ckptr.restore(
211234
epath.Path(config.load_parameters_path),
@@ -215,6 +238,10 @@ def create_sharded_state():
215238
)
216239

217240
if is_nnx_checkpoint:
241+
# Unwrap 'base' key if present (NNX-RL format)
242+
if has_base_key:
243+
restored = restored.get("base", restored)
244+
218245
checkpoint = jax.tree.map(
219246
lambda v: v["value"],
220247
restored,

tests/utils/forward_pass_logit_checker.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,10 @@ def main(config, test_args): # pylint: disable=W0621
425425
mesh = jax.sharding.Mesh(devices_array, config.mesh_axes)
426426
quant = quantizations.configure_quantization(config)
427427
maxtext_model = models.transformer_as_linen(config, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
428-
maxtext_state, _ = maxtext_utils.setup_decode_state(maxtext_model, config, rng1, mesh, None)
428+
if test_args.ckpt_type == "linen":
429+
maxtext_state, _ = maxtext_utils.setup_decode_state(maxtext_model, config, rng1, mesh, None)
430+
else:
431+
maxtext_state, _ = maxtext_utils.setup_decode_state_from_nnx(maxtext_model, config, rng1, mesh)
429432

430433
prompts = ["I love to", "Today is a", "What is the"]
431434
all_data_to_save = []
@@ -537,6 +540,34 @@ def main(config, test_args): # pylint: disable=W0621
537540
test_args, remaining_args = parser.parse_known_args()
538541
# Reconstruct model_args (script name + the args MaxText needs)
539542
model_args = [sys.argv[0]] + remaining_args
543+
parser.add_argument(
544+
"--ckpt_type",
545+
type=str,
546+
required=False,
547+
default="linen",
548+
choices=["linen", "nnx"],
549+
help="Checkpoint format to load: 'linen' (default) or 'nnx'.",
550+
)
551+
test_args, _ = parser.parse_known_args()
552+
553+
# Remove args defined in this test file to avoid error from pyconfig
554+
model_args = sys.argv
555+
to_remove_args = [
556+
"--atol",
557+
"--rtol",
558+
"--token_size",
559+
"--max_kl_div",
560+
"--golden_logits_path",
561+
"--hf_model_path",
562+
"--run_hf_model",
563+
"--output_logits_path",
564+
"--gcs_output_logits_path",
565+
"--clip_logits_epsilon",
566+
"--skip_first_token",
567+
"--ckpt_type",
568+
]
569+
for arg in to_remove_args:
570+
model_args = [s for s in model_args if not s.startswith(arg)]
540571

541572
cfg = pyconfig.initialize(model_args)
542573
assert (

0 commit comments

Comments
 (0)