|
| 1 | +import jax |
| 2 | +import jax.numpy as jnp |
| 3 | +import orbax.checkpoint as ocp |
| 4 | +import numpy as np |
| 5 | +from typing import Any, Dict, Sequence, Tuple |
| 6 | +from jax.tree_util import tree_flatten_with_path, keystr, tree_structure, tree_map_with_path |
| 7 | +from absl import app |
| 8 | +from absl import flags |
| 9 | + |
| 10 | + |
| 11 | +_LINEN_CKPT_PATH = flags.DEFINE_string( |
| 12 | + "linen_ckpt_path", None, "Path to the Linen model checkpoint items directory.", required=True |
| 13 | +) |
| 14 | +_NNX_CKPT_PATH = flags.DEFINE_string( |
| 15 | + "nnx_ckpt_path", None, "Path to the NNX model checkpoint items directory.", required=True |
| 16 | +) |
| 17 | + |
| 18 | + |
| 19 | +def load_checkpoint_params(path: str) -> Dict[str, Any]: |
| 20 | + """Loads parameters from an Orbax checkpoint path.""" |
| 21 | + print(f"Loading checkpoint from: {path}") |
| 22 | + checkpointer = ocp.PyTreeCheckpointer() |
| 23 | + restored_state = checkpointer.restore(path) |
| 24 | + if restored_state is None: |
| 25 | + raise ValueError(f"Failed to restore checkpoint from {path}") |
| 26 | + if isinstance(restored_state, dict) and "params" in restored_state: |
| 27 | + return restored_state["params"] |
| 28 | + return restored_state |
| 29 | + |
| 30 | + |
| 31 | +def transform_nnx_params(nnx_params: Dict[str, Any]) -> Dict[str, Any]: |
| 32 | + """Applies specific transformations with verbose logging matching original format.""" |
| 33 | + |
| 34 | + def _transform(path, leaf: jax.Array) -> jax.Array: |
| 35 | + key_str = keystr(path) |
| 36 | + |
| 37 | + if "layers" in key_str and hasattr(leaf, "ndim") and leaf.ndim >= 2: |
| 38 | + print(f"TRANSPOSING: {key_str} with shape {leaf.shape}") |
| 39 | + axes = (1, 0) + tuple(range(2, leaf.ndim)) |
| 40 | + return jnp.transpose(leaf, axes=axes) |
| 41 | + else: |
| 42 | + if "token_embedder" in key_str: |
| 43 | + print(f"SKIPPING Transpose: {key_str} because it is token_embedder") |
| 44 | + else: |
| 45 | + shape = getattr(leaf, "shape", "N/A") |
| 46 | + print(f"SKIPPING Transpose: {key_str} with shape {shape} (ndim < 2)") |
| 47 | + return leaf |
| 48 | + |
| 49 | + print("Applying transformations to NNX params...") |
| 50 | + return tree_map_with_path(_transform, nnx_params) |
| 51 | + |
| 52 | + |
| 53 | +def get_tree_structure_info(tree: Dict[str, Any]): |
| 54 | + """Helper only used if structures differ.""" |
| 55 | + flat_with_path, _ = tree_flatten_with_path(tree) |
| 56 | + return {keystr(p): (getattr(l, "shape", "N/A"), str(getattr(l, "dtype", type(l).__name__))) for p, l in flat_with_path} |
| 57 | + |
| 58 | + |
| 59 | +def print_structure_diff(params1, params2): |
| 60 | + """Prints missing/added keys if structures differ.""" |
| 61 | + info1 = get_tree_structure_info(params1) |
| 62 | + info2 = get_tree_structure_info(params2) |
| 63 | + keys1, keys2 = set(info1.keys()), set(info2.keys()) |
| 64 | + |
| 65 | + for k in sorted(keys2 - keys1): |
| 66 | + print(f" + Added in NNX: {k}") |
| 67 | + for k in sorted(keys1 - keys2): |
| 68 | + print(f" - Missing in NNX: {k}") |
| 69 | + |
| 70 | + |
| 71 | +def compare_params(params1: Dict[str, Any], params2: Dict[str, Any]) -> bool: |
| 72 | + if tree_structure(params1) != tree_structure(params2): |
| 73 | + print("[] Tree structures differ.") |
| 74 | + print_structure_diff(params1, params2) |
| 75 | + return False |
| 76 | + |
| 77 | + print("[] Tree structures are the same.") |
| 78 | + |
| 79 | + all_match = True |
| 80 | + |
| 81 | + def _compare_leaf(path, x, y): |
| 82 | + nonlocal all_match |
| 83 | + key_str = keystr(path) |
| 84 | + |
| 85 | + try: |
| 86 | + shape1 = getattr(x, "shape", "N/A") |
| 87 | + shape2 = getattr(y, "shape", "N/A") |
| 88 | + |
| 89 | + if shape1 != shape2: |
| 90 | + print(f"[{key_str}] SHAPE MISMATCH: {shape1} vs {shape2}") |
| 91 | + all_match = False |
| 92 | + return |
| 93 | + |
| 94 | + dtype1 = getattr(x, "dtype", type(x)) |
| 95 | + dtype2 = getattr(y, "dtype", type(y)) |
| 96 | + |
| 97 | + if dtype1 != dtype2: |
| 98 | + print(f"[{key_str}] DTYPE MISMATCH: {dtype1} vs {dtype2}") |
| 99 | + all_match = False |
| 100 | + return |
| 101 | + |
| 102 | + diff = x - y |
| 103 | + abs_diff = jnp.abs(diff) |
| 104 | + mean_diff_scalar = jnp.mean(abs_diff) |
| 105 | + max_diff_scalar = jnp.max(abs_diff) |
| 106 | + is_close_scalar = jnp.allclose(x, y) |
| 107 | + |
| 108 | + mean_diff = float(mean_diff_scalar) |
| 109 | + max_diff = float(max_diff_scalar) |
| 110 | + is_close = bool(is_close_scalar) |
| 111 | + |
| 112 | + print( |
| 113 | + f"[{key_str}] " |
| 114 | + f"Shape(Linen/NNX): {shape1} / {shape2} — " |
| 115 | + f"Mean abs diff: {mean_diff:.2e}, " |
| 116 | + f"Max abs diff: {max_diff:.2e}, " |
| 117 | + f"AllClose: {is_close}" |
| 118 | + ) |
| 119 | + |
| 120 | + if not is_close: |
| 121 | + all_match = False |
| 122 | + |
| 123 | + except Exception as e: |
| 124 | + print(f"[{key_str}] Error during comparison: {e}") |
| 125 | + all_match = False |
| 126 | + |
| 127 | + tree_map_with_path(_compare_leaf, params1, params2) |
| 128 | + |
| 129 | + return all_match |
| 130 | + |
| 131 | + |
| 132 | +def main(argv: Sequence[str]): |
| 133 | + if len(argv) > 1: |
| 134 | + raise app.UsageError("Too many command-line arguments.") |
| 135 | + |
| 136 | + linen_ckpt_path = _LINEN_CKPT_PATH.value |
| 137 | + nnx_ckpt_path = _NNX_CKPT_PATH.value |
| 138 | + |
| 139 | + print(f"Linen Checkpoint Path: {linen_ckpt_path}") |
| 140 | + print(f"NNX Checkpoint Path: {nnx_ckpt_path}") |
| 141 | + |
| 142 | + print("Loading Linen params...") |
| 143 | + linen_params = load_checkpoint_params(linen_ckpt_path) |
| 144 | + print("Loading NNX params...") |
| 145 | + nnx_params = load_checkpoint_params(nnx_ckpt_path) |
| 146 | + |
| 147 | + if linen_params is not None and nnx_params is not None: |
| 148 | + nnx_params_transformed = transform_nnx_params(nnx_params) |
| 149 | + |
| 150 | + print("\nComparing Linen params with Transformed NNX params...") |
| 151 | + if compare_params(linen_params, nnx_params_transformed): |
| 152 | + print("\nCheckpoints are considered the same (within np.allclose tolerance) after transformation!") |
| 153 | + else: |
| 154 | + print("\nCheckpoints DIFFER after transformation.") |
| 155 | + else: |
| 156 | + print("Failed to load params from one or both checkpoints.") |
| 157 | + |
| 158 | + |
| 159 | +if __name__ == "__main__": |
| 160 | + app.run(main) |
0 commit comments