diff --git a/src/maxtext/checkpoint_conversion/compare_linen_nnx_checkpoint.py b/src/maxtext/checkpoint_conversion/compare_linen_nnx_checkpoint.py new file mode 100644 index 0000000000..c103f234ee --- /dev/null +++ b/src/maxtext/checkpoint_conversion/compare_linen_nnx_checkpoint.py @@ -0,0 +1,609 @@ +# Copyright 2023-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Compare checkpoint tree structures, shapes, and values. + +Supports comparing any combination of Linen and NNX checkpoints: +- Linen vs NNX (cross-format comparison) +- Linen vs Linen (same-format comparison) +- NNX vs NNX (same-format comparison) + +The script auto-detects the format of each checkpoint and applies the +appropriate normalization. Cross-format transformations (like layer axis +transposition) are only applied when comparing Linen vs NNX. + +Key differences between Linen and NNX checkpoints: +- Linen: params/params/decoder/layers/0/... (per-layer, double nested) +- NNX: model/decoder/layers/... (stacked layers, single nested, {value: array} wrappers) + +The script handles: +- Double 'params' nesting in Linen checkpoints +- 'model' key in NNX checkpoints (vs 'params' in Linen) +- {value: array} wrappers in NNX checkpoints +- Layer axis transposition (NNX stacks layers along axis 0, only for cross-format) +- RNG filtering (NNX has rngs, Linen doesn't) + +Usage: + # Compare Linen vs NNX (structure and shapes only) + python compare_linen_nnx_checkpoint.py \ + --ckpt_path_1="gs://bucket/linen_checkpoint/0/items" \ + --ckpt_path_2="gs://bucket/nnx_checkpoint/0/items" + + # Compare NNX vs NNX + python compare_linen_nnx_checkpoint.py \ + --ckpt_path_1="gs://bucket/nnx_checkpoint_a/0/items" \ + --ckpt_path_2="gs://bucket/nnx_checkpoint_b/0/items" + + # Compare Linen vs Linen + python compare_linen_nnx_checkpoint.py \ + --ckpt_path_1="gs://bucket/linen_checkpoint_a/0/items" \ + --ckpt_path_2="gs://bucket/linen_checkpoint_b/0/items" + + # Compare with value checking + python compare_linen_nnx_checkpoint.py \ + --ckpt_path_1="gs://bucket/checkpoint_a/0/items" \ + --ckpt_path_2="gs://bucket/checkpoint_b/0/items" \ + --compare_values --atol=1e-5 --rtol=1e-5 +""" + +import os +from typing import Any, Dict, Sequence + +# MUST set before importing JAX to force CPU-only mode +os.environ["JAX_PLATFORMS"] = "cpu" + +import jax +import jax.numpy as jnp +from jax.tree_util import tree_flatten_with_path, keystr, tree_structure, tree_map_with_path +import numpy as np +from etils import epath +import orbax.checkpoint as ocp +from absl import app +from absl import flags + +FLAGS = flags.FLAGS + +flags.DEFINE_string( + "ckpt_path_1", + None, + "Path to the first checkpoint items directory. Format is auto-detected.", + required=True, +) +flags.DEFINE_string( + "ckpt_path_2", + None, + "Path to the second checkpoint items directory. Format is auto-detected.", + required=True, +) +flags.DEFINE_boolean( + "verbose", + False, + "Print detailed per-parameter information.", +) +flags.DEFINE_boolean( + "transpose_nnx_layers", + False, + "Transpose NNX layer params from (layers, ...) to (...) for comparison. " + "NNX stacks layers along axis 0, while Linen stores per-layer params. " + "Only applied for cross-format (Linen vs NNX) comparisons.", +) +flags.DEFINE_string( + "compare_only", + "params", + "Which parts to compare: 'params' for params only, 'all' for full state.", +) +flags.DEFINE_boolean( + "ignore_rngs", + True, + "Ignore RNG-related paths in comparison (NNX has rngs, Linen doesn't).", +) +flags.DEFINE_boolean( + "compare_values", + False, + "Also compare parameter values (not just structure and shapes).", +) +flags.DEFINE_float( + "atol", + 1e-5, + "Absolute tolerance for value comparison.", +) +flags.DEFINE_float( + "rtol", + 1e-5, + "Relative tolerance for value comparison.", +) + + +def log(message: str) -> None: + """Log a message with prefix.""" + print(f"[compare_ckpt] {message}") + + +def is_rng_path(path: str) -> bool: + """Check if a path is RNG-related.""" + path_lower = path.lower() + return "rngs" in path_lower or "rng" in path_lower + + +def filter_rngs(tree: Dict[str, Any]) -> Dict[str, Any]: + """Filter out RNG-related keys from a tree.""" + if not isinstance(tree, dict): + return tree + + result = {} + for key, value in tree.items(): + # Skip RNG-related keys + if is_rng_path(key): + continue + # Recursively filter nested dicts + if isinstance(value, dict): + filtered = filter_rngs(value) + if filtered: # Only add if not empty after filtering + result[key] = filtered + else: + result[key] = value + return result + + +def detect_format(state: dict) -> str: + """Detects checkpoint format from state structure ('linen' or 'nnx'). + + Linen format: + - Top-level keys: ['params', 'opt_state', 'step'] + - params/params/decoder/... (double nested) + + NNX format: + - Top-level keys: ['model', 'optimizer'] (nnx.State style) + - model/decoder/... with {value: array} wrappers + """ + # Check for NNX nnx.State format (has 'model' key instead of 'params') + if "model" in state: + return "nnx" + + if "params" not in state: + raise ValueError(f"Checkpoint does not contain 'params' or 'model' key. Found keys: {list(state.keys())}") + + params = state["params"] + + # Check for Linen's double 'params' nesting + if isinstance(params, dict) and "params" in params: + inner = params["params"] + if isinstance(inner, dict) and ("decoder" in inner or "encoder" in inner): + return "linen" + + # Check for NNX's flat structure (params/decoder/...) + if isinstance(params, dict) and ("decoder" in params or "encoder" in params): + return "nnx" + + # Try to detect by looking for {value: array} wrappers (NNX style) + if _has_value_wrappers(params): + return "nnx" + + raise ValueError( + f"Could not detect checkpoint format. params keys: {list(params.keys()) if isinstance(params, dict) else type(params)}" + ) + + +def _has_value_wrappers(tree: Any) -> bool: + """Check if tree contains {value: array} wrappers (NNX style).""" + if isinstance(tree, dict): + if set(tree.keys()) == {"value"}: + inner = tree["value"] + if hasattr(inner, "shape") or isinstance(inner, (np.ndarray, jnp.ndarray)): + return True + for v in tree.values(): + if _has_value_wrappers(v): + return True + return False + + +def _strip_value_wrappers(tree: Any) -> Any: + """Recursively strips {'value': array} wrappers from a tree.""" + if isinstance(tree, dict): + if set(tree.keys()) == {"value"}: + inner = tree["value"] + if hasattr(inner, "shape") or isinstance(inner, (np.ndarray, jnp.ndarray)): + return inner + return {k: _strip_value_wrappers(v) for k, v in tree.items()} + elif isinstance(tree, (list, tuple)): + return type(tree)(_strip_value_wrappers(item) for item in tree) + else: + return tree + + +def _normalize_linen_params(params: dict) -> dict: + """Normalize Linen params by removing double 'params' nesting.""" + if isinstance(params, dict) and "params" in params: + inner = params["params"] + if isinstance(inner, dict) and ("decoder" in inner or "encoder" in inner): + return inner + return params + + +def _normalize_nnx_params(params: dict) -> dict: + """Normalize NNX params by stripping {value: array} wrappers.""" + return _strip_value_wrappers(params) + + +def load_checkpoint(checkpoint_path: str, metadata_only: bool = False) -> dict: + """Loads checkpoint from local or GCS path. + + If metadata_only=True, returns a pytree of ArrayMetadata (shape/dtype only) + without downloading any tensor data. This is fast and sufficient for + structure/shape comparison. + """ + log(f"Loading checkpoint from: {checkpoint_path}") + if metadata_only: + log(" Mode: metadata only (no tensor data downloaded)") + + checkpoint_dir = epath.Path(checkpoint_path) + + # Create checkpointer and get metadata + ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler()) + + try: + metadata = ckptr.metadata(checkpoint_dir) + + if metadata_only: + tree = metadata.item_metadata.tree + log(f" Loaded metadata keys: {list(tree.keys())}") + return tree + + # Create a mesh with all available devices for unsharded restoration + devices = np.array(jax.devices()).reshape((-1,)) + single_device_mesh = jax.sharding.Mesh(devices, ("x",)) + unsharded = jax.sharding.NamedSharding(single_device_mesh, jax.sharding.PartitionSpec()) + + # Build restore args that restore arrays without original sharding + restore_args = jax.tree_util.tree_map( + lambda x: ocp.ArrayRestoreArgs(sharding=unsharded) if hasattr(x, "shape") else None, + metadata.item_metadata.tree, + is_leaf=lambda x: hasattr(x, "shape"), + ) + state = ckptr.restore(checkpoint_dir, restore_args=restore_args) + except Exception as e: # pylint: disable=broad-exception-caught + if metadata_only: + log(f" Metadata loading failed: {e}") + raise + # Fallback to simple restore without sharding args + log(f" Falling back to simple restore: {e}") + checkpointer = ocp.PyTreeCheckpointer() + state = checkpointer.restore(checkpoint_path) + + if state is None: + raise ValueError(f"Failed to restore checkpoint from {checkpoint_path}") + + log(f" Loaded keys: {list(state.keys())}") + return state + + +def transform_nnx_params_for_comparison(nnx_params: Dict[str, Any]) -> Dict[str, Any]: + """Transform NNX params to match Linen structure for comparison. + + NNX stacks layer parameters along axis 0 (shape: [num_layers, ...]), + while Linen stores per-layer parameters (shape: [...]). + + This function transposes layer params from (layers, d1, d2, ...) to (d1, layers, d2, ...) + to align with how Linen params would look if stacked. + """ + + def _transform(path, leaf: jax.Array) -> jax.Array: + key_str = keystr(path) + + # Only transform arrays in 'layers' with ndim >= 2 + if "layers" in key_str and hasattr(leaf, "ndim") and leaf.ndim >= 2: + # Transpose from (layers, d1, d2, ...) to (d1, layers, d2, ...) + axes = (1, 0) + tuple(range(2, leaf.ndim)) + result = jnp.transpose(leaf, axes=axes) + if FLAGS.verbose: + log(f" TRANSPOSING: {key_str} shape {leaf.shape} -> {result.shape}") + return result + else: + return leaf + + log("Transforming NNX params (transposing layer dimensions)...") + return tree_map_with_path(_transform, nnx_params) + + +def get_tree_structure_info(tree: Dict[str, Any]) -> Dict[str, tuple]: + """Get structure info as dict of path -> (shape, dtype).""" + flat_with_path, _ = tree_flatten_with_path(tree) + return { + keystr(p): ( + getattr(leaf, "shape", "N/A"), + str(getattr(leaf, "dtype", type(leaf).__name__)), + ) + for p, leaf in flat_with_path + } + + +def print_structure_diff(params1: Dict, params2: Dict, name1: str = "Linen", name2: str = "NNX"): + """Print structural differences between two param trees.""" + info1 = get_tree_structure_info(params1) + info2 = get_tree_structure_info(params2) + keys1, keys2 = set(info1.keys()), set(info2.keys()) + + only_in_1 = sorted(keys1 - keys2) + only_in_2 = sorted(keys2 - keys1) + common = keys1 & keys2 + + if only_in_1: + print(f"\n--- Paths only in {name1} ({len(only_in_1)}) ---") + for k in only_in_1: + shape, dtype = info1[k] + print(f" - {k}: shape={shape}, dtype={dtype}") + + if only_in_2: + print(f"\n--- Paths only in {name2} ({len(only_in_2)}) ---") + for k in only_in_2: + shape, dtype = info2[k] + print(f" + {k}: shape={shape}, dtype={dtype}") + + # Check for shape/dtype mismatches in common paths + shape_mismatches = [] + dtype_mismatches = [] + for k in common: + shape1, dtype1 = info1[k] + shape2, dtype2 = info2[k] + if shape1 != shape2: + shape_mismatches.append((k, shape1, shape2)) + if dtype1 != dtype2: + dtype_mismatches.append((k, dtype1, dtype2)) + + if shape_mismatches: + print(f"\n--- Shape mismatches ({len(shape_mismatches)}) ---") + for k, s1, s2 in shape_mismatches: + print(f" {k}: {name1}={s1}, {name2}={s2}") + + if dtype_mismatches: + print(f"\n--- Dtype mismatches ({len(dtype_mismatches)}) ---") + for k, d1, d2 in dtype_mismatches: + print(f" {k}: {name1}={d1}, {name2}={d2}") + + return only_in_1, only_in_2, shape_mismatches, dtype_mismatches + + +def compare_params( + params1: Dict[str, Any], + params2: Dict[str, Any], + verbose: bool = False, + compare_values: bool = False, + atol: float = 1e-5, + rtol: float = 1e-5, + name1: str = "Ckpt1", + name2: str = "Ckpt2", +) -> bool: + """Compare two parameter trees for structure, shape, and optionally values. + + Returns True if tree structures, shapes, and (optionally) values match. + """ + # First check tree structure + if tree_structure(params1) != tree_structure(params2): + print("\n[✗] Tree structures differ.") + print_structure_diff(params1, params2, name1=name1, name2=name2) + return False + + print("\n[✓] Tree structures are the same.") + + all_match = True + num_params = 0 + shape_mismatches = [] + dtype_mismatches = [] + value_mismatches = [] + value_matches = 0 + + def _compare_leaf(path, x, y): + nonlocal all_match, num_params, shape_mismatches, dtype_mismatches, value_mismatches, value_matches + key_str = keystr(path) + num_params += 1 + + shape1 = getattr(x, "shape", "N/A") + shape2 = getattr(y, "shape", "N/A") + dtype1 = getattr(x, "dtype", type(x).__name__) + dtype2 = getattr(y, "dtype", type(y).__name__) + + # Check shape + shape_match = shape1 == shape2 + if not shape_match: + shape_mismatches.append((key_str, shape1, shape2)) + all_match = False + + # Check dtype + dtype_match = str(dtype1) == str(dtype2) + if not dtype_match: + dtype_mismatches.append((key_str, dtype1, dtype2)) + all_match = False + + # Check values if requested and shapes match + if compare_values and shape_match and hasattr(x, "shape") and hasattr(y, "shape"): + try: + x_arr = np.asarray(x) + y_arr = np.asarray(y) + is_close = bool(np.allclose(x_arr, y_arr, atol=atol, rtol=rtol)) + + if is_close: + value_matches += 1 + if verbose: + print(f" [✓] {key_str} | Shape: {shape1} | Values match") + else: + diff = np.abs(x_arr - y_arr) + mean_diff = float(np.mean(diff)) + max_diff = float(np.max(diff)) + value_mismatches.append((key_str, mean_diff, max_diff)) + all_match = False + if verbose: + print(f" [✗] {key_str} | Shape: {shape1} | Mean diff: {mean_diff:.2e}, Max diff: {max_diff:.2e}") + except Exception as e: # pylint: disable=broad-exception-caught + value_mismatches.append((key_str, f"Error: {e}", "")) + all_match = False + elif verbose and not compare_values: + print(f" {key_str} | Shape: {shape1} | Dtype: {dtype1}") + + tree_map_with_path(_compare_leaf, params1, params2) + + # Print summary + print("\n--- Summary ---") + print(f"Total parameters: {num_params}") + + if shape_mismatches: + print(f"\n[✗] Shape mismatches ({len(shape_mismatches)}):") + for key_str, s1, s2 in shape_mismatches: + print(f" {key_str}: {name1}={s1}, {name2}={s2}") + else: + print("[✓] All shapes match.") + + if dtype_mismatches: + print(f"\n[✗] Dtype mismatches ({len(dtype_mismatches)}):") + for key_str, d1, d2 in dtype_mismatches: + print(f" {key_str}: {name1}={d1}, {name2}={d2}") + else: + print("[✓] All dtypes match.") + + if compare_values: + if value_mismatches: + print(f"\n[✗] Value mismatches ({len(value_mismatches)}):") + for item in value_mismatches[:20]: # Show first 20 + if len(item) == 3: + key_str, mean_diff, max_diff = item + if isinstance(mean_diff, float): + print(f" {key_str}: mean_diff={mean_diff:.2e}, max_diff={max_diff:.2e}") + else: + print(f" {key_str}: {mean_diff}") + if len(value_mismatches) > 20: + print(f" ... and {len(value_mismatches) - 20} more (use --verbose to see all)") + else: + print(f"[✓] All values match (atol={atol}, rtol={rtol}).") + print(f" Values matching: {value_matches}/{num_params}") + + return all_match + + +def _extract_params(state: dict, fmt: str) -> dict: + """Extract params from a checkpoint state based on its detected format.""" + if fmt == "linen": + return state.get("params", {}) + else: + # NNX format: params are in 'model' key + return state.get("model", state.get("params", {})) + + +def _normalize_params(params: dict, fmt: str) -> dict: + """Normalize params based on detected format.""" + if fmt == "linen": + return _normalize_linen_params(params) + else: + return _normalize_nnx_params(params) + + +def main(argv: Sequence[str]): + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + ckpt_path_1 = FLAGS.ckpt_path_1 + ckpt_path_2 = FLAGS.ckpt_path_2 + + print("=" * 80) + print("Checkpoint Comparator") + print("=" * 80) + + print(f"\nCheckpoint 1: {ckpt_path_1}") + print(f"Checkpoint 2: {ckpt_path_2}") + print(f"Transpose NNX layers: {FLAGS.transpose_nnx_layers}") + print(f"Ignore RNGs: {FLAGS.ignore_rngs}") + print(f"Compare values: {FLAGS.compare_values}") + if FLAGS.compare_values: + print(f" Tolerance: atol={FLAGS.atol}, rtol={FLAGS.rtol}") + + # Load checkpoints — use metadata-only when not comparing values to avoid + # downloading tensor data (which can be 100+ GiB and cause XPK timeouts). + metadata_only = not FLAGS.compare_values + print("\n" + "-" * 40) + state_1 = load_checkpoint(ckpt_path_1, metadata_only=metadata_only) + state_2 = load_checkpoint(ckpt_path_2, metadata_only=metadata_only) + + # Detect formats + format_1 = detect_format(state_1) + format_2 = detect_format(state_2) + log(f"Detected checkpoint 1 format: {format_1}") + log(f"Detected checkpoint 2 format: {format_2}") + + is_cross_format = format_1 != format_2 + name_1 = f"Ckpt1({format_1})" + name_2 = f"Ckpt2({format_2})" + + # Extract and normalize params + print("\n" + "-" * 40) + log("Normalizing parameters...") + + if FLAGS.compare_only == "params": + params_1 = _extract_params(state_1, format_1) + params_2 = _extract_params(state_2, format_2) + else: + params_1 = state_1 + params_2 = state_2 + + params_1 = _normalize_params(params_1, format_1) + log(f" Checkpoint 1 ({format_1}): normalized") + params_2 = _normalize_params(params_2, format_2) + log(f" Checkpoint 2 ({format_2}): normalized") + + # Filter out RNG paths if requested + if FLAGS.ignore_rngs: + print("\n" + "-" * 40) + log("Filtering out RNG-related paths...") + params_1 = filter_rngs(params_1) + params_2 = filter_rngs(params_2) + + # Transform NNX params for cross-format comparison (transpose layer dimensions) + # Only apply when comparing Linen vs NNX, not for same-format comparisons + if FLAGS.transpose_nnx_layers and is_cross_format: + print("\n" + "-" * 40) + if format_1 == "nnx": + params_1 = transform_nnx_params_for_comparison(params_1) + if format_2 == "nnx": + params_2 = transform_nnx_params_for_comparison(params_2) + + # Compare + print("\n" + "-" * 40) + log("Comparing parameters...") + + success = compare_params( + params_1, + params_2, + verbose=FLAGS.verbose, + compare_values=FLAGS.compare_values, + atol=FLAGS.atol, + rtol=FLAGS.rtol, + name1=name_1, + name2=name_2, + ) + + # Final verdict + print("\n" + "=" * 80) + if success: + print("CHECKPOINTS MATCH") + if FLAGS.compare_values: + print(" Tree structure, shapes, and values are identical!") + else: + print(" Tree structure and all shapes are identical!") + else: + print("CHECKPOINTS DIFFER") + print(" See details above for mismatches.") + print("=" * 80) + + return 0 if success else 1 + + +if __name__ == "__main__": + app.run(main) diff --git a/src/maxtext/checkpoint_conversion/linen_nnx_converter.py b/src/maxtext/checkpoint_conversion/linen_nnx_converter.py new file mode 100644 index 0000000000..015d3b5a56 --- /dev/null +++ b/src/maxtext/checkpoint_conversion/linen_nnx_converter.py @@ -0,0 +1,581 @@ +# Copyright 2023-2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Bidirectional conversion between Linen and NNX checkpoint formats. + +Top-level key mapping: + Linen → NNX: + params/params/ → model/ (remove double-nesting, rename, add {value:} wrappers) + opt_state → optimizer/opt_state (remove 'params' level from mu/nu) + step → optimizer/step (move inside optimizer) + + NNX → Linen: + model/ → params/params/ (strip {value:} wrappers, add double-nesting) + optimizer/opt_state → opt_state (add 'params' level to mu/nu) + optimizer/step → step (move to top level) + +Layer structure (--scan_layers): + linen_to_nnx: + scan_layers=True (default): stack layers_N arrays → 'layers' tensor with layer dim at axis 1 + scan_layers=False: rename layers_N → integer-keyed 'layers/{N}' + + nnx_to_linen (auto-detected): + Stacked 'layers' tensor → unstack along axis 1 → layers_N per-layer arrays + Integer-keyed layers/{N} → rename to layers_N + +Usage: + python linen_nnx_converter.py \\ + --source_path="gs://bucket/checkpoint/0/items" \\ + --target_path="gs://bucket/converted/" \\ + --direction=auto +""" + +import argparse +import os +import re +import time +from typing import Any + +# MUST set before importing JAX to force CPU-only mode +os.environ["JAX_PLATFORMS"] = "cpu" + +import jax +import numpy as np +from etils import epath +import orbax.checkpoint as ocp + + +def log(message: str) -> None: + print(f"[linen_nnx_converter] {message}") + + +# ── Format detection ─────────────────────────────────────────────────────────── + + +def detect_format(state: dict) -> str: + """Detects checkpoint format ('linen' or 'nnx') from top-level keys.""" + # NNX: uses 'model' as the top-level params key + if "model" in state: + return "nnx" + + if "params" not in state: + raise ValueError(f"Cannot detect checkpoint format: no 'model' or 'params' key. " f"Found: {list(state.keys())}") + + params = state["params"] + + # Linen: double-nested params/params/decoder + if isinstance(params, dict) and "params" in params: + inner = params["params"] + if isinstance(inner, dict) and ("decoder" in inner or "encoder" in inner): + return "linen" + + # Old NNX format: params/decoder (single-nested with value wrappers) + if isinstance(params, dict) and ("decoder" in params or "encoder" in params): + if _has_value_wrappers(params): + return "nnx" + + if "optimizer" in state: + return "nnx" + if "opt_state" in state: + return "linen" + + raise ValueError( + f"Could not detect checkpoint format. Keys: {list(state.keys())}, " + f"params keys: {list(params.keys()) if isinstance(params, dict) else type(params)}" + ) + + +# ── Value wrapper helpers ────────────────────────────────────────────────────── + + +def _has_value_wrappers(tree: Any) -> bool: + """Returns True if tree contains {value: array} wrappers (NNX style).""" + if isinstance(tree, dict): + if set(tree.keys()) == {"value"}: + inner = tree["value"] + if hasattr(inner, "shape") or isinstance(inner, np.ndarray): + return True + for v in tree.values(): + if _has_value_wrappers(v): + return True + return False + + +def _strip_value_wrappers(tree: Any) -> Any: + """Recursively strips {value: array} wrappers from a tree.""" + if isinstance(tree, dict): + if set(tree.keys()) == {"value"}: + inner = tree["value"] + if hasattr(inner, "shape") or isinstance(inner, np.ndarray): + return inner + return {k: _strip_value_wrappers(v) for k, v in tree.items()} + elif isinstance(tree, (list, tuple)): + return type(tree)(_strip_value_wrappers(item) for item in tree) + else: + return tree + + +def _add_value_wrappers(tree: Any) -> Any: + """Recursively wraps leaf arrays in {value: array} (NNX nnx.Param format).""" + if isinstance(tree, dict): + if set(tree.keys()) == {"value"}: + inner = tree["value"] + if hasattr(inner, "shape") or isinstance(inner, np.ndarray): + return tree # Already wrapped + return {k: _add_value_wrappers(v) for k, v in tree.items()} + elif isinstance(tree, (list, tuple)): + return type(tree)(_add_value_wrappers(item) for item in tree) + elif hasattr(tree, "shape") or isinstance(tree, np.ndarray): + return {"value": tree} + else: + return tree + + +# ── Layer structure helpers ──────────────────────────────────────────────────── + + +def _stack_layers(decoder: dict) -> tuple[dict, bool]: + """Stacks per-layer parameters (layers_N) into a single 'layers' dict at axis 0. + + Returns (result_dict, was_stacked). + """ + layer_pattern = re.compile(r"^layers_(\d+)$") + layer_indices = {} + other_keys = {} + + for key, value in decoder.items(): + match = layer_pattern.match(key) + if match: + layer_indices[int(match.group(1))] = value + else: + other_keys[key] = value + + if not layer_indices: + return decoder, False + + sorted_indices = sorted(layer_indices.keys()) + num_layers = len(sorted_indices) + log(f" Found {num_layers} individual layers, stacking into 'layers'") + + def stack_arrays(layers_data: list) -> Any: + first = layers_data[0] + if hasattr(first, "shape") or isinstance(first, np.ndarray): + return np.stack([np.asarray(layers_data[i]) for i in range(len(layers_data))], axis=0) + elif isinstance(first, dict): + result = {} + for key in first.keys(): + child_data = [layers_data[i].get(key) for i in range(len(layers_data))] + if all(c is not None for c in child_data): + result[key] = stack_arrays(child_data) + return result + else: + return first + + layers_data = [layer_indices[i] for i in sorted_indices] + stacked = stack_arrays(layers_data) + + result = dict(other_keys) + result["layers"] = stacked + return result, True + + +def _rename_layers_to_integer_keys(decoder: dict) -> dict: + """Converts layers_N keys to integer-keyed dict under 'layers' (no stacking). + + Converts {layers_0: {...}, layers_1: {...}} → {layers: {'0': {...}, '1': {...}}}. + Used for scan_layers=False linen→nnx conversion (Pattern C). + """ + layer_pattern = re.compile(r"^layers_(\d+)$") + layer_indices = {} + other_keys = {} + + for key, value in decoder.items(): + match = layer_pattern.match(key) + if match: + layer_indices[int(match.group(1))] = value + else: + other_keys[key] = value + + if not layer_indices: + return decoder + + sorted_indices = sorted(layer_indices.keys()) + log(f" Found {len(sorted_indices)} individual layers, renaming to integer-keyed 'layers/N'") + result = dict(other_keys) + result["layers"] = {str(i): layer_indices[i] for i in sorted_indices} + return result + + +def _transpose_layers_axes(tree: Any, src_axis: int, dst_axis: int) -> Any: + """Transposes the layers dimension in arrays within a tree (src_axis ↔ dst_axis).""" + if src_axis == dst_axis: + return tree + if isinstance(tree, dict): + return {k: _transpose_layers_axes(v, src_axis, dst_axis) for k, v in tree.items()} + elif isinstance(tree, (list, tuple)): + return type(tree)(_transpose_layers_axes(item, src_axis, dst_axis) for item in tree) + elif hasattr(tree, "shape") and len(tree.shape) >= 2: + axes = list(range(len(tree.shape))) + axes[src_axis], axes[dst_axis] = axes[dst_axis], axes[src_axis] + result = np.transpose(np.asarray(tree), axes=axes) + log(f" Transposed: {tree.shape} → {result.shape}") + return result + else: + return tree + + +def _detect_num_layers(tree: Any, scan_axis: int) -> int | None: + """Detects num_layers from the first array with ndim > scan_axis.""" + if hasattr(tree, "shape") or isinstance(tree, np.ndarray): + shape = getattr(tree, "shape", None) or np.asarray(tree).shape + if len(shape) > scan_axis: + return shape[scan_axis] + return None + if isinstance(tree, dict): + for v in tree.values(): + result = _detect_num_layers(v, scan_axis) + if result is not None: + return result + return None + + +def _unstack_single_layer(tree: Any, idx: int, scan_axis: int) -> Any: + """Extracts a single layer by indexing at scan_axis.""" + if hasattr(tree, "shape") or isinstance(tree, np.ndarray): + arr = np.asarray(tree) + if arr.ndim > scan_axis: + return np.take(arr, idx, axis=scan_axis) + return arr + if isinstance(tree, dict): + return {k: _unstack_single_layer(v, idx, scan_axis) for k, v in tree.items()} + if isinstance(tree, (list, tuple)): + return type(tree)(_unstack_single_layer(v, idx, scan_axis) for v in tree) + return tree + + +def _convert_layers_to_linen_format(decoder: dict) -> dict: + """Converts NNX 'layers' back to Linen's layers_N format (auto-detects NNX style). + + Handles: + - Stacked tensor (Pattern B): layers/ + → layers_0, layers_1, ... (unstack along axis 1) + - Integer-keyed (Pattern C): layers/0, layers/1, ... + → layers_0, layers_1, ... (rename) + """ + if "layers" not in decoder: + return decoder + + layers_val = decoder["layers"] + other_keys = {k: v for k, v in decoder.items() if k != "layers"} + + if not isinstance(layers_val, dict): + # Already a non-dict (shouldn't happen normally), keep as-is + return decoder + + # Pattern C: integer-keyed per-layer dict → rename + if all(k.isdigit() for k in layers_val.keys()): + result = dict(other_keys) + for idx_str, layer_data in sorted(layers_val.items(), key=lambda x: int(x[0])): + result[f"layers_{idx_str}"] = layer_data + log(f" Renamed integer-keyed layers/N → layers_N ({len(layers_val)} layers)") + return result + + # Pattern B: stacked tensor (layer dim at axis 1) → unstack + num_layers = _detect_num_layers(layers_val, scan_axis=1) + if num_layers is None: + log(" WARNING: Could not detect num_layers for unstacking, keeping 'layers' as-is") + result = dict(other_keys) + result["layers"] = layers_val + return result + + result = dict(other_keys) + for i in range(num_layers): + result[f"layers_{i}"] = _unstack_single_layer(layers_val, idx=i, scan_axis=1) + log(f" Unstacked scanned 'layers' → layers_N ({num_layers} layers at axis 1)") + return result + + +# ── Optimizer state helpers ──────────────────────────────────────────────────── + + +def _convert_opt_state_linen_to_nnx(opt_state: Any) -> Any: + """Removes 'params' nesting from mu/nu in linen opt_state. + + NNX optimizer state has plain arrays (no {value:} wrappers). + Linen opt_state mirrors the params structure (params/decoder/...), + so we remove the 'params' level to get decoder/... directly. + """ + if isinstance(opt_state, dict): + result = {} + for k, v in opt_state.items(): + if k == "params": + # Remove this level by merging its contents up + converted = _convert_opt_state_linen_to_nnx(v) + if isinstance(converted, dict): + result.update(converted) + else: + result[k] = converted + else: + result[k] = _convert_opt_state_linen_to_nnx(v) + return result + elif isinstance(opt_state, (list, tuple)): + return type(opt_state)(_convert_opt_state_linen_to_nnx(item) for item in opt_state) + else: + return opt_state # Plain array or scalar — no value wrapper for opt_state + + +def _convert_opt_state_nnx_to_linen(opt_state: Any, depth: int = 0) -> Any: + """Adds 'params' nesting to mu/nu, removes any stray {value:} wrappers. + + NNX optimizer mu/nu contains decoder/... directly. + Linen expects mu/params/decoder/... (one 'params' level mirroring the params structure). + """ + if isinstance(opt_state, dict): + # Strip any {value:} wrappers in opt_state (shouldn't be there but handle gracefully) + if set(opt_state.keys()) == {"value"}: + inner = opt_state["value"] + if hasattr(inner, "shape") or isinstance(inner, np.ndarray): + return inner + + result = {} + for k, v in opt_state.items(): + converted = _convert_opt_state_nnx_to_linen(v, depth + 1) + # Add one 'params' level after mu/nu (mirrors linen's params structure) + if k in ("mu", "nu") and isinstance(converted, dict): + result[k] = {"params": converted} + else: + result[k] = converted + return result + elif isinstance(opt_state, (list, tuple)): + return type(opt_state)(_convert_opt_state_nnx_to_linen(item, depth + 1) for item in opt_state) + else: + return opt_state + + +# ── Main conversion functions ────────────────────────────────────────────────── + + +def convert_linen_to_nnx(state: dict, scan_layers: bool = True) -> dict: + """Converts Linen checkpoint to NNX format. + + Args: + state: Linen checkpoint dict with keys ['params', 'opt_state', 'step']. + scan_layers: If True (default), stack per-layer arrays and insert layer + dim at axis 1 (for NNX with scan_layers=True). + If False, rename layers_N → integer-keyed layers/N + (for NNX with scan_layers=False). + """ + result = {} + + if "params" in state: + linen_params = state["params"] + # Remove double 'params' nesting: params/params/decoder → decoder + if isinstance(linen_params, dict) and "params" in linen_params: + nnx_params = linen_params["params"] + log(" params: Removed double 'params' nesting (params/params → model)") + else: + nnx_params = linen_params + log(" params: No double nesting found") + + stripped = _strip_value_wrappers(nnx_params) + + for component in ("decoder", "encoder"): + if component in stripped and isinstance(stripped[component], dict): + if scan_layers: + stripped[component], was_stacked = _stack_layers(stripped[component]) + if was_stacked and "layers" in stripped[component]: + log(f" {component}/layers: Transposing stacked (layers, ...) → (..., layers, ...) at axis 1") + stripped[component]["layers"] = _transpose_layers_axes(stripped[component]["layers"], src_axis=0, dst_axis=1) + else: + stripped[component] = _rename_layers_to_integer_keys(stripped[component]) + + result["model"] = _add_value_wrappers(stripped) + log(" model: Saved with {value:} wrappers under 'model' key") + + # optimizer: move step inside, keep opt_state + optimizer_dict = {} + if "step" in state: + optimizer_dict["step"] = state["step"] + log(f" optimizer/step: Moved from top-level (step={state['step']})") + if "opt_state" in state: + optimizer_dict["opt_state"] = _convert_opt_state_linen_to_nnx(state["opt_state"]) + log(" optimizer/opt_state: Removed 'params' nesting from mu/nu") + if optimizer_dict: + result["optimizer"] = optimizer_dict + + return result + + +def convert_nnx_to_linen(state: dict) -> dict: + """Converts NNX checkpoint to Linen format. + + Reads from 'model'/'optimizer' keys (or falls back to old 'params'/'opt_state' format). + Layer structure is auto-detected (stacked vs integer-keyed). + """ + result = {} + + model_key = "model" if "model" in state else "params" + if model_key in state: + nnx_params = state[model_key] + stripped = _strip_value_wrappers(nnx_params) + log(f" {model_key}: Removed {{value:}} wrappers") + + for component in ("decoder", "encoder"): + if component in stripped and isinstance(stripped[component], dict): + stripped[component] = _convert_layers_to_linen_format(stripped[component]) + + # Add double 'params' nesting: decoder → params/params/decoder + result["params"] = {"params": stripped} + log(" params: Added double 'params' nesting (model → params/params)") + + # optimizer: extract step and opt_state back to top level + if "optimizer" in state: + optimizer = state["optimizer"] + if "step" in optimizer: + result["step"] = optimizer["step"] + log(" step: Extracted from optimizer/step to top level") + if "opt_state" in optimizer: + result["opt_state"] = _convert_opt_state_nnx_to_linen(optimizer["opt_state"]) + log(" opt_state: Added 'params' nesting to mu/nu") + elif "opt_state" in state: + # Backward compat: old format with opt_state at top level + result["opt_state"] = _convert_opt_state_nnx_to_linen(state["opt_state"]) + log(" opt_state: Converted from top-level opt_state (old format)") + + if "step" in state and "step" not in result: + result["step"] = state["step"] + + return result + + +# ── Checkpoint I/O ───────────────────────────────────────────────────────────── + + +def load_checkpoint(checkpoint_path: str) -> dict: + """Loads checkpoint from local or GCS path.""" + log(f"Loading checkpoint from: {checkpoint_path}") + + checkpoint_dir = epath.Path(checkpoint_path) + ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler()) + metadata = ckptr.metadata(checkpoint_dir) + + devices = np.array(jax.devices()).reshape((-1,)) + single_device_mesh = jax.sharding.Mesh(devices, ("x",)) + unsharded = jax.sharding.NamedSharding(single_device_mesh, jax.sharding.PartitionSpec()) + + restore_args = jax.tree_util.tree_map( + lambda x: ocp.ArrayRestoreArgs(sharding=unsharded) if hasattr(x, "shape") else None, + metadata.item_metadata.tree, + is_leaf=lambda x: hasattr(x, "shape"), + ) + + state = ckptr.restore(checkpoint_dir, restore_args=restore_args) + log(f" Loaded keys: {list(state.keys())}") + return state + + +def save_checkpoint(state: dict, output_path: str) -> None: + """Saves checkpoint to local or GCS path.""" + log(f"Saving checkpoint to: {output_path}") + + output_dir = epath.Path(output_path) + output_dir.mkdir(exist_ok=True, parents=True) + + ckptr = ocp.PyTreeCheckpointer() + ckptr.save(output_dir, state, force=True) + log(" Checkpoint saved successfully") + + +# ── CLI ──────────────────────────────────────────────────────────────────────── + + +def main(): + parser = argparse.ArgumentParser( + description="Convert between Linen and NNX checkpoint formats.", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--source_path", + type=str, + required=True, + help="Path to source checkpoint items directory (e.g. gs://bucket/ckpt/0/items).", + ) + parser.add_argument( + "--target_path", + type=str, + required=True, + help="Path to save converted checkpoint.", + ) + parser.add_argument( + "--direction", + type=str, + choices=["auto", "linen_to_nnx", "nnx_to_linen"], + default="auto", + help="Conversion direction. 'auto' detects from source format.", + ) + parser.add_argument( + "--scan_layers", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "For linen_to_nnx only: if True (default), stack per-layer arrays into a " + "scanned 'layers' tensor with layer dim at axis 1 (for NNX with scan_layers=True). " + "If False, rename layers_N to integer-keyed layers/N without stacking " + "(for NNX with scan_layers=False)." + ), + ) + + args = parser.parse_args() + + print("=" * 80) + print("Linen <-> NNX Checkpoint Converter") + print("=" * 80) + + start_time = time.time() + + state = load_checkpoint(args.source_path) + + if args.direction == "auto": + source_format = detect_format(state) + target_format = "nnx" if source_format == "linen" else "linen" + log(f"Auto-detected: {source_format} → {target_format}") + else: + source_format = args.direction.split("_to_")[0] + target_format = args.direction.split("_to_")[1] + log(f"Using specified direction: {source_format} → {target_format}") + + log(f"Converting: {source_format} → {target_format}") + if source_format == "linen": + log(f"scan_layers={args.scan_layers}") + + if source_format == "linen" and target_format == "nnx": + converted_state = convert_linen_to_nnx(state, scan_layers=args.scan_layers) + elif source_format == "nnx" and target_format == "linen": + converted_state = convert_nnx_to_linen(state) + else: + raise ValueError(f"Invalid conversion: {source_format} → {target_format}") + + save_checkpoint(converted_state, args.target_path) + + elapsed = time.time() - start_time + print("\n" + "=" * 80) + print(f"Conversion complete in {elapsed:.2f} seconds") + print(f" Source: {args.source_path}") + print(f" Target: {args.target_path}") + print(f" Direction: {source_format} → {target_format}") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 3ff1c33153..52d143852e 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -1113,8 +1113,8 @@ position_id_per_seconds: 25 subslice_shape: "" # NNX -enable_nnx: false -pure_nnx_decoder: false +enable_nnx: True +pure_nnx_decoder: True ################################## Qwen3-Next Specific Configs ################################## # Kernel size for the 1D convolution in the Gated Delta Net diff --git a/src/maxtext/layers/attentions.py b/src/maxtext/layers/attentions.py index 813cb33014..824b7590eb 100644 --- a/src/maxtext/layers/attentions.py +++ b/src/maxtext/layers/attentions.py @@ -534,14 +534,14 @@ def __init__( elif self.is_qwen3_next: self.query_norm = Qwen3NextRMSNorm( num_features=self.config.head_dim, - eps=self.config.normalization_layer_epsilon, + epsilon=self.config.normalization_layer_epsilon, dtype=self.config.dtype, weight_dtype=self.config.weight_dtype, rngs=self.rngs, ) self.key_norm = Qwen3NextRMSNorm( num_features=self.config.head_dim, - eps=self.config.normalization_layer_epsilon, + epsilon=self.config.normalization_layer_epsilon, dtype=self.config.dtype, weight_dtype=self.config.weight_dtype, rngs=self.rngs, diff --git a/src/maxtext/layers/initializers.py b/src/maxtext/layers/initializers.py index 20baf9a633..e7ea2094db 100644 --- a/src/maxtext/layers/initializers.py +++ b/src/maxtext/layers/initializers.py @@ -94,6 +94,16 @@ def variable_to_logically_partitioned(variable: nnx.VariableState): out_sharding = metadata["sharding"] if out_sharding is not None: + if nnx.PARTITION_NAME in metadata: + partition_name = metadata[nnx.PARTITION_NAME] + scan_axis = metadata.get("param_scan_axis", 0) if variable.type == nnx.Param else 0 + + sharding_list = [out_sharding] if isinstance(out_sharding, str) else list(out_sharding) + if partition_name not in sharding_list: + sharding_list.insert(scan_axis, partition_name) + + out_sharding = tuple(sharding_list) + return nn.LogicallyPartitioned( # type: ignore[wrong-keyword-args] variable.value, out_sharding, # type: ignore[arg-type] diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index c96ec08c8d..20805d7bc5 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -71,7 +71,7 @@ class NNXDecoderLayer(nnx.Module): """ - Transformer decoder layer converted to NNX. + Transformer decoder layer converted to NNX """ def __init__( @@ -307,11 +307,12 @@ def __init__( dense_cls, moe_cls = decoder_block_classes num_dense = config.first_num_dense_layers - self.dense_layers = self._create_scanned_layers(dense_cls, length=num_dense, rngs=rngs) - + self.dense_layers = self._create_scanned_layers( + dense_cls, length=num_dense, metadata_axis_name="dense_layers", rngs=rngs + ) num_moe = config.num_decoder_layers - config.first_num_dense_layers + self.moe_layers = self._create_scanned_layers(moe_cls, length=num_moe, metadata_axis_name="moe_layers", rngs=rngs) - self.moe_layer = self._create_scanned_layers(moe_cls, length=num_moe, rngs=rngs) elif self.is_gemma3: attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN) scan_length = config.num_decoder_layers // attention_pattern_length @@ -323,7 +324,9 @@ def __init__( RemattedGemma3Block = gemma3.Gemma3ScannableBlock if scan_length > 0: - self.layers = self._create_scanned_layers(RemattedGemma3Block, length=scan_length, rngs=rngs, **layer_kwargs) + self.layers = self._create_scanned_layers( + RemattedGemma3Block, length=scan_length, metadata_axis_name="layers", rngs=rngs, **layer_kwargs + ) self.layers_remainder = RemattedGemma3Block( config=self.config, mesh=mesh, quant=self.quant, model_mode=self.model_mode, **rem_layer_kwargs, rngs=rngs ) # pytype: disable=wrong-keyword-args @@ -337,7 +340,13 @@ def __init__( "interleave_moe_layer_step": self.config.interleave_moe_layer_step, } - self.layers = self._create_scanned_layers(layer_cls, length=num_layers, rngs=rngs, **layer_kwargs) + if num_layers > 0: + self.layers = self._create_scanned_layers( + layer_cls, length=num_layers, metadata_axis_name="layers", rngs=rngs, **layer_kwargs + ) + else: + self.layers = nnx.List([]) + else: self.layers = nnx.List([]) @@ -386,34 +395,86 @@ def _create_single_layer(self, decoder_layer_class, rngs, **kwargs): ) return nnx_wrappers.ToNNX(layer_linen, rngs=rngs) - def _create_scanned_layers(self, decoder_layer_class, length: int, rngs: nnx.Rngs, **layer_kwargs): - """Creates a VMapped stack of layers, forcing parameter init for Compact modules.""" - - def create_layer_fn(rng): - layer = decoder_layer_class( - config=self.config, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode, rngs=rng, **layer_kwargs - ) + def _create_scanned_layers( + self, decoder_layer_class, length: int, metadata_axis_name: str, rngs: nnx.Rngs, **layer_kwargs + ): + """Creates a scanned stack of layers using jax.lax.scan for memory-efficient initialization. - return layer + Uses jax.lax.scan instead of nnx.vmap to reduce peak memory during initialization. + With vmap, all layers' parameters are created simultaneously (O(N) peak memory). + With scan, parameters are created one layer at a time (O(1) peak intermediate memory), + which prevents OOM on memory-constrained devices like TPU v6e-4. + """ + scan_axis = self.config.param_scan_axis - # Workaround for Deepseek MTP test failure. - # TODO: Handle this properly. + # Fork rngs to get per-layer RNG states for scanning try: forked_rngs = rngs.fork(split=length) - except: # pylint: disable=bare-except pass - out_axes = nnx.StateAxes({nnx.Param: self.config.param_scan_axis, ...: 0}) - layers_vmapped = nnx.vmap( - create_layer_fn, - in_axes=0, - out_axes=out_axes, - axis_name="layers", - transform_metadata={nnx.PARTITION_NAME: "layers"}, - )(forked_rngs) + rngs_graphdef, rngs_state = nnx.split(forked_rngs) + + # Create a reference layer to capture the module graph structure (graphdef). + # This layer's params are discarded — only the structure is kept. + # Must use the first slice of the forked rngs (not a dummy Rngs(0)) so the + # graphdef has the same number of RNG state leaves as the scan-created layers. + first_rng_state = jax.tree.map(lambda x: x[0], rngs_state) + ref_rngs = nnx.merge(rngs_graphdef, first_rng_state) + ref_layer = decoder_layer_class( + config=self.config, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode, rngs=ref_rngs, **layer_kwargs + ) + layer_graphdef, _, _ = nnx.split(ref_layer, nnx.Param, ...) + del ref_layer + + # Sequentially create each layer's parameters via jax.lax.scan. + # The scan body is traced once; XLA executes it N times with different RNG keys, + # keeping only one layer's intermediate state alive at a time. + def scan_body(carry, rng_state_slice): + layer_rngs = nnx.merge(rngs_graphdef, rng_state_slice) + layer = decoder_layer_class( + config=self.config, + mesh=self.mesh, + quant=self.quant, + model_mode=self.model_mode, + rngs=layer_rngs, + **layer_kwargs, + ) + _, params, rest = nnx.split(layer, nnx.Param, ...) + return carry, (params, rest) + + _, (stacked_params, stacked_rest) = jax.lax.scan(scan_body, None, rngs_state) - return layers_vmapped + # jax.lax.scan stacks outputs along axis 0. Move params to the configured scan axis. + if scan_axis != 0: + stacked_params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), stacked_params) + + # Add partition metadata that nnx.vmap's transform_metadata would normally set. + # This metadata is read by variable_to_logically_partitioned() in initializers.py + # and by nnx.get_partition_spec() (via the updated out_sharding) to produce + # correct sharding specs that include the scan axis dimension. + def _add_scan_metadata(state, axis): + def _update_leaf(leaf): + if isinstance(leaf, nnx.VariableState): + metadata = leaf.get_metadata() + metadata[nnx.PARTITION_NAME] = metadata_axis_name + metadata["param_scan_axis"] = axis + # Insert the scan axis name into out_sharding so that + # nnx.get_partition_spec returns specs matching the actual tensor rank. + # Without this, scanned params are 3D but specs remain 2D. + if "out_sharding" in metadata and metadata["out_sharding"]: + out_sharding = list(metadata["out_sharding"]) + out_sharding.insert(axis, metadata_axis_name) + metadata["out_sharding"] = tuple(out_sharding) + return leaf.replace(**metadata) + return leaf + + return jax.tree.map(_update_leaf, state, is_leaf=lambda x: isinstance(x, nnx.VariableState)) + + stacked_params = _add_scan_metadata(stacked_params, scan_axis) + stacked_rest = _add_scan_metadata(stacked_rest, 0) + + return nnx.merge(layer_graphdef, stacked_params, stacked_rest) def _apply_layer_with_remat(self, layer: nnx.Module, y: jax.Array, policy: Any, prevent_cse: bool, **kwargs): """Helper to cleanly apply jax.checkpoint to a single unscanned layer or block.""" @@ -435,54 +496,54 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs) """Runs the layer stack using nnx.scan.""" policy = self.get_remat_policy() prevent_cse = maxtext_utils.should_prevent_cse_in_remat(self.config) - graphdef, params, state = nnx.split( - layers, nnx.Param, ... - ) # state: the mutable state we carry (KV cache, RNGs, etc.) + graphdef, params, state = nnx.split(layers, nnx.Param, ...) scan_axis = self.config.param_scan_axis if scan_axis != 0: - # Move scan_axis to 0 so scan can iterate over it params = jax.tree.map(lambda x: jnp.moveaxis(x, scan_axis, 0), params) layer_cls = layers.__class__ sig = inspect.signature(layer_cls.__call__) valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters} - layer_cls = layers.__class__ # Access the underlying class - sig = inspect.signature(layer_cls.__call__) - # Filter kwargs to only include keys that exist in the layer's signature - valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters} + def _extract_matching_state(template, full): + if isinstance(template, nnx.State): + return nnx.State({k: _extract_matching_state(v, full[k]) for k, v in template.items()}) + elif isinstance(template, dict): + return {k: _extract_matching_state(v, full[k]) for k, v in template.items()} + return full def layer_fn(carry, scanned_vars): - # Unpack the sliced variables for THIS layer current_params, current_state = scanned_vars if self.config.parameter_memory_host_offload: current_params = jax.tree.map(lambda x: jax.device_put(x, max_utils.device_space()), current_params) - # Merge using the SLICED state layer = nnx.merge(graphdef, current_params, current_state) - - # Run the layer (Filter kwargs if using the solution from previous turn) layer_out = layer(carry, *args, **valid_kwargs) - new_carry = layer_out[0] if isinstance(layer_out, tuple) else layer_out - # Extract the updated state to return it - # _, new_current_state = nnx.split(layer, nnx.Param, ...) - new_current_state = nnx.state(layer) + new_full_state = nnx.state(layer) + new_current_state = _extract_matching_state(current_state, new_full_state) + + # ONLY return non-param state to prevent memory duplication of weights return new_carry, new_current_state layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse) - final_carry, scanned_state = jax.lax.scan(layer_fn, x_in, (params, state)) + final_carry, scanned_other = jax.lax.scan(layer_fn, x_in, (params, state)) if scan_axis != 0: - scanned_params, scanned_other = scanned_state.split(nnx.Param, ...) - scanned_params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), scanned_params) - scanned_state = nnx.State.merge(scanned_params, scanned_other) + params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), params) - return final_carry, nnx.merge(graphdef, scanned_state) + scanned_state = nnx.State.merge(params, scanned_other) + # Update the existing module in-place rather than creating a new one. + # Creating a new module via nnx.merge and reassigning (self.layers = new_module) + # would replace a child node in the NNX graph, which is detected as a graph + # structure mutation when the parent module is inside a JAX transformation + # (e.g., nnx.jit in PeftTrainer). In-place update preserves object identity. + nnx.update(layers, scanned_state) + return final_carry, layers def get_decoder_layers(self): """Retrieves decoder layer classes based on config using a dictionary lookup.""" @@ -829,10 +890,19 @@ def _find_next_boundary(self, current_idx, end_idx, engram_indices): def _apply_single_engram_layer(self, y, current_idx, layer_stack, *args, **kwargs): """Applies a single, unscanned Engram layer by dynamically slicing the NNX state.""" graphdef, state = nnx.split(layer_stack) + params, rest = state.split(nnx.Param, ...) + scan_axis = self.config.param_scan_axis + + # Helper to generate N-dimensional basic slices (e.g., x[:, idx, :]) + def _extract_slice(x, idx, axis): + slices = tuple(idx if i == axis else slice(None) for i in range(x.ndim)) + return x[slices] - # Slice the parameters for the current index (assuming scan axis is 0) - sliced_state = jax.tree.map(lambda x: x[current_idx], state) - single_layer = nnx.merge(graphdef, sliced_state) + # Slice using native indexing instead of jnp.take + sliced_params = jax.tree.map(lambda x: _extract_slice(x, current_idx, scan_axis), params) + sliced_rest = jax.tree.map(lambda x: _extract_slice(x, current_idx, 0), rest) + + single_layer = nnx.merge(graphdef, sliced_params, sliced_rest) # Run the single layer out = single_layer( @@ -841,14 +911,23 @@ def _apply_single_engram_layer(self, y, current_idx, layer_stack, *args, **kwarg y = out[0] if isinstance(out, tuple) else out # Re-merge the updated state back into the specific slice of the stack - new_single_state = nnx.state(single_layer) - updated_state = jax.tree.map( + new_state = nnx.state(single_layer) + new_params, new_rest = new_state.split(nnx.Param, ...) + + updated_params = jax.tree.map( + lambda s, new_s: jax.lax.dynamic_update_slice_in_dim( + s, jnp.expand_dims(new_s, axis=scan_axis), current_idx, axis=scan_axis + ), + params, + new_params, + ) + updated_rest = jax.tree.map( lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, jnp.expand_dims(new_s, axis=0), current_idx, axis=0), - state, - new_single_state, + rest, + new_rest, ) - nnx.update(layer_stack, updated_state) + nnx.update(layer_stack, updated_params, updated_rest) return y def _apply_scanned_chunk(self, y, current_idx, next_boundary, layer_stack, *args, **kwargs): @@ -856,10 +935,15 @@ def _apply_scanned_chunk(self, y, current_idx, next_boundary, layer_stack, *args scan_length = next_boundary - current_idx if scan_length > 0: graphdef, state = nnx.split(layer_stack) + params, rest = state.split(nnx.Param, ...) + scan_axis = self.config.param_scan_axis - # Slice the chunk state - chunk_state = jax.tree.map(lambda x: jax.lax.dynamic_slice_in_dim(x, current_idx, scan_length, axis=0), state) - chunk_stack = nnx.merge(graphdef, chunk_state) + # Slice the chunk state along the correct axes + chunk_params = jax.tree.map( + lambda x: jax.lax.dynamic_slice_in_dim(x, current_idx, scan_length, axis=scan_axis), params + ) + chunk_rest = jax.tree.map(lambda x: jax.lax.dynamic_slice_in_dim(x, current_idx, scan_length, axis=0), rest) + chunk_stack = nnx.merge(graphdef, chunk_params, chunk_rest) # Apply sequentially y, chunk_stack = self._apply_layers_sequentially( @@ -867,11 +951,17 @@ def _apply_scanned_chunk(self, y, current_idx, next_boundary, layer_stack, *args ) # Update the original stack state - new_chunk_state = nnx.state(chunk_stack) - updated_state = jax.tree.map( - lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, new_s, current_idx, axis=0), state, new_chunk_state + new_state = nnx.state(chunk_stack) + new_params, new_rest = new_state.split(nnx.Param, ...) + + updated_params = jax.tree.map( + lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, new_s, current_idx, axis=scan_axis), params, new_params + ) + updated_rest = jax.tree.map( + lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, new_s, current_idx, axis=0), rest, new_rest ) - nnx.update(layer_stack, updated_state) + + nnx.update(layer_stack, updated_params, updated_rest) return y @@ -961,7 +1051,7 @@ def __call__( y = self._apply_interleaved_scanned_layers( y, - self.moe_layer, + self.moe_layers, 0, (cfg.num_decoder_layers - cfg.first_num_dense_layers), [e - cfg.first_num_dense_layers for e in cfg.engram_layers], @@ -978,7 +1068,7 @@ def __call__( if cfg.use_batch_split_schedule: policy = self.get_remat_policy() - mock_params = self._build_linen_params(self.moe_layer) + mock_params = self._build_linen_params(self.moe_layers) y = deepseek_batchsplit.scan_batch_split_layers( y, @@ -992,8 +1082,8 @@ def __call__( policy=policy, ) else: - y, self.moe_layer = self._apply_layers_sequentially( - self.moe_layer, y, *layer_args, length=num_moe, **layer_kwargs + y, self.moe_layers = self._apply_layers_sequentially( + self.moe_layers, y, *layer_args, length=num_moe, **layer_kwargs ) elif self.is_gemma3: y = self._apply_gemma3_scanned_blocks( @@ -1009,7 +1099,10 @@ def __call__( ) else: scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval) - y, self.layers = self._apply_layers_sequentially(self.layers, y, *layer_args, length=scan_length, **layer_kwargs) + if scan_length > 0: + y, self.layers = self._apply_layers_sequentially( + self.layers, y, *layer_args, length=scan_length, **layer_kwargs + ) else: prevent_cse = maxtext_utils.should_prevent_cse_in_remat(cfg) @@ -1027,7 +1120,16 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in): for lyr, layer in enumerate(self.layers): graphdef, state = nnx.split(layer) - kv_cache = kv_caches[lyr] if kv_caches is not None else None + if kv_caches is not None: + if cfg.decoder_block == DecoderBlockType.QWEN3_NEXT: + if (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0: + kv_cache = (kv_caches["key_cache"][lyr], kv_caches["value_cache"][lyr]) + else: + kv_cache = None + else: + kv_cache = kv_caches[lyr] + else: + kv_cache = None input_tokens = decoder_input_tokens if cfg.engram_layers else None if input_tokens is not None: @@ -1037,7 +1139,12 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in): nnx.update(layer, new_state) if kv_caches is not None and kv_cache is not None: - kv_caches[lyr] = kv_cache + if cfg.decoder_block == DecoderBlockType.QWEN3_NEXT: + if (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0: + kv_caches["key_cache"][lyr] = kv_cache[0] + kv_caches["value_cache"][lyr] = kv_cache[1] + else: + kv_caches[lyr] = kv_cache if deepstack_visual_embeds is not None and lyr < len(deepstack_visual_embeds): visual_embeds = deepstack_visual_embeds[lyr] @@ -1059,7 +1166,7 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in): # When vocab tiling is enabled in training mode, full logits won't generate to reduce memory # Instead, we keep track on the hidden states, which has smaller size compared to full logits - if cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN: + elif cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN: logits = None self.sow(nnx.Intermediate, "hidden_states", hidden_state) @@ -1124,7 +1231,7 @@ def decoder_as_linen( model_mode: str, quant: None | Quant = None, ): - """Creates a Decoder module.""" + """Creates a Decoder module""" module = nnx_wrappers.to_linen( NNXDecoder, config=config, diff --git a/src/maxtext/layers/normalizations.py b/src/maxtext/layers/normalizations.py index 195d5bcc14..be6f56c8a4 100644 --- a/src/maxtext/layers/normalizations.py +++ b/src/maxtext/layers/normalizations.py @@ -102,7 +102,17 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) -> return y_flat.reshape(input_shape) -def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: DType, *, rngs: nnx.Rngs): +def Qwen3NextRMSNorm( + num_features: int, + epsilon: float, + dtype: DType, + weight_dtype: DType, + shard_mode: ShardMode = ShardMode.AUTO, + kernel_axes: tuple[None | str, ...] = (), + parameter_memory_host_offload: bool = False, + *, + rngs: nnx.Rngs, +): """ Used for input and post attention layernorms in Qwen3NextDecoderLayer. @@ -115,10 +125,13 @@ def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: return nnx.data( RMSNorm( num_features=num_features, - epsilon=eps, + epsilon=epsilon, dtype=dtype, weight_dtype=weight_dtype, + shard_mode=shard_mode, + kernel_axes=kernel_axes, scale_init=linen_initializers.zeros, + parameter_memory_host_offload=parameter_memory_host_offload, scale_offset=1.0, rngs=rngs, ) diff --git a/src/maxtext/models/models.py b/src/maxtext/models/models.py index 0d1fcab700..bd6324e607 100644 --- a/src/maxtext/models/models.py +++ b/src/maxtext/models/models.py @@ -33,7 +33,7 @@ from maxtext.layers.decoders import Decoder from maxtext.layers.embeddings import Embed, embed_as_linen from maxtext.layers.encoders import AudioEncoder, VisionEncoder, audio_encoder_as_linen, vision_encoder_as_linen -from maxtext.layers.multi_token_prediction import multi_token_prediction_block_as_linen +from maxtext.layers.multi_token_prediction import MultiTokenPredictionBlock, multi_token_prediction_block_as_linen from maxtext.layers.quantizations import AqtQuantization as Quant from maxtext.multimodal import processor as mm_processor from maxtext.utils import max_utils @@ -376,25 +376,12 @@ def __init__( # For MTP, we use the DecoderLayer blueprint to ensure architectural consistency. # By convention, this is the last layer in the list. mtp_layer = layer_types[-1] - mtp_block_linen = multi_token_prediction_block_as_linen( + self.mtp_block = MultiTokenPredictionBlock( config=self.config, mesh=self.mesh, transformer_layer_module=mtp_layer, decoder=self.decoder, rngs=rngs, - name="mtp_block", - ) - self.mtp_block = nnx_wrappers.ToNNX(mtp_block_linen, rngs=rngs) - - self.mtp_block.lazy_init( - shared_embedding=self.token_embedder, - main_hidden_state=jnp.ones((1, 1, self.config.emb_dim), dtype=self.config.dtype), - input_ids=jnp.ones((1, 1), dtype=jnp.int32), - target_ids=jnp.ones((1, 1), dtype=jnp.int32), - target_mask=jnp.ones((1, 1), dtype=jnp.int32), - position_ids=jnp.ones((1, 1), dtype=jnp.int32), - decoder_segment_ids=jnp.ones((1, 1), dtype=jnp.int32), - deterministic=True, ) def no_op(self, *args, **kwargs): diff --git a/src/maxtext/models/qwen3.py b/src/maxtext/models/qwen3.py index eb15747fc2..5ba630adc3 100644 --- a/src/maxtext/models/qwen3.py +++ b/src/maxtext/models/qwen3.py @@ -962,7 +962,7 @@ def __init__( # First LayerNorm, applied before the attention block. self.input_layernorm = Qwen3NextRMSNorm( num_features=cfg.emb_dim, - eps=cfg.normalization_layer_epsilon, + epsilon=cfg.normalization_layer_epsilon, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, rngs=rngs, @@ -987,7 +987,7 @@ def __init__( # Second LayerNorm, applied before the MoE block. self.post_attention_layernorm = Qwen3NextRMSNorm( num_features=cfg.emb_dim, - eps=cfg.normalization_layer_epsilon, + epsilon=cfg.normalization_layer_epsilon, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, rngs=rngs, diff --git a/src/maxtext/trainers/pre_train/nnx_train.py b/src/maxtext/trainers/pre_train/nnx_train.py new file mode 100644 index 0000000000..50e8a2d264 --- /dev/null +++ b/src/maxtext/trainers/pre_train/nnx_train.py @@ -0,0 +1,883 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""NNX-native pre-training loop for MaxText. + +This module implements a pre-training loop that uses the Flax NNX API throughout, +in contrast to train.py which wraps NNX models inside Linen's TrainState. + + + Architecture + + ┌─────────────────────────────────┬──────────────────────────────────────────────────────────────────────────┐ + │function │ What it does │ + ├─────────────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ + │ loss_fn │ Forward-pass + cross-entropy; for both train and eval; │ + │ │ called directly on an nnx.Module │ + ├─────────────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ + │ train_step │ Functional step — merges (graphdef, opt_state) → runs nnx.value_and_grad │ + │ │ → updates optimizer → returns new nnx.State │ + ├─────────────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ + │ eval_step │ Same merge pattern, forward-only, no grads │ + ├─────────────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ + │ _create_and_shard_optimizer │ Wraps model + optax tx in nnx.Optimizer, derives partition specs via │ + │ │ nnx.get_partition_spec, shards state with jax.jit(out_shardings=…) │ + ├─────────────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ + │ _build_jit_steps │ Partially applies static (graphdef, config) then wraps with │ + │ │ jax.jit(in_shardings, out_shardings, donate_argnums=(0,1)) │ + ├─────────────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ + │ _maybe_restore_checkpoint / │ Orbax round-trip using the NNX {"value": array} wire format │ + │ _maybe_save_checkpoint │ │ + ├─────────────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ + │ train_loop │ Full loop: model → optimizer → data → checkpoint → JIT compile → step → │ + │ │ eval → log │ + ├─────────────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ + │ main / initialize / run │ Entry-point boilerplate matching train.py conventions │ + └─────────────────────────────────┴──────────────────────────────────────────────────────────────────────────┘ + + Key differences from train.py + + - No Linen TrainState — state lives in nnx.Optimizer (model params + optax state + step counter). + - Gradient computation uses nnx.value_and_grad, which is NNX-graph-aware. It differentiates only through + nnx.Param variables; non-differentiable NNX variables (RNGs, cache, …) are untouched. + - Gradient clipping uses optax.clip_by_global_norm directly, avoiding the Linen-TrainState coupling in + apply_gradient_clipping. + - JIT boundary: graphdef is a Python-static closure; only opt_state (a plain pytree of arrays) crosses the JIT + boundary with donate_argnums=(0,1) + - The JIT boundary uses split/merge so that graphdef is static and state is + donated as a pytree, preserving full sharding control via jax.jit shardings. + - Checkpointing saves/restores the raw nnx.State pytree via Orbax. + +Entry point: + python -m maxtext.trainers.pre_train.nnx_train [overrides…] +""" + +import contextlib +import datetime +import functools +import os +from typing import Any, Sequence + +import jax +import jax.numpy as jnp +import numpy as np +import optax +from absl import app +from flax import linen as nn +from flax import nnx +from flax.linen import partitioning as nn_partitioning +from jax.sharding import Mesh + +from maxtext.common import checkpointing, profiler +from maxtext.common.common_types import ShardMode +from maxtext.common.data_loader import create_dataloader +from maxtext.common.gcloud_stub import cloud_diagnostics as _cloud_diag +from maxtext.common.gcloud_stub import is_decoupled, vertex_tensorboard_modules +from maxtext.common.goodput import ( + RECORD_JOB_END_TIME, + RECORD_JOB_START_TIME, + GoodputEvent, + create_goodput_recorder, + maybe_monitor_goodput, + maybe_record_goodput, + record_goodput, +) +from maxtext.common.metric_logger import MetricLogger, record_activation_metrics +from maxtext.configs import pyconfig +from maxtext.input_pipeline.input_pipeline_interface import create_data_iterator +from maxtext.layers.multi_token_prediction import calculate_mtp_acceptance_rate, calculate_mtp_loss +from maxtext.optimizers import optimizers +from maxtext.utils import exceptions, max_logging, max_utils, maxtext_utils, model_creation_utils, sharding +from maxtext.utils.globals import EPS +from maxtext.utils.gradient_accumulation import nnx_gradient_accumulation_loss_and_grad +from maxtext.utils.rampup_batch import create_rampup_manager + +_diag_modules = _cloud_diag() +diagnostic, debug_configuration, diagnostic_configuration, stack_trace_configuration = _diag_modules +VertexTensorboardManager, _vertex_tb_is_stub = vertex_tensorboard_modules() + + +# --------------------------------------------------------------------------- +# Loss computation for both train and eval +# --------------------------------------------------------------------------- + + +def loss_fn(model: nnx.Module, config, data: dict[str, jax.Array], dropout_rng: jax.Array, is_train=True): + """Compute cross-entropy loss for one batch using an NNX model. + + Args: + model: The NNX Transformer (or compatible) model. Called in-place; no + explicit params argument is needed because the NNX module carries state. + config: MaxText Config object. + data: Batch dict with keys "inputs", "inputs_position", "inputs_segmentation", + "targets", "targets_segmentation". + dropout_rng: PRNG key used to seed dropout layers. + is_train: True for train_step and False for eval_step. + + Returns: + (loss, aux) where loss is a scalar and aux is a dict of auxiliary metrics. + """ + # rng1, aqt_rng = jax.random.split(dropout_rng) + + # Trim to micro-batch size (handles per_device_batch_size < 1 cases) + # decimate proportion of data when per_device_batch_size<1 + if is_train: + batch = {k: v[: config.micro_batch_size_to_train_on, :] for k, v in data.items()} + else: + batch = {k: v[: config.micro_batch_size_to_eval_on, :] for k, v in data.items()} + + # Flax NNX model + logits = model( + decoder_input_tokens=batch["inputs"], + decoder_positions=batch["inputs_position"], + decoder_segment_ids=batch["inputs_segmentation"], + encoder_images=batch["images"] if config.use_multimodal else None, + encoder_image_masks=batch["image_masks"] if config.use_multimodal and "image_masks" in batch else None, + enable_dropout=config.enable_dropout if is_train else False, + decoder_target_tokens=batch["targets"], + decoder_target_mask=batch["targets_segmentation"], + ) + intermediate_outputs = {} + one_hot_targets = jax.nn.one_hot(batch["targets"], config.vocab_size) + xent, z_loss = max_utils.cross_entropy_with_logits(logits, one_hot_targets, z_loss=config.z_loss_multiplier) + + xent = nn.with_logical_constraint(xent, ("activation_embed_and_logits_batch", "activation_length")) + z_loss = nn.with_logical_constraint(z_loss, ("activation_embed_and_logits_batch", "activation_length")) + + # Mask out paddings at the end of each example. + xent = xent * (batch["targets_segmentation"] != 0) + z_loss = z_loss * (batch["targets_segmentation"] != 0) + + total_loss = jnp.sum(xent) + total_z_loss = jnp.sum(z_loss) + + total_weights = jnp.sum(batch["targets_segmentation"] != 0) + # If gradient accumulation is enabled, we don't need to divide total_loss + # by total_weights and then multiply the computed gradient by total_weights, + # since it's equivalent to computing the gradient from total_loss. + # This simplification reduces the number of operations and makes it easier + # for XLA to move all-reduce out of the gradient accumulation loop when use + # Zero1+GA to reduce communication overhead. + # EPS was used to avoid division by zero, but it's not needed when gradient + # accumulation is enabled since there's no division. + if config.gradient_accumulation_steps > 1 and not config.use_tunix_gradient_accumulation: + loss = total_loss + else: + # When using Tunix gradient accumulation, we revert to standard normalization. + # Unlike the manual accumulation path above, Tunix (via optax.MultiSteps) expects + # a normalized loss for each step. It handles the accumulation state + # updates and scaling internally. + loss = total_loss / (total_weights + EPS) + + # We keep z-loss normalized by total_weights. + total_z_loss = total_z_loss / (total_weights + EPS) + + # Calculate and Add MTP Loss + mtp_loss = 0.0 + if config.mtp_num_layers > 0 and is_train: + mtp_loss = calculate_mtp_loss(intermediate_outputs, config) + loss += mtp_loss + + # get indexer loss + indexer_loss = 0.0 + if config.use_indexer and config.indexer_loss_scaling_factor > 0.0: + indexer_losses = [] + # Extract 'indexer_loss' from model intermediates. + # We check for paths ending in ('self_attention', 'indexer_loss'). + # This handles varying paths caused by different layer names. + for path, val in jax.tree_util.tree_leaves_with_path(intermediate_outputs): + path_keys = tuple(k.key for k in path if hasattr(k, "key")) + if path_keys[-2:] == ("self_attention", "indexer_loss"): + indexer_losses.append(jnp.ravel(val)) + + if indexer_losses: + indexer_loss = jnp.mean(jnp.concatenate(indexer_losses)) + loss += indexer_loss + else: + max_logging.debug("No indexer loss found.") + + # get MoE load balance loss + moe_lb_loss = 0.0 + if config.num_experts > 1: + # Note: the key is affected by the model implementation + possible_keys = [ + ("intermediates", "decoder", "layers", "moe_lb_loss"), + ("intermediates", "decoder", "moe_layers", "moe_lb_loss"), + ] + + total_moe_lb_loss = 0.0 + found_loss = False + for nested_key in possible_keys: + total_moe_lb_loss = maxtext_utils.get_nested_value(intermediate_outputs, nested_key, 0.0) + if total_moe_lb_loss != 0.0: + found_loss = True + break + + if not found_loss: + max_logging.debug("\nNo MoE load balance loss found. Defaulting to 0.0.") + + moe_lb_loss = jnp.mean(jnp.array(total_moe_lb_loss)) + loss += moe_lb_loss + + # get MoE routed bias term updates + moe_bias_updates = None + if config.routed_bias and config.routed_bias_update_rate > 0.0: + nested_key = ("intermediates", "decoder", "moe_layers", "moe_bias_updates") + moe_bias_updates = maxtext_utils.get_nested_value(intermediate_outputs, nested_key, None) + + # Add the model's primary output to the intermediates dict so it can be used + # by the acceptance rate calculation in eval_step. + intermediate_outputs["logits"] = logits + + aux = { + "intermediate_outputs": intermediate_outputs, + "total_loss": total_loss, + "z_loss": total_z_loss, + "total_weights": total_weights, + "moe_lb_loss": moe_lb_loss, + "indexer_loss": indexer_loss, + "moe_bias_updates": moe_bias_updates, + "mtp_loss": mtp_loss, + } + return loss, aux + + +# --------------------------------------------------------------------------- +# Train / eval steps (purely functional, JIT-able) +# --------------------------------------------------------------------------- + + +def train_step( + model_graphdef: nnx.graph.NodeDef, + opt_graphdef: nnx.graph.NodeDef, + model_state: nnx.State, + opt_state: nnx.State, + data: dict[str, jax.Array], + dropout_rng: jax.Array, + config, +): + """One training step: forward + backward + optimizer update. + + Args: + model_graphdef: Static NNX graph definition for the model (JIT closure). + opt_graphdef: Static NNX graph definition for the optimizer (JIT closure). + model_state: Mutable model parameter pytree (donated). + opt_state: Mutable optimizer state pytree (donated). + data: Batch of token IDs and metadata. + dropout_rng: PRNG key for dropout. + config: MaxText Config. + + Returns: + (new_model_state, new_opt_state): Updated pytrees. + metrics: Dict of scalar training metrics. + """ + model: nnx.Module = nnx.merge(model_graphdef, model_state) + optimizer: nnx.Optimizer = nnx.merge(opt_graphdef, opt_state) + if config.use_dpo: + # Need impl on NNX + pass + # state, reference_params = _split_dpo_state(state) + # state_mesh_shardings, reference_params_sharding = _split_dpo_state(state_mesh_shardings) + # extra_dpo_args = [reference_params] + # loss_fn = dpo_loss_fn + + # Compute loss and gradients w.r.t. model parameters. + # nnx.value_and_grad differentiates only through nnx.Param variables, + # keeping non-differentiable state (RNGs, cache, etc.) frozen. + if config.gradient_accumulation_steps > 1: + loss, aux, raw_grads = nnx_gradient_accumulation_loss_and_grad(loss_fn, model, config, data, dropout_rng) + else: + if config.optimizer_memory_host_offload: + # Need impl on NNX + pass + # if config.use_dpo: + # reference_params = jax.device_put( + # reference_params, + # max_utils.with_memory_kind(reference_params_sharding, "device"), + # ) + # extra_dpo_args = [reference_params] + if config.shard_optimizer_over_data: + # Need impl on NNX + pass + # params = jax.tree.map( + # functools.partial(sharding.maybe_shard_with_name, shard_mode=config.shard_mode), + # params, + # params_shardings, + # ) + grad_fn = nnx.value_and_grad(loss_fn, argnums=0, has_aux=True) + (loss, aux), raw_grads = grad_fn(model, config, data, dropout_rng, is_train=True) + + # Cast gradients to configured dtype before clipping / accumulation + raw_grads = jax.tree.map( + lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x, + raw_grads, + ) + intermediate_outputs = aux["intermediate_outputs"] + total_weights = aux["total_weights"] + moe_lb_loss = aux["moe_lb_loss"] + indexer_loss = aux["indexer_loss"] + z_loss = aux["z_loss"] + moe_bias_updates = aux["moe_bias_updates"] + mtp_loss = aux["mtp_loss"] + + # Gradient clipping (implemented directly to avoid Linen TrainState dependency) + if config.gradient_clipping_threshold > 0: + clip_tx = optax.clip_by_global_norm(config.gradient_clipping_threshold) + grads, _ = clip_tx.update(raw_grads, clip_tx.init(raw_grads), None) + else: + grads = raw_grads + if config.optimizer_memory_host_offload: + # Need impl on NNX + pass + # state = state.replace( + # opt_state=jax.device_put( + # state.opt_state, + # jax.tree_util.tree_map( + # lambda x: x.with_memory_kind(kind="device"), + # state_mesh_shardings.opt_state, + # ), + # ) + # ) + # Move all parameters to device before optimizer update + if config.parameter_memory_host_offload: + max_logging.log("\nMoving all parameters to device before optimizer update") + # Need impl on NNX + # def move(path, value): + # max_logging.log(f"train.py: Moving f{path} to device") + # return value.with_memory_kind(kind="device") + + # state = state.replace( + # params=jax.device_put( + # state.params, + # jax.tree_util.tree_map_with_path(move, state_mesh_shardings.params), + # ) + # ) + + # NNX 0.11+: update takes (model, grads) explicitly. + optimizer.update(model, grads) + + new_model_state = nnx.state(model) + new_opt_state = nnx.state(optimizer) + + # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family + if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: + # Need impl on NNX + pass + # target_path = ("params", "decoder", "moe_layers", "DeepSeekMoeBlock_0", "MoeBlock_0", "gate", "bias") + # Flax 'sow' returns a tuple, so we take the first element [0]. + # Updates the shape to be aligned with state. + # moe_bias_updates = jnp.array(moe_bias_updates[0]).transpose() + # new_state = maxtext_utils.update_state_param(new_state, target_path, moe_bias_updates) + + scalar_metrics = { + "learning/loss": loss, + "learning/z_loss": z_loss, + "learning/moe_lb_loss": moe_lb_loss, + "learning/indexer_loss": indexer_loss, + "learning/mtp_loss": mtp_loss, + "learning/total_weights": total_weights, + } + if config.use_qk_clip: + # Apply QK-Clip + # Need impl on NNX + pass + # new_state = qk_clip_utils.apply_qk_clip(new_state, intermediate_outputs, config) + + # Report max_logits metric + # global_max_logit = qk_clip_utils.calculate_max_logit_metric(intermediate_outputs) + # if global_max_logit is not None: + # scalar_metrics["learning/max_logits"] = global_max_logit + + if not config.optimizer_memory_host_offload: + scalar_metrics["learning/grad_norm"] = max_utils.l2norm_pytree(grads) + scalar_metrics["learning/raw_grad_norm"] = max_utils.l2norm_pytree(raw_grads) + scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(nnx.state(model, nnx.Param)) + if config.use_dpo: + scalar_metrics["learning/dpo_reward_accuracy"] = aux["reward_accuracy"] + metrics = { + "scalar": scalar_metrics, + "scalars": {}, + } + + if config.record_internal_nn_metrics: + record_activation_metrics(metrics, intermediate_outputs, config) + + if config.use_dpo: + # Need impl on NNX + pass + # new_state = _merge_dpo_state(new_state, reference_params) + return (new_model_state, new_opt_state), metrics + + +def eval_step( + model_graphdef: nnx.graph.NodeDef, + model_state: nnx.State, + data: dict[str, jax.Array], + dropout_rng: jax.Array, + config, +): + """One evaluation step: forward only, no gradient computation. + + Args: + model_graphdef: Static NNX graph definition for the model. + model_state: Current model parameter pytree (read-only). + data: Batch of token IDs and metadata. + dropout_rng: PRNG key (dropout disabled for eval, but kept for API symmetry). + config: MaxText Config. + + Returns: + metrics: Dict of scalar evaluation metrics. + """ + model: nnx.Module = nnx.merge(model_graphdef, model_state) + loss, aux = loss_fn(model, config, data, dropout_rng, is_train=False) + + mtp_acceptance_rate = 0.0 + if config.mtp_eval_target_module > 0: + mtp_acceptance_rate = calculate_mtp_acceptance_rate(aux["intermediate_outputs"], config) + + total_loss = aux["total_loss"] + z_loss = aux["z_loss"] + total_weights = aux["total_weights"] + moe_lb_loss = aux["moe_lb_loss"] + indexer_loss = aux["indexer_loss"] + mtp_loss = aux["mtp_loss"] + metrics = { + "scalar": { + "evaluation/loss": loss, + "evaluation/z_loss": z_loss, + "evaluation/total_loss": total_loss, + "evaluation/total_weights": total_weights, + "evaluation/moe_lb_loss": moe_lb_loss, + "evaluation/indexer_loss": indexer_loss, + "evaluation/mtp_loss": mtp_loss, + "evaluation/mtp_acceptance_rate_percent": mtp_acceptance_rate, + }, + } + if config.use_dpo: + metrics["scalar"]["evaluation/dpo_reward_accuracy"] = aux["reward_accuracy"] + + return metrics + + +# --------------------------------------------------------------------------- +# Training-loop setup +# --------------------------------------------------------------------------- + + +def _create_and_shard_optimizer(model: nnx.Module, config, mesh: Mesh): + """Creates an nnx.Optimizer and returns sharded model + optimizer states. + + In NNX 0.11+, the optimizer does not hold a model reference, so model and + optimizer are kept as independent objects with separate graphdefs, state + pytrees, and sharding specs throughout the training loop. + + Args: + model: Sharded NNX model (already placed on devices). + config: MaxText Config. + mesh: JAX device mesh. + + Returns: + model_graphdef: Static NNX graph definition for the model. + opt_graphdef: Static NNX graph definition for the optimizer. + model_state: Sharded model parameter pytree (donated to JIT steps). + opt_state: Sharded optimizer state pytree (donated to JIT steps). + model_shardings: Partition specs for model_state. + opt_shardings: Partition specs for opt_state. + learning_rate_schedule: Learning-rate schedule function. + """ + learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) + tx = optimizers.get_optimizer(config, learning_rate_schedule, model) + # NNX 0.11+: wrt is mandatory; optimizer does not store a model reference. + optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) + + # Derive separate partition specs for model and optimizer. + model_graphdef, abstract_model_state = nnx.split(nnx.eval_shape(lambda: model)) + opt_graphdef, abstract_opt_state = nnx.split(nnx.eval_shape(lambda: optimizer)) + + with nn.logical_axis_rules(config.logical_axis_rules): + model_shardings = nn.logical_to_mesh_sharding( + nnx.get_partition_spec(abstract_model_state), mesh, config.logical_axis_rules + ) + opt_shardings = nn.logical_to_mesh_sharding( + nnx.get_partition_spec(abstract_opt_state), mesh, config.logical_axis_rules + ) + + _, model_state = nnx.split(model) + _, opt_state = nnx.split(optimizer) + + @functools.partial(jax.jit, out_shardings=(model_shardings, opt_shardings)) + def shard_states(mshard, oshard): + return mshard, oshard + + with mesh: + model_state, opt_state = shard_states(model_state, opt_state) + + return model_graphdef, opt_graphdef, model_state, opt_state, model_shardings, opt_shardings, learning_rate_schedule + + +def _get_first_step(opt_state: nnx.State) -> int: + """Extracts the current step counter from the optimizer state.""" + # nnx.Optimizer stores step as an nnx.Variable; its value is a scalar. + step_leaves = [v for k, v in opt_state.flat_state().items() if "step" in str(k)] + if step_leaves: + return int(step_leaves[0]) + return 0 + + +def _build_jit_steps( + config, + model_graphdef: nnx.graph.NodeDef, + opt_graphdef: nnx.graph.NodeDef, + mesh: Mesh, + model_shardings: Any, + opt_shardings: Any, + eval_data_iterator, +): + """JIT-compiles the train and eval step functions with sharding annotations. + + Returns: + p_train_step: JIT-compiled train step. + p_eval_step: JIT-compiled eval step (None if no eval data). + """ + data_sharding = sharding.get_input_data_sharding(config, mesh) + + # Partial application captures static graphdefs and config outside JIT. + _train_fn = functools.partial(train_step, model_graphdef, opt_graphdef, config=config) + _train_fn.__name__ = "nnx_train_step" + + p_train_step = jax.jit( + _train_fn, + in_shardings=(model_shardings, opt_shardings, data_sharding, None), + out_shardings=((model_shardings, opt_shardings), None), + donate_argnums=(0, 1), # donate both model_state and opt_state buffers + ) + + p_eval_step = None + if eval_data_iterator is not None: + # Eval only needs the model; optimizer state is not required. + _eval_fn = functools.partial(eval_step, model_graphdef, config=config) + _eval_fn.__name__ = "nnx_eval_step" + p_eval_step = jax.jit( + _eval_fn, + in_shardings=(model_shardings, data_sharding, None), + out_shardings=None, + donate_argnums=(), + ) + + return p_train_step, p_eval_step + + +def _wrap_state(state: nnx.State): + """Wraps each leaf in {"value": ...} to match the NNX checkpoint format.""" + return jax.tree.map(lambda v: {"value": v}, state, is_leaf=lambda n: isinstance(n, nnx.Variable)) + + +def _unwrap_state(raw): + """Unwraps {"value": ...} leaves back to plain arrays.""" + return jax.tree.map(lambda v: v["value"], raw, is_leaf=lambda x: isinstance(x, dict) and "value" in x) + + +def _maybe_restore_checkpoint(checkpoint_manager, model_state: nnx.State, opt_state: nnx.State, config, data_iterator): + """Restores model and optimizer states from an Orbax checkpoint if one exists. + + Checkpoint layout: {"model": , "optimizer": }, + with every leaf wrapped as {"value": }. + + Returns: + (model_state, opt_state, data_iterator, start_step) + """ + if checkpoint_manager is None: + return model_state, opt_state, data_iterator, 0 + + try: + import orbax.checkpoint as ocp # pylint: disable=import-outside-toplevel + + latest = checkpoint_manager.latest_step() + if latest is None: + max_logging.log("No existing checkpoint found; starting from scratch.") + return model_state, opt_state, data_iterator, 0 + + max_logging.log(f"Restoring NNX checkpoint from step {latest}.") + ckptr = ocp.Checkpointer( + ocp.PyTreeCheckpointHandler( + restore_concurrent_gb=config.checkpoint_storage_concurrent_gb, + save_concurrent_gb=config.checkpoint_storage_concurrent_gb, + use_ocdbt=config.checkpoint_storage_use_ocdbt, + use_zarr3=config.checkpoint_storage_use_zarr3, + ) + ) + + target = {"model": _wrap_state(model_state), "optimizer": _wrap_state(opt_state)} + restore_args = ocp.checkpoint_utils.construct_restore_args(target) + checkpoint_dir = checkpoint_manager.directory / str(latest) + restored_raw = ckptr.restore(checkpoint_dir, item=target, restore_args=restore_args) + + restored_model_state = _unwrap_state(restored_raw["model"]) + restored_opt_state = _unwrap_state(restored_raw["optimizer"]) + return restored_model_state, restored_opt_state, data_iterator, int(latest) + + except Exception as e: # pylint: disable=broad-exception-caught + max_logging.log(f"Checkpoint restore failed ({e}); starting from scratch.") + return model_state, opt_state, data_iterator, 0 + + +def _maybe_save_checkpoint( + checkpoint_manager, model_state: nnx.State, opt_state: nnx.State, config, data_iterator, step: int +): + """Saves model and optimizer states to an Orbax checkpoint.""" + if checkpoint_manager is None: + return + state_to_save = {"model": _wrap_state(model_state), "optimizer": _wrap_state(opt_state)} + checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator, step) + + +# --------------------------------------------------------------------------- +# Main training loop +# --------------------------------------------------------------------------- + + +def train_loop(config, recorder, state=None): + """NNX pre-training loop. + + Args: + config: MaxText Config. + recorder: Goodput recorder (may be None). + state: Unused; present for API symmetry with train.py. + + Returns: + Final optimizer state pytree. + """ + # ---- Model ---------------------------------------------------------------- + with maybe_record_goodput(recorder, GoodputEvent.TPU_INIT): + with nn.logical_axis_rules(config.logical_axis_rules): + model, mesh = model_creation_utils.create_nnx_model(config) + + # ---- Optimizer + sharding ------------------------------------------------- + with maybe_record_goodput(recorder, GoodputEvent.TRAINING_PREPARATION): + model_graphdef, opt_graphdef, model_state, opt_state, model_shardings, opt_shardings, learning_rate_schedule = ( + _create_and_shard_optimizer(model, config, mesh) + ) + + # ---- Data --------------------------------------------------------------- + with jax.set_mesh(mesh): + data_iterator, eval_data_iterator = create_data_iterator(config, mesh) + rampup_manager = create_rampup_manager(config, checkpoint_manager=None) + data_loader = create_dataloader(config, mesh, data_iterator, recorder, rampup_manager) + + # ---- Checkpointing ------------------------------------------------------- + logger = checkpointing.setup_checkpoint_logger(config) + checkpoint_dir = config.checkpoint_dir if config.enable_checkpointing else "" + checkpoint_manager = checkpointing.create_orbax_checkpoint_manager( + checkpoint_dir, + config.enable_checkpointing, + config.async_checkpointing, + config.checkpoint_period, + config.dataset_type, + logger, + config.checkpoint_storage_use_ocdbt, + config.checkpoint_storage_use_zarr3, + config.enable_continuous_checkpointing, + config.max_num_checkpoints_to_keep, + config.checkpoint_storage_concurrent_gb, + config.enable_single_controller, + config.colocated_python_checkpointing, + config.enable_single_replica_ckpt_restoring, + ) + + model_state, opt_state, data_iterator, start_step = _maybe_restore_checkpoint( + checkpoint_manager, model_state, opt_state, config, data_iterator + ) + + # ---- JIT-compile steps ---------------------------------------------------- + with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): + p_train_step, p_eval_step = _build_jit_steps( + config, model_graphdef, opt_graphdef, mesh, model_shardings, opt_shardings, eval_data_iterator + ) + + # Trigger AOT compilation and print memory stats + with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): + shaped_batch = maxtext_utils.get_shaped_batch(config) + init_rng = jax.random.PRNGKey(config.init_weights_seed) + example_rng = jax.jit(jax.random.fold_in)(init_rng, 0) + # Need imple below func on NNX + # maxtext_utils.maybe_dump_jaxpr(config, p_train_step, (model_state, opt_state, shaped_batch, example_rng)) + if config.compiled_trainstep_file == "": # compile only when there is no pre-compiled file loaded + compiled = p_train_step.lower(model_state, opt_state, shaped_batch, example_rng).compile() + compiled_stats = compiled.memory_analysis() + max_utils.print_compiled_memory_stats(compiled_stats) + + # ---- Profiler / logger ---------------------------------------------------- + prof = profiler.Profiler(config, offset_step=start_step) + metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule) + + # Write train config params, num model params, and XLA flags to tensorboard + metric_logger.write_setup_info_to_tensorboard(model_state) + + # ---- Main loop ------------------------------------------------------------ + _job_completed_gracefully = False + try: + last_step_completion = datetime.datetime.now() + max_logging.info(f"Entering train loop from start_step={start_step}") + + for step in np.arange(start_step, config.steps): + prof.maybe_activate_profiler(step, opt_state) + + with jax.profiler.StepTraceAnnotation("train", step_num=step): + example_batch = data_loader.load_next_batch(rampup_manager=rampup_manager) + nextrng = jax.jit(jax.random.fold_in)(init_rng, step) + with maybe_record_goodput(recorder, GoodputEvent.STEP, step): + with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): + (model_state, opt_state), metrics = p_train_step(model_state, opt_state, example_batch, nextrng) + + step_time_delta = datetime.datetime.now() - last_step_completion + last_step_completion = datetime.datetime.now() + + _maybe_save_checkpoint(checkpoint_manager, model_state, opt_state, config, data_iterator, step) + + # ---- Optional eval ------------------------------------------------------- + if ( + p_eval_step is not None + and config.eval_interval > 0 + and step > start_step + and (step + 1) % config.eval_interval == 0 + ): + assert eval_data_iterator + # Explicitly reset the eval iterator and counters before starting the eval loop + eval_data_iterator.reset() + metric_logger.reset_eval_metrics() + + eval_step_count = 0 + for eval_batch in eval_data_iterator: + if config.eval_steps > 0 and eval_step_count >= config.eval_steps: + break + with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): + eval_metrics = p_eval_step(model_state, eval_batch, nextrng) + metric_logger.record_eval_metrics(step, metrics=eval_metrics) + max_logging.log(f"Completed eval step {eval_step_count}") + eval_step_count += 1 + + metric_logger.record_eval_metrics(step, eval_step_count=eval_step_count) + if metric_logger.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss: + prof.deactivate() + raise exceptions.StopTraining(f"Target loss {config.target_eval_loss=} achieved.") + + prof.maybe_deactivate_profiler(step, opt_state) + + if step == start_step: + max_utils.print_mem_stats("After first step") + + metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta) + + # Final checkpoint on loop completion + if config.save_checkpoint_on_completion: + _maybe_save_checkpoint( + checkpoint_manager, model_state, opt_state, config, data_iterator, step=int(config.steps - 1) + ) + if checkpoint_manager is not None: + # in case the last checkpoint_period checkpoint is still in progress + checkpoint_manager.wait_until_finished() + + _job_completed_gracefully = True + + except exceptions.StopTraining as e: + max_logging.log(f"Training stopped: {str(e)}") + _job_completed_gracefully = True + + finally: + if _job_completed_gracefully: + record_goodput(recorder, RECORD_JOB_END_TIME) + metric_logger.flush_metrics_and_cleanup() + + return opt_state + + +# --------------------------------------------------------------------------- +# Entry-point helpers +# --------------------------------------------------------------------------- + + +def initialize(argv: Sequence[str]): + """Initialise hyperparameters and utility objects.""" + jax.config.update("jax_default_prng_impl", "unsafe_rbg") + + import tensorflow as tf # pylint: disable=import-outside-toplevel + + tf.config.set_visible_devices([], "GPU") + + if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""): + os.environ["LIBTPU_INIT_ARGS"] = ( + os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" + ) + + config = pyconfig.initialize(argv) + max_utils.print_system_information() + + if not config.enable_nnx: + max_logging.log("WARNING: nnx_train.py requires enable_nnx=True. Forcing it on.") + + if config.shard_mode == ShardMode.EXPLICIT: + jax.config.update("jax_remove_size_one_mesh_axis_from_type", True) + + os.environ["TFDS_DATA_DIR"] = config.dataset_path or "" + + vertex_tensorboard_manager = VertexTensorboardManager() + if config.use_vertex_tensorboard or os.environ.get("UPLOAD_DATA_TO_TENSORBOARD"): + vertex_tensorboard_manager.configure_vertex_tensorboard(config) + + # Create the Goodput recorder + recorder = create_goodput_recorder(config) + + # Stack traces configurations + debug_config = debug_configuration.DebugConfig( + stack_trace_config=stack_trace_configuration.StackTraceConfig( + collect_stack_trace=config.collect_stack_trace, + stack_trace_to_cloud=config.stack_trace_to_cloud, + stack_trace_interval_seconds=config.stack_trace_interval_seconds, + ) + ) + diagnostic_config = diagnostic_configuration.DiagnosticConfig(debug_config) + return config, recorder, diagnostic_config + + +def run(config, recorder, diagnostic_config): + """Run the NNX training job. + + In decoupled mode (DECOUPLE_GCLOUD=TRUE) cloud diagnostics may be stubbed; if so, skip wrapping. + """ + # Use nullcontext when diagnostics are stubbed or in decoupled mode + diagnostics_context = ( + contextlib.nullcontext() + if is_decoupled() or getattr(diagnostic, "__class__", None).__name__ == "_StubDiag" + else diagnostic.diagnose(diagnostic_config) + ) + + if is_decoupled() or getattr(diagnostic, "__class__", None).__name__ == "_StubDiag": + max_logging.log("[DECOUPLED NO-OP] skipping cloud diagnostics wrapper.") + + with ( + diagnostics_context, + max_utils.maybe_get_transformer_engine_context(config), + ): + train_loop(config, recorder) + + +def main(argv: Sequence[str]) -> None: + config, recorder, diagnostic_config = initialize(argv) + record_goodput(recorder, RECORD_JOB_START_TIME) + with maybe_monitor_goodput(config): + run(config, recorder, diagnostic_config) + + +if __name__ == "__main__": + app.run(main) diff --git a/src/maxtext/trainers/pre_train/nnx_train_compile.py b/src/maxtext/trainers/pre_train/nnx_train_compile.py new file mode 100644 index 0000000000..d02a945c90 --- /dev/null +++ b/src/maxtext/trainers/pre_train/nnx_train_compile.py @@ -0,0 +1,265 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Save a Cross Ahead of Time Compiled (XAOT) version of nnx_train.py's train step. + +Mirrors train_compile.py but uses the Flax NNX API throughout, in contrast to +train_compile.py which relies on Linen's TrainState. + +Key differences from train_compile.py +-------------------------------------- +- No Linen TrainState. State lives in two separate pytrees: + model_state – nnx.State for the model parameters + opt_state – nnx.State for the optimizer (optax state + step counter) +- nnx.eval_shape creates abstract shapes without materialising parameters, so the + whole compilation is done without ever touching real hardware memory. +- Graphdefs (model_graphdef, opt_graphdef) are baked into the partial and are + Python-static across the JIT boundary; they are therefore not listed in + static_argnums. +- in_shardings / out_shardings follow the NNX train_step signature: + in: (model_state, opt_state, batch, rng) + out: ((model_state, opt_state), metrics) + +Entry point: + python -m maxtext.trainers.pre_train.nnx_train_compile [overrides…] +""" + +import functools +import os +from typing import Callable, Sequence + +import jax +from absl import app +from flax import linen as nn +from flax import nnx +from flax.linen import partitioning as nn_partitioning + +from maxtext.common.common_types import MODEL_MODE_TRAIN +from maxtext.configs import pyconfig +from maxtext.optimizers import optimizers +from maxtext.trainers.pre_train import nnx_train +from maxtext.trainers.pre_train.train_compile import get_topology_mesh, jit_and_compile, save_compiled, validate_config +from maxtext.utils import gcs_utils, max_utils, maxtext_utils, model_creation_utils, sharding + + +def create_nnx_rngs( + config: pyconfig.HyperParameters, is_training: bool = True, rng_key: jax.Array | None = None +) -> nnx.Rngs: + """ + Create NNX Rngs + + Args: + config: the configuration + is_training: if the Rngs are for training + rng_key: the Rng key + + Returns: + The NNX Rngs + """ + if rng_key is None: + rng_key = jax.random.PRNGKey(config.init_weights_seed) + + if is_training: + return nnx.Rngs( + params=jax.random.fold_in(rng_key, 0), dropout=jax.random.fold_in(rng_key, 1), aqt=jax.random.fold_in(rng_key, 2) + ) + return nnx.Rngs(params=rng_key) # disable dropout RNG and aqt for inference + + +# --------------------------------------------------------------------------- +# Shaped inputs (NNX version) +# --------------------------------------------------------------------------- + + +def get_shaped_inputs_nnx(topology_mesh, config): + """Build abstract (shape-only) versions of nnx_train.train_step's inputs. + + Uses nnx.eval_shape to trace through model and optimizer construction so that + no actual parameters are allocated. The returned abstract states have + ShapeDtypeStruct leaves and can be passed directly to jax.jit.lower(). + + Returns: + model_graphdef: Static NNX graph definition for the model. + opt_graphdef: Static NNX graph definition for the optimizer. + abstract_model_state: Abstract model parameter pytree. + abstract_opt_state: Abstract optimizer state pytree. + model_shardings: Partition specs mapped to mesh shardings for model_state. + opt_shardings: Partition specs mapped to mesh shardings for opt_state. + data_sharding: Input-batch sharding. + shaped_batch: Shaped batch dict (ShapeDtypeStruct leaves). + shaped_rng: Shaped RNG key. + learning_rate_schedule: LR schedule (baked into the compiled object). + """ + # rng_key = jax.random.PRNGKey(config.init_weights_seed) + # rngs = nnx.Rngs(params=rng_key, dropout=1) + + # ------------------------------------------------------------------ + # 1. Abstract model via nnx.eval_shape — no parameters materialised. + # ------------------------------------------------------------------ + + def get_nnx_create_model_fn(config, mesh=None, devices=None) -> Callable: + """Creates the function for NNX model creation.""" + + def _create_model(): + # is_training = model_mode == MODEL_MODE_TRAIN + # rngs = maxtext_utils_nnx.create_nnx_rngs(config, is_training=is_training, rng_key=rng_key) + rng_key = jax.random.PRNGKey(config.init_weights_seed) + rngs = create_nnx_rngs(config, True, rng_key) + return model_creation_utils.from_config(config, devices, mesh, rngs=rngs, model_mode=MODEL_MODE_TRAIN) + + return _create_model + + with nn.logical_axis_rules(config.logical_axis_rules): + create_model_fn = get_nnx_create_model_fn(config, topology_mesh) + abstract_model = nnx.eval_shape(create_model_fn) + model_graphdef, abstract_model_state = nnx.split(abstract_model) + + # ------------------------------------------------------------------ + # 2. Abstract optimizer via nnx.eval_shape. + # ------------------------------------------------------------------ + learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) + # get_optimizer may inspect the model structure (e.g. for Muon); the abstract + # model has the same tree structure as the real one, so this is safe. + tx = optimizers.get_optimizer(config, learning_rate_schedule, abstract_model) + + def _build_optimizer(): + return nnx.Optimizer(abstract_model, tx, wrt=nnx.Param) + + abstract_optimizer = nnx.eval_shape(_build_optimizer) + opt_graphdef, abstract_opt_state = nnx.split(abstract_optimizer) + + # ------------------------------------------------------------------ + # 3. Partition specs → mesh shardings. + # ------------------------------------------------------------------ + with nn.logical_axis_rules(config.logical_axis_rules): + model_shardings = nn.logical_to_mesh_sharding( + nnx.get_partition_spec(abstract_model_state), topology_mesh, config.logical_axis_rules + ) + opt_shardings = nn.logical_to_mesh_sharding( + nnx.get_partition_spec(abstract_opt_state), topology_mesh, config.logical_axis_rules + ) + + # ------------------------------------------------------------------ + # 4. Shaped batch and RNG. + # ------------------------------------------------------------------ + data_sharding = sharding.get_input_data_sharding(config, topology_mesh) + shaped_batch = maxtext_utils.get_shaped_batch(config) + + _, example_rng = jax.random.split(jax.random.PRNGKey(0), 2) + shaped_rng = jax.ShapeDtypeStruct(example_rng.shape, example_rng.dtype) + + return ( + model_graphdef, + opt_graphdef, + abstract_model_state, + abstract_opt_state, + model_shardings, + opt_shardings, + data_sharding, + shaped_batch, + shaped_rng, + learning_rate_schedule, + ) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(argv: Sequence[str]) -> None: + jax.config.update("jax_default_prng_impl", "unsafe_rbg") + os.environ["LIBTPU_INIT_ARGS"] = ( + os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" + ) + print("Starting nnx_train_compile.py...", flush=True) + + # Parse and validate configuration + config = pyconfig.initialize(argv) + validate_config(config) + + # Create target mesh + topology_mesh = get_topology_mesh(config) + + # Print system information after building the compile topology to avoid + # prematurely initialising the backend. + max_utils.print_system_information() + + # Get shaped inputs + ( + model_graphdef, + opt_graphdef, + abstract_model_state, + abstract_opt_state, + model_shardings, + opt_shardings, + data_sharding, + shaped_batch, + shaped_rng, + _, # _learning_rate_schedule, + ) = get_shaped_inputs_nnx(topology_mesh, config) + + # Build the partial that matches what _build_jit_steps produces in nnx_train. + # graphdefs are static (captured in the Python closure) so they do not appear + # in static_argnums. + func_to_compile = functools.partial(nnx_train.train_step, model_graphdef, opt_graphdef, config=config) + func_to_compile.__name__ = "nnx_train_step" + + shaped_train_args = (abstract_model_state, abstract_opt_state, shaped_batch, shaped_rng) + shaped_train_kwargs = {} + + in_shard = (model_shardings, opt_shardings, data_sharding, None) + out_shard = ((model_shardings, opt_shardings), None) + static_argnums = () + donate_argnums = (0, 1) + + # Compile + print("Jitting and compiling NNX train step...", flush=True) + compiled = jit_and_compile( + func_to_compile, + shaped_train_args, + shaped_train_kwargs, + topology_mesh, + in_shard, + out_shard, + static_argnums, + donate_argnums, + config, + nn_partitioning.axis_rules(config.logical_axis_rules), + ) + print("Jitting and compilation complete!", flush=True) + + # Serialize and save the compiled object + if config.compiled_trainstep_file != "": + print("Saving compiled object...") + save_compiled(compiled, config.compiled_trainstep_file) + print(f"Successfully saved compiled object as {config.compiled_trainstep_file}") + print("Finished nnx_train_compile.py successfully!", flush=True) + print(f"Cost analysis: {compiled.cost_analysis()}") + print(f"Memory analysis: {compiled.memory_analysis()}") + + # Dump HLO if requested + if config.dump_hlo: + gcs_utils.upload_dump( + config.dump_hlo_local_dir, + config.dump_hlo_gcs_dir, + module_name=config.dump_hlo_module_name, + delete_local_after=config.dump_hlo_delete_local_after, + all_host_upload=config.dump_hlo_upload_all, + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/src/maxtext/utils/gradient_accumulation.py b/src/maxtext/utils/gradient_accumulation.py index e4cad14906..5c68cbe27c 100644 --- a/src/maxtext/utils/gradient_accumulation.py +++ b/src/maxtext/utils/gradient_accumulation.py @@ -16,6 +16,7 @@ import jax import jax.numpy as jnp +from flax import nnx from jax.sharding import NamedSharding from maxtext.common.common_types import ShardMode @@ -137,6 +138,127 @@ def reshape_to_microbatch_accumulations(batch_arr): return loss, aux, raw_grads +# --------------------------------------------------------------------------- +# Gradient accumulation helper for NNX +# --------------------------------------------------------------------------- + + +def nnx_gradient_accumulation_loss_and_grad(_loss_fn, model, config, data, dropout_rng): + """ + Calculates gradients using gradient accumulation. + + This function computes the gradient of `_loss_fn` over multiple microbatches + and accumulates them before returning a single, averaged gradient. It uses + `jax.lax.scan` for efficient accumulation on device. + + It also supports a `shard_optimizer_over_data` mode (e.g., ZeRO-1) where + parameters are cast to bf16 and sharded *before* the accumulation loop + to perform the all-gather in lower precision. + + Args: + _loss_fn: The loss function to differentiate. Its signature is expected + to be: `(model, config, data, dropout_rng, is_train=True)`. + config: Model and training configuration object. Must contain + `gradient_accumulation_steps` and `shard_optimizer_over_data`. + model: The model module. + data: A PyTree of batched data. The leading dimension is assumed + to be the total batch size (microbatch_size * num_accumulations). + dropout_rng: JAX PRNGKey for dropout. + extra_dpo_args: A tuple of extra arguments to pass to the loss function. + + Returns: + A tuple containing: + - total_loss (Array): The mean loss, averaged over all microbatches. + - final_aux (PyTree): Auxiliary outputs, summed across microbatches. + - raw_grads (PyTree): The accumulated and averaged gradients. + """ + + # For more efficient DP/ZeRO-1 + GA + # if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism > 1: + # ga_params_shardings = jax.tree.map(update_sharding_for_reduced, params_shardings) + # grad_shardings = jax.tree.map(update_sharding_for_unreduced, params_shardings) + # else: + # ga_params_shardings = grad_shardings = params_shardings + + graphdef, params, rest = nnx.split(model, nnx.Param, ...) + + # When using Zero-1 optimizer sharding, cast params to lower precision and apply sharding constraints + # so that all-gather is done once in the lower precision before the gradient accumulation loop + if config.shard_optimizer_over_data: + + def convert_to_bf16(param): + if param.dtype == jnp.float32: + return param.astype(jnp.bfloat16) + return param + + ga_params = jax.tree.map(convert_to_bf16, params) + else: + ga_params = params + + # ga_params = jax.tree.map(_maybe_shard_with_name, ga_params, ga_params_shardings) + grad_func = nnx.value_and_grad(_loss_fn, argnums=0, has_aux=True) + + def accumulate_gradient(acc_grad_and_loss, data): + ga_params = acc_grad_and_loss["ga_params"] + # Reconstruct the model using the fixed parameters (ga_params) + # and the advancing non-parameter state (RNGs) from the carry. + + # as ga_params will change during train_step, always create a local_model + local_model = nnx.merge(graphdef, ga_params, acc_grad_and_loss["rest_state"]) + (_, aux), cur_batch_gradient = grad_func(local_model, config, data, dropout_rng, is_train=True) + _, _, next_rest_state = nnx.split(local_model, nnx.Param, ...) + + acc_grad_and_loss["rest_state"] = next_rest_state + acc_grad_and_loss["loss"] += aux["total_loss"] + acc_grad_and_loss["moe_lb_loss"] += aux["moe_lb_loss"] + acc_grad_and_loss["mtp_loss"] += aux["mtp_loss"] + acc_grad_and_loss["grad"] = jax.tree.map(lambda x, y: x + y, cur_batch_gradient, acc_grad_and_loss["grad"]) + acc_grad_and_loss["total_weights"] += aux["total_weights"] + return acc_grad_and_loss, aux + + def reshape_to_microbatch_accumulations(batch_arr): + """Reshape [B, ...] → [num_microbatches, B//num_microbatches, ...].""" + num_microbatches = config.gradient_accumulation_steps + microbatch_shape = (num_microbatches, batch_arr.shape[0] // num_microbatches) + batch_arr.shape[1:] + return jnp.reshape(batch_arr, microbatch_shape) + + # def reshape_to_microbatch_accumulations(batch_arr): + # """Reshape global batch to microbatches, assuming batch axis is leading.""" + # num_microbatches = config.gradient_accumulation_steps + # microbatch_shape = (batch_arr.shape[0] // num_microbatches, num_microbatches) + batch_arr.shape[1:] + # reshaped_batch_arr = jnp.reshape(batch_arr, microbatch_shape) + # return jnp.swapaxes(reshaped_batch_arr, 0, 1) + + data = jax.tree.map(reshape_to_microbatch_accumulations, data) + init_grad = jax.tree.map(jnp.zeros_like, ga_params) + # init_grad = jax.tree.map(_maybe_shard_with_name, init_grad, grad_shardings) + init_grad_and_loss = { + "loss": 0.0, + "grad": init_grad, + "total_weights": 0, + "moe_lb_loss": 0.0, + "mtp_loss": 0.0, + "ga_params": ga_params, + } + init_grad_and_loss["rest_state"] = rest + + grad_and_loss, aux = jax.lax.scan( + accumulate_gradient, init_grad_and_loss, data, length=config.gradient_accumulation_steps + ) + loss = ( + grad_and_loss["loss"] / grad_and_loss["total_weights"] + + grad_and_loss["moe_lb_loss"] / config.gradient_accumulation_steps + + grad_and_loss["mtp_loss"] / config.gradient_accumulation_steps + ) + raw_grads = grad_and_loss["grad"] + raw_grads = jax.tree.map(lambda arr: arr / grad_and_loss["total_weights"], raw_grads) + aux = jax.tree.map(lambda x: jnp.sum(x, axis=0), aux) # pytype: disable=module-attr + + nnx.update(model, grad_and_loss["rest_state"]) + + return loss, aux, raw_grads + + # GA helper functions def update_sharding_for_reduced(sharding: NamedSharding) -> NamedSharding: """ diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index 4449aa04f2..eab927b59c 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -20,6 +20,7 @@ import os from flax import linen as nn +from flax import nnx from flax.linen import partitioning as nn_partitioning from flax.training import train_state @@ -1030,7 +1031,7 @@ def init_initial_state(model, tx, config, is_training, key): return init_decode_state(model.apply, model_vars) -def get_abstract_param(model, config): +def get_abstract_param(model: nn.Module | nnx.Module, config): """Get abstract model structure (name, shape) without materializing the weights to save memory""" with model.mesh, nn_partitioning.axis_rules(config.logical_axis_rules): key = jax.random.PRNGKey(0) @@ -1039,14 +1040,17 @@ def get_abstract_param(model, config): config.model_name, batch_size=config.micro_batch_size_to_train_on ) audio_shape = mm_processor.get_dummy_audio_shape_for_init(config) - abstract_vars = jax.eval_shape( - model.init, - {"params": key, "dropout": key, "aqt": key}, - jnp.ones(input_shape, dtype=jnp.int32), - jnp.ones(input_shape, dtype=jnp.int32), - encoder_images=np.ones(image_shape, dtype=jnp.int32) if config.use_multimodal else None, - encoder_audios=np.ones(audio_shape, dtype=jnp.float32) if config.use_audio else None, - ) + if isinstance(model, nn.Module): + abstract_vars = jax.eval_shape( + model.init, + {"params": key, "dropout": key, "aqt": key}, + jnp.ones(input_shape, dtype=jnp.int32), + jnp.ones(input_shape, dtype=jnp.int32), + encoder_images=np.ones(image_shape, dtype=jnp.int32) if config.use_multimodal else None, + encoder_audios=np.ones(audio_shape, dtype=jnp.float32) if config.use_audio else None, + ) + else: # nnx.Module + _, abstract_vars = nnx.split(nnx.eval_shape(lambda: model)) return abstract_vars diff --git a/tests/unit/nnx_train_compile_test.py b/tests/unit/nnx_train_compile_test.py new file mode 100644 index 0000000000..f1b5cdfe8e --- /dev/null +++ b/tests/unit/nnx_train_compile_test.py @@ -0,0 +1,355 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the Ahead-of-Time (AOT) compilation script using the NNX API. + +This module contains unit tests for `nnx_train_compile.py`, ensuring that +various model configurations and parallelism strategies can be successfully +compiled for different hardware topologies using the Flax NNX API. +""" + +import os.path +import unittest +from tempfile import gettempdir + +import pytest + +from maxtext.trainers.pre_train.nnx_train_compile import main as nnx_train_compile_main +from tests.utils.test_helpers import get_test_config_path + + +@pytest.mark.tpu_backend +class NnxTrainCompile(unittest.TestCase): + """Tests for the Ahead of Time Compilation functionality, nnx_train_compile.py""" + + @pytest.mark.cpu_only + def test_save_compiled_v4(self): + temp_dir = gettempdir() + compiled_trainstep_file = os.path.join(temp_dir, "nnx_test_compiled_v4.pickle") + nnx_train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v4-8", + "compile_topology_num_slices=1", + "base_emb_dim=256", + "base_mlp_dim=256", + "base_num_decoder_layers=2", + "enable_nnx=True", + "pure_nnx_decoder=True", + ) + ) + + @pytest.mark.cpu_only + def test_save_compiled_v5e(self): + temp_dir = gettempdir() + compiled_trainstep_file = os.path.join(temp_dir, "nnx_test_compiled_v5e.pickle") + nnx_train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5e-16", + "compile_topology_num_slices=1", + "base_emb_dim=256", + "base_mlp_dim=256", + "base_num_decoder_layers=2", + "enable_nnx=True", + "pure_nnx_decoder=True", + ) + ) + + @pytest.mark.cpu_only + def test_save_compiled_v5p_two_slices(self): + temp_dir = gettempdir() + compiled_trainstep_file = os.path.join(temp_dir, "nnx_test_compiled_v5p_two_slices.pickle") + nnx_train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5p-8", + "compile_topology_num_slices=2", + "base_emb_dim=256", + "base_mlp_dim=256", + "base_num_decoder_layers=2", + "enable_nnx=True", + "pure_nnx_decoder=True", + ) + ) + + @pytest.mark.cpu_only + def test_save_compiled_v6e(self): + temp_dir = gettempdir() + compiled_trainstep_file = os.path.join(temp_dir, "nnx_test_compiled_v6e.pickle") + nnx_train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v6e-16", + "compile_topology_num_slices=1", + "base_emb_dim=256", + "base_mlp_dim=256", + "base_num_decoder_layers=2", + "enable_nnx=True", + "pure_nnx_decoder=True", + ) + ) + + @pytest.mark.cpu_only + def test_save_compiled_tpu7x(self): + temp_dir = gettempdir() + compiled_trainstep_file = os.path.join(temp_dir, "nnx_test_compiled_tpu7x.pickle") + nnx_train_compile_main( + ( + None, + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=tpu7x-16", + "compile_topology_num_slices=1", + "ici_fsdp_parallelism=16", + "base_emb_dim=256", + "base_mlp_dim=256", + "base_num_decoder_layers=2", + "enable_nnx=True", + "pure_nnx_decoder=True", + ) + ) + + @pytest.mark.cpu_only + def test_save_compiled_tpu7x_two_slices(self): + temp_dir = gettempdir() + compiled_trainstep_file = os.path.join(temp_dir, "nnx_test_compiled_tpu7x_two_slices.pickle") + nnx_train_compile_main( + ( + None, + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=tpu7x-8", + "compile_topology_num_slices=2", + "ici_fsdp_parallelism=4", + "ici_tensor_parallelism=2", + "dcn_data_parallelism=2", + "base_emb_dim=256", + "base_mlp_dim=256", + "base_num_decoder_layers=2", + "enable_nnx=True", + "pure_nnx_decoder=True", + ) + ) + + @pytest.mark.cpu_only + def test_sequence_parallelism(self): + temp_dir = gettempdir() + compiled_trainstep_file = os.path.join(temp_dir, "nnx_test_compiled_sequence_parallelism.pickle") + nnx_train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5p-64", + "use_iota_embed=true", + "compile_topology_num_slices=1", + "ici_sequence_parallelism=16", + "global_parameter_scale=32", + "per_device_batch_size=0.0625", + "max_target_length=65536", + "enable_nnx=True", + "pure_nnx_decoder=True", + ) + ) + + @pytest.mark.cpu_only + def test_remat_full(self): + temp_dir = gettempdir() + compiled_trainstep_file = os.path.join(temp_dir, "nnx_test_remat_full.pickle") + nnx_train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v6e-256", + "compile_topology_num_slices=1", + "per_device_batch_size=1", + "ici_fsdp_parallelism=16", + "ici_tensor_parallelism=16", + "max_target_length=1024", + "fused_qkv=true", + "fused_mlp=true", + "remat_policy=full", + "use_iota_embed=true", + "global_parameter_scale=128", + "enable_nnx=True", + "pure_nnx_decoder=True", + ) + ) + + @pytest.mark.cpu_only + def test_save_flash(self): + compiled_trainstep_file = "/tmp/nnx_test_save_flash" + nnx_train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5p-256", + "compile_topology_num_slices=1", + "per_device_batch_size=1", + "remat_policy=custom", + "context=device", + "enable_nnx=True", + "pure_nnx_decoder=True", + ) + ) + + @pytest.mark.cpu_only + def test_gpt3_6b(self): + compiled_trainstep_file = "/tmp/nnx_test_gpt3_6b" + nnx_train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5p-256", + "compile_topology_num_slices=1", + "model_name=gpt3-6b", + "per_device_batch_size=1", + "enable_nnx=True", + "pure_nnx_decoder=True", + ) + ) + + @pytest.mark.cpu_only + def test_moe_dropping_bf16(self): + temp_dir = gettempdir() + compiled_trainstep_file = os.path.join(temp_dir, "nnx_test_moe_dropping_bf16.pickle") + nnx_train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5p-64", + "use_iota_embed=true", + "compile_topology_num_slices=1", + "model_name=mixtral-8x7b", + "sparse_matmul=False", + "capacity_factor=1", + "per_device_batch_size=4", + "max_target_length=1024", + "attention=flash", + "dtype=bfloat16", + "enable_nnx=True", + "pure_nnx_decoder=True", + ) + ) + + @pytest.mark.cpu_only + def test_moe_megablox_bf16(self): + temp_dir = gettempdir() + compiled_trainstep_file = os.path.join(temp_dir, "nnx_test_moe_megablox_bf16.pickle") + nnx_train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v6e-256", + "use_iota_embed=true", + "compile_topology_num_slices=1", + "model_name=mixtral-8x7b", + "sparse_matmul=True", + "megablox=True", + "per_device_batch_size=4", + "max_target_length=1024", + "attention=flash", + "dtype=bfloat16", + "enable_nnx=True", + "pure_nnx_decoder=True", + ) + ) + + @pytest.mark.cpu_only + def test_moe_deepseek_scanned_bf16(self): + temp_dir = gettempdir() + compiled_trainstep_file = os.path.join(temp_dir, "nnx_test_moe_deepseek_scanned_bf16.pickle") + nnx_train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5p-64", + "use_iota_embed=true", + "compile_topology_num_slices=1", + "model_name=deepseek3-test", + "sparse_matmul=True", + "megablox=False", + "per_device_batch_size=2", + "max_target_length=1024", + "attention=flash", + "dtype=bfloat16", + "weight_dtype=bfloat16", + "scan_layers=True", + "enable_nnx=True", + "pure_nnx_decoder=True", + ) + ) + + @pytest.mark.cpu_only + def test_moe_megablox_ring_ep_random(self): + temp_dir = gettempdir() + compiled_trainstep_file = os.path.join(temp_dir, "nnx_test_moe_megablox_ring_ep_random.pickle") + nnx_train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5p-16", + "use_iota_embed=true", + "compile_topology_num_slices=1", + "model_name=deepseek3-test", + "sparse_matmul=True", + "megablox=True", + "per_device_batch_size=4", + "max_target_length=128", + "use_ring_of_experts=True", + "use_random_routing=True", + "attention=flash", + "dtype=bfloat16", + "enable_nnx=True", + "pure_nnx_decoder=True", + ) + ) + + @pytest.mark.cpu_only + def test_pipeline_subset(self): + compiled_trainstep_file = "/tmp/nnx_test_pipeline_subset.pickle" + nnx_train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5p-128", + "compile_topology_num_slices=8", + "use_iota_embed=true", + "per_device_batch_size=1", + "max_target_length=1024", + "pipeline_parallel_layers=56", + "base_num_decoder_layers=61", + "ici_expert_parallelism=16", + "dcn_pipeline_parallelism=8", + "enable_nnx=True", + "pure_nnx_decoder=True", + ) + ) diff --git a/tests/unit/train_compile_test.py b/tests/unit/train_compile_test.py index cb291e13bd..10474239c6 100644 --- a/tests/unit/train_compile_test.py +++ b/tests/unit/train_compile_test.py @@ -636,6 +636,8 @@ def test_moe_deepseek_pipeline_subset(self): "pipeline_parallel_layers=56", "ici_expert_parallelism=16", "dcn_pipeline_parallelism=8", + "first_num_dense_layers=8", + "base_num_decoder_layers=72", ) ) @@ -653,7 +655,7 @@ def test_pipeline_subset(self): "per_device_batch_size=1", "max_target_length=1024", "pipeline_parallel_layers=56", - "base_num_decoder_layers=61", # Remainder of 5 will fail when sharded incorrectly. + "base_num_decoder_layers=64", # Must be divisible by dcn_pipeline_parallelism=8 in NNX scan path. "ici_expert_parallelism=16", "dcn_pipeline_parallelism=8", )