Skip to content

Commit e37a93d

Browse files
committed
Tool to inspect checkpoint structure
1 parent 98fb5cf commit e37a93d

1 file changed

Lines changed: 223 additions & 0 deletions

File tree

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
2+
"""
3+
A unified tool to inspect checkpoint structures for:
4+
1. HuggingFace/PyTorch source files (.safetensors, .pth)
5+
2. MaxText Model Architecture (on-the-fly, no weights loaded)
6+
3. Saved Orbax Checkpoints (metadata only)
7+
8+
Usage Examples:
9+
[Mode 1: HF/PyTorch]
10+
python inspect_checkpoint.py hf --path <local_hf_path> --format <safetensors | pth>
11+
[Mode 2: MaxText Arch]
12+
python inspect_checkpoint.py maxtext --model_name <maxtext_model_name> --scan_layers <True | False>
13+
[Mode 3: Orbax]
14+
python inspect_checkpoint.py orbax --path <local_orbax_path | gcs_orbax_path>
15+
16+
17+
cd ~/maxtext
18+
SCRIPT=~/maxtext/src/MaxText/utils/ckpt_conversion/inspect_checkpoint.py
19+
python inspect_checkpoint.py hf --path <local_hf_path> --format safetensors
20+
python $SCRIPT maxtext --model_name deepseek3.2-671b --scan_layers False
21+
python $SCRIPT maxtext --model_name deepseek3.2-671b --scan_layers True
22+
"""
23+
24+
import argparse
25+
import sys
26+
import os
27+
import re
28+
import pathlib
29+
30+
31+
def natural_sort_key(s: str):
32+
"""Sorts strings containing numbers naturally (1, 2, 10 instead of 1, 10, 2)."""
33+
return [int(text) if text.isdigit() else text for text in re.split(r"(\d+)", str(s))]
34+
35+
36+
def print_structure(data_dict):
37+
"""Utility to print sorted keys and shapes from a flat dictionary."""
38+
for key in sorted(data_dict.keys(), key=natural_sort_key):
39+
shape = data_dict[key]
40+
print(f"key: {key} | shape: {shape}")
41+
42+
43+
# ==============================================================================
44+
# Mode 1: HuggingFace / PyTorch (.safetensors or .pth)
45+
# ==============================================================================
46+
def inspect_hf(args):
47+
print(f"\n--- Inspecting {args.format} files in {args.path} ---")
48+
49+
# Lazy imports
50+
try:
51+
import torch
52+
except ImportError:
53+
sys.exit("Error: 'torch' is required for this mode. `pip install torch`")
54+
55+
ckpt_paths = sorted(pathlib.Path(args.path).glob(f"[!.]*.{args.format}"))
56+
if not ckpt_paths:
57+
sys.exit(f"No files with extension .{args.format} found in {args.path}")
58+
59+
chkpt_vars_raw = {}
60+
61+
if args.format == "safetensors":
62+
try:
63+
from safetensors import safe_open
64+
except ImportError:
65+
sys.exit("Error: 'safetensors' is required. `pip install safetensors`")
66+
67+
for i, ckpt_path in enumerate(ckpt_paths):
68+
print(f"Loading {ckpt_path.name} ({i+1}/{len(ckpt_paths)})...")
69+
with safe_open(ckpt_path, framework="pt") as f:
70+
for k in f.keys():
71+
# Storing shape directly to save memory, rather than the full tensor
72+
chkpt_vars_raw[k] = f.get_tensor(k).shape
73+
74+
elif args.format == "pth":
75+
for i, ckpt_path in enumerate(ckpt_paths):
76+
print(f"Loading {ckpt_path.name} ({i+1}/{len(ckpt_paths)})...")
77+
checkpoint = torch.load(ckpt_path, map_location="cpu")
78+
# Flatten logic might be needed depending on pth structure,
79+
# here we assume standard state_dict or handle the wrapper keys manually if needed.
80+
if isinstance(checkpoint, dict):
81+
for k, v in checkpoint.items():
82+
if hasattr(v, "shape"):
83+
chkpt_vars_raw[k] = v.shape
84+
else:
85+
# Handle nested state dicts or wrapper keys if common in your workflow
86+
chkpt_vars_raw[k] = "Non-tensor found"
87+
88+
print("\n=== Structure ===")
89+
print_structure(chkpt_vars_raw)
90+
91+
92+
# ==============================================================================
93+
# Mode 2: MaxText Architecture (On-the-fly)
94+
# ==============================================================================
95+
def inspect_maxtext(args):
96+
print(f"\n--- Inspecting MaxText Architecture: {args.model_name} (Scan: {args.scan_layers}) ---")
97+
98+
# Lazy imports
99+
import jax
100+
from maxtext.utils import max_utils, maxtext_utils
101+
from MaxText import pyconfig
102+
from MaxText.globals import MAXTEXT_PKG_DIR
103+
from MaxText.layers import models, quantizations
104+
105+
Transformer = models.transformer_as_linen
106+
107+
# Setup config
108+
argv = [
109+
"", # First arg is usually script name in pyconfig
110+
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
111+
f"model_name={args.model_name}",
112+
f"scan_layers={args.scan_layers}",
113+
"attention=dot_product",
114+
"skip_jax_distributed_system=true",
115+
]
116+
117+
# Initialize without heavyweight runtime
118+
config = pyconfig.initialize(argv)
119+
devices_array = maxtext_utils.create_device_mesh(config)
120+
mesh = jax.sharding.Mesh(devices_array, config.mesh_axes)
121+
quant = quantizations.configure_quantization(config)
122+
model = Transformer(config, mesh=mesh, quant=quant)
123+
124+
# Get abstract params (no memory/compute)
125+
abstract_param = maxtext_utils.get_abstract_param(model, config)
126+
num_params = max_utils.calculate_num_params_from_pytree(abstract_param)
127+
128+
print(f"\nTotal Parameters: {num_params} (~{num_params/1e9:.2f}B)")
129+
print("\n=== Structure ===")
130+
131+
abstract_params_flat, _ = jax.tree_util.tree_flatten_with_path(abstract_param)
132+
133+
flat_shapes = {}
134+
for path_tuple, abstract_leaf_value in abstract_params_flat:
135+
key_parts = [k.key for k in path_tuple if hasattr(k, "key")]
136+
# Construct MaxText style parameter key
137+
mt_param_key = "params-" + "-".join(key_parts)
138+
flat_shapes[mt_param_key] = abstract_leaf_value.shape
139+
140+
print_structure(flat_shapes)
141+
142+
143+
# ==============================================================================
144+
# Mode 3: Orbax Checkpoint (Saved)
145+
# ==============================================================================
146+
def inspect_orbax(args):
147+
print(f"\n--- Inspecting Orbax Checkpoint: {args.path} ---")
148+
149+
# Lazy imports
150+
try:
151+
import orbax.checkpoint as ocp
152+
from etils import epath
153+
except ImportError:
154+
sys.exit("Error: 'orbax-checkpoint' or 'etils' not found. `pip install orbax-checkpoint etils[epath]`")
155+
156+
path = epath.Path(args.path)
157+
158+
try:
159+
# Depending on Orbax version, metadata access might vary slightly.
160+
# This aligns with StandardCheckpointer usage.
161+
metadata = ocp.StandardCheckpointer().metadata(path)
162+
if hasattr(metadata, "item_metadata"):
163+
metadata = metadata.item_metadata
164+
except Exception as e:
165+
sys.exit(f"Error reading Orbax metadata: {e}")
166+
167+
# Convert to flat dict
168+
dictionary = ocp.tree.to_flat_dict(metadata)
169+
170+
# Filter for params only and clean up keys
171+
flat_shapes = {}
172+
for k, v in dictionary.items():
173+
# k is a tuple, join it. v is metadata object with .shape
174+
key_str = ".".join(k)
175+
if key_str.startswith("params"):
176+
flat_shapes[key_str] = v.shape
177+
178+
print("\n=== Structure ===")
179+
print_structure(flat_shapes)
180+
181+
182+
# ==============================================================================
183+
# Main CLI Driver
184+
# ==============================================================================
185+
def main():
186+
parser = argparse.ArgumentParser(description="Consolidated Model Checkpoint Inspector")
187+
subparsers = parser.add_subparsers(dest="mode", required=True, help="Inspection mode")
188+
189+
# Mode 1: HuggingFace / PyTorch
190+
parser_hf = subparsers.add_parser("hf", help="Inspect .safetensors or .pth files")
191+
parser_hf.add_argument("--path", type=str, required=True, help="Directory containing checkpoint files")
192+
parser_hf.add_argument(
193+
"--format", type=str, required=False, choices=["safetensors", "pth"], default="safetensors", help="File format"
194+
)
195+
196+
# Mode 2: MaxText Architecture
197+
parser_mt = subparsers.add_parser("maxtext", help="Inspect MaxText theoretical architecture")
198+
parser_mt.add_argument("--model_name", type=str, required=True, help="e.g. deepseek3-671b")
199+
parser_mt.add_argument(
200+
"--scan_layers",
201+
type=str,
202+
required=False,
203+
default="true",
204+
choices=["true", "false", "True", "False"],
205+
help="Simulate scanned or unscanned structure",
206+
)
207+
208+
# Mode 3: Orbax
209+
parser_orbax = subparsers.add_parser("orbax", help="Inspect saved Orbax checkpoint metadata")
210+
parser_orbax.add_argument("--path", type=str, required=True, help="Path to checkpoint items (local or GCS)")
211+
212+
args = parser.parse_args()
213+
214+
if args.mode == "hf":
215+
inspect_hf(args)
216+
elif args.mode == "maxtext":
217+
inspect_maxtext(args)
218+
elif args.mode == "orbax":
219+
inspect_orbax(args)
220+
221+
222+
if __name__ == "__main__":
223+
main()

0 commit comments

Comments
 (0)