|
| 1 | + |
| 2 | +""" |
| 3 | +A unified tool to inspect checkpoint structures for: |
| 4 | +1. HuggingFace/PyTorch source files (.safetensors, .pth) |
| 5 | +2. MaxText Model Architecture (on-the-fly, no weights loaded) |
| 6 | +3. Saved Orbax Checkpoints (metadata only) |
| 7 | +
|
| 8 | +Usage Examples: |
| 9 | +[Mode 1: HF/PyTorch] |
| 10 | + python inspect_checkpoint.py hf --path <local_hf_path> --format <safetensors | pth> |
| 11 | +[Mode 2: MaxText Arch] |
| 12 | + python inspect_checkpoint.py maxtext --model_name <maxtext_model_name> --scan_layers <True | False> |
| 13 | +[Mode 3: Orbax] |
| 14 | + python inspect_checkpoint.py orbax --path <local_orbax_path | gcs_orbax_path> |
| 15 | +
|
| 16 | +
|
| 17 | +cd ~/maxtext |
| 18 | +SCRIPT=~/maxtext/src/MaxText/utils/ckpt_conversion/inspect_checkpoint.py |
| 19 | +python inspect_checkpoint.py hf --path <local_hf_path> --format safetensors |
| 20 | +python $SCRIPT maxtext --model_name deepseek3.2-671b --scan_layers False |
| 21 | +python $SCRIPT maxtext --model_name deepseek3.2-671b --scan_layers True |
| 22 | +""" |
| 23 | + |
| 24 | +import argparse |
| 25 | +import sys |
| 26 | +import os |
| 27 | +import re |
| 28 | +import pathlib |
| 29 | + |
| 30 | + |
| 31 | +def natural_sort_key(s: str): |
| 32 | + """Sorts strings containing numbers naturally (1, 2, 10 instead of 1, 10, 2).""" |
| 33 | + return [int(text) if text.isdigit() else text for text in re.split(r"(\d+)", str(s))] |
| 34 | + |
| 35 | + |
| 36 | +def print_structure(data_dict): |
| 37 | + """Utility to print sorted keys and shapes from a flat dictionary.""" |
| 38 | + for key in sorted(data_dict.keys(), key=natural_sort_key): |
| 39 | + shape = data_dict[key] |
| 40 | + print(f"key: {key} | shape: {shape}") |
| 41 | + |
| 42 | + |
| 43 | +# ============================================================================== |
| 44 | +# Mode 1: HuggingFace / PyTorch (.safetensors or .pth) |
| 45 | +# ============================================================================== |
| 46 | +def inspect_hf(args): |
| 47 | + print(f"\n--- Inspecting {args.format} files in {args.path} ---") |
| 48 | + |
| 49 | + # Lazy imports |
| 50 | + try: |
| 51 | + import torch |
| 52 | + except ImportError: |
| 53 | + sys.exit("Error: 'torch' is required for this mode. `pip install torch`") |
| 54 | + |
| 55 | + ckpt_paths = sorted(pathlib.Path(args.path).glob(f"[!.]*.{args.format}")) |
| 56 | + if not ckpt_paths: |
| 57 | + sys.exit(f"No files with extension .{args.format} found in {args.path}") |
| 58 | + |
| 59 | + chkpt_vars_raw = {} |
| 60 | + |
| 61 | + if args.format == "safetensors": |
| 62 | + try: |
| 63 | + from safetensors import safe_open |
| 64 | + except ImportError: |
| 65 | + sys.exit("Error: 'safetensors' is required. `pip install safetensors`") |
| 66 | + |
| 67 | + for i, ckpt_path in enumerate(ckpt_paths): |
| 68 | + print(f"Loading {ckpt_path.name} ({i+1}/{len(ckpt_paths)})...") |
| 69 | + with safe_open(ckpt_path, framework="pt") as f: |
| 70 | + for k in f.keys(): |
| 71 | + # Storing shape directly to save memory, rather than the full tensor |
| 72 | + chkpt_vars_raw[k] = f.get_tensor(k).shape |
| 73 | + |
| 74 | + elif args.format == "pth": |
| 75 | + for i, ckpt_path in enumerate(ckpt_paths): |
| 76 | + print(f"Loading {ckpt_path.name} ({i+1}/{len(ckpt_paths)})...") |
| 77 | + checkpoint = torch.load(ckpt_path, map_location="cpu") |
| 78 | + # Flatten logic might be needed depending on pth structure, |
| 79 | + # here we assume standard state_dict or handle the wrapper keys manually if needed. |
| 80 | + if isinstance(checkpoint, dict): |
| 81 | + for k, v in checkpoint.items(): |
| 82 | + if hasattr(v, "shape"): |
| 83 | + chkpt_vars_raw[k] = v.shape |
| 84 | + else: |
| 85 | + # Handle nested state dicts or wrapper keys if common in your workflow |
| 86 | + chkpt_vars_raw[k] = "Non-tensor found" |
| 87 | + |
| 88 | + print("\n=== Structure ===") |
| 89 | + print_structure(chkpt_vars_raw) |
| 90 | + |
| 91 | + |
| 92 | +# ============================================================================== |
| 93 | +# Mode 2: MaxText Architecture (On-the-fly) |
| 94 | +# ============================================================================== |
| 95 | +def inspect_maxtext(args): |
| 96 | + print(f"\n--- Inspecting MaxText Architecture: {args.model_name} (Scan: {args.scan_layers}) ---") |
| 97 | + |
| 98 | + # Lazy imports |
| 99 | + import jax |
| 100 | + from maxtext.utils import max_utils, maxtext_utils |
| 101 | + from MaxText import pyconfig |
| 102 | + from MaxText.globals import MAXTEXT_PKG_DIR |
| 103 | + from MaxText.layers import models, quantizations |
| 104 | + |
| 105 | + Transformer = models.transformer_as_linen |
| 106 | + |
| 107 | + # Setup config |
| 108 | + argv = [ |
| 109 | + "", # First arg is usually script name in pyconfig |
| 110 | + os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), |
| 111 | + f"model_name={args.model_name}", |
| 112 | + f"scan_layers={args.scan_layers}", |
| 113 | + "attention=dot_product", |
| 114 | + "skip_jax_distributed_system=true", |
| 115 | + ] |
| 116 | + |
| 117 | + # Initialize without heavyweight runtime |
| 118 | + config = pyconfig.initialize(argv) |
| 119 | + devices_array = maxtext_utils.create_device_mesh(config) |
| 120 | + mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) |
| 121 | + quant = quantizations.configure_quantization(config) |
| 122 | + model = Transformer(config, mesh=mesh, quant=quant) |
| 123 | + |
| 124 | + # Get abstract params (no memory/compute) |
| 125 | + abstract_param = maxtext_utils.get_abstract_param(model, config) |
| 126 | + num_params = max_utils.calculate_num_params_from_pytree(abstract_param) |
| 127 | + |
| 128 | + print(f"\nTotal Parameters: {num_params} (~{num_params/1e9:.2f}B)") |
| 129 | + print("\n=== Structure ===") |
| 130 | + |
| 131 | + abstract_params_flat, _ = jax.tree_util.tree_flatten_with_path(abstract_param) |
| 132 | + |
| 133 | + flat_shapes = {} |
| 134 | + for path_tuple, abstract_leaf_value in abstract_params_flat: |
| 135 | + key_parts = [k.key for k in path_tuple if hasattr(k, "key")] |
| 136 | + # Construct MaxText style parameter key |
| 137 | + mt_param_key = "params-" + "-".join(key_parts) |
| 138 | + flat_shapes[mt_param_key] = abstract_leaf_value.shape |
| 139 | + |
| 140 | + print_structure(flat_shapes) |
| 141 | + |
| 142 | + |
| 143 | +# ============================================================================== |
| 144 | +# Mode 3: Orbax Checkpoint (Saved) |
| 145 | +# ============================================================================== |
| 146 | +def inspect_orbax(args): |
| 147 | + print(f"\n--- Inspecting Orbax Checkpoint: {args.path} ---") |
| 148 | + |
| 149 | + # Lazy imports |
| 150 | + try: |
| 151 | + import orbax.checkpoint as ocp |
| 152 | + from etils import epath |
| 153 | + except ImportError: |
| 154 | + sys.exit("Error: 'orbax-checkpoint' or 'etils' not found. `pip install orbax-checkpoint etils[epath]`") |
| 155 | + |
| 156 | + path = epath.Path(args.path) |
| 157 | + |
| 158 | + try: |
| 159 | + # Depending on Orbax version, metadata access might vary slightly. |
| 160 | + # This aligns with StandardCheckpointer usage. |
| 161 | + metadata = ocp.StandardCheckpointer().metadata(path) |
| 162 | + if hasattr(metadata, "item_metadata"): |
| 163 | + metadata = metadata.item_metadata |
| 164 | + except Exception as e: |
| 165 | + sys.exit(f"Error reading Orbax metadata: {e}") |
| 166 | + |
| 167 | + # Convert to flat dict |
| 168 | + dictionary = ocp.tree.to_flat_dict(metadata) |
| 169 | + |
| 170 | + # Filter for params only and clean up keys |
| 171 | + flat_shapes = {} |
| 172 | + for k, v in dictionary.items(): |
| 173 | + # k is a tuple, join it. v is metadata object with .shape |
| 174 | + key_str = ".".join(k) |
| 175 | + if key_str.startswith("params"): |
| 176 | + flat_shapes[key_str] = v.shape |
| 177 | + |
| 178 | + print("\n=== Structure ===") |
| 179 | + print_structure(flat_shapes) |
| 180 | + |
| 181 | + |
| 182 | +# ============================================================================== |
| 183 | +# Main CLI Driver |
| 184 | +# ============================================================================== |
| 185 | +def main(): |
| 186 | + parser = argparse.ArgumentParser(description="Consolidated Model Checkpoint Inspector") |
| 187 | + subparsers = parser.add_subparsers(dest="mode", required=True, help="Inspection mode") |
| 188 | + |
| 189 | + # Mode 1: HuggingFace / PyTorch |
| 190 | + parser_hf = subparsers.add_parser("hf", help="Inspect .safetensors or .pth files") |
| 191 | + parser_hf.add_argument("--path", type=str, required=True, help="Directory containing checkpoint files") |
| 192 | + parser_hf.add_argument( |
| 193 | + "--format", type=str, required=False, choices=["safetensors", "pth"], default="safetensors", help="File format" |
| 194 | + ) |
| 195 | + |
| 196 | + # Mode 2: MaxText Architecture |
| 197 | + parser_mt = subparsers.add_parser("maxtext", help="Inspect MaxText theoretical architecture") |
| 198 | + parser_mt.add_argument("--model_name", type=str, required=True, help="e.g. deepseek3-671b") |
| 199 | + parser_mt.add_argument( |
| 200 | + "--scan_layers", |
| 201 | + type=str, |
| 202 | + required=False, |
| 203 | + default="true", |
| 204 | + choices=["true", "false", "True", "False"], |
| 205 | + help="Simulate scanned or unscanned structure", |
| 206 | + ) |
| 207 | + |
| 208 | + # Mode 3: Orbax |
| 209 | + parser_orbax = subparsers.add_parser("orbax", help="Inspect saved Orbax checkpoint metadata") |
| 210 | + parser_orbax.add_argument("--path", type=str, required=True, help="Path to checkpoint items (local or GCS)") |
| 211 | + |
| 212 | + args = parser.parse_args() |
| 213 | + |
| 214 | + if args.mode == "hf": |
| 215 | + inspect_hf(args) |
| 216 | + elif args.mode == "maxtext": |
| 217 | + inspect_maxtext(args) |
| 218 | + elif args.mode == "orbax": |
| 219 | + inspect_orbax(args) |
| 220 | + |
| 221 | + |
| 222 | +if __name__ == "__main__": |
| 223 | + main() |
0 commit comments