Skip to content

Commit 490c8a1

Browse files
committed
update
1 parent 6bf98a2 commit 490c8a1

1 file changed

Lines changed: 50 additions & 57 deletions

File tree

src/maxtext/checkpoint_conversion/inspect_checkpoint.py

Lines changed: 50 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
[Mode 1: HF/PyTorch]
2424
python src/maxtext/checkpoint_conversion/inspect_checkpoint.py hf --path <local_hf_path> --format <safetensors | pth>
2525
[Mode 2: MaxText Arch]
26-
python src/maxtext/checkpoint_conversion/inspect_checkpoint.py maxtext --model_name <maxtext_model_name> --scan_layers <True | False>
26+
python src/maxtext/checkpoint_conversion/inspect_checkpoint.py maxtext model_name <maxtext_model_name> scan_layers <True | False>
2727
[Mode 3: Orbax]
2828
python src/maxtext/checkpoint_conversion/inspect_checkpoint.py orbax --path <local_orbax_path | gcs_orbax_path>
2929
"""
@@ -43,8 +43,7 @@ def natural_sort_key(s: str):
4343
def print_structure(data_dict):
4444
"""Utility to print sorted keys and shapes from a flat dictionary."""
4545
for key in sorted(data_dict.keys(), key=natural_sort_key):
46-
shape = data_dict[key]
47-
print(f"key: {key} | shape: {shape}")
46+
print(f"key: {key} | {data_dict[key]}")
4847

4948

5049
# ==============================================================================
@@ -53,17 +52,11 @@ def print_structure(data_dict):
5352
def inspect_hf(args):
5453
print(f"\n--- Inspecting {args.format} files in {args.path} ---")
5554

56-
# Lazy imports
57-
try:
58-
import torch
59-
except ImportError:
60-
sys.exit("Error: 'torch' is required for this mode. `pip install torch`")
61-
6255
ckpt_paths = sorted(pathlib.Path(args.path).glob(f"[!.]*.{args.format}"))
6356
if not ckpt_paths:
6457
sys.exit(f"No files with extension .{args.format} found in {args.path}")
6558

66-
chkpt_vars_raw = {}
59+
param_dict = {}
6760

6861
if args.format == "safetensors":
6962
try:
@@ -76,31 +69,34 @@ def inspect_hf(args):
7669
with safe_open(ckpt_path, framework="pt") as f:
7770
for k in f.keys():
7871
# Storing shape directly to save memory, rather than the full tensor
79-
chkpt_vars_raw[k] = f.get_tensor(k).shape
72+
shape = f.get_tensor(k).shape
73+
param_dict[k] = f"shape: {shape}"
8074

8175
elif args.format == "pth":
76+
try:
77+
import torch
78+
except ImportError:
79+
sys.exit("Error: 'torch' is required for this mode. `pip install torch`")
80+
8281
for i, ckpt_path in enumerate(ckpt_paths):
8382
print(f"Loading {ckpt_path.name} ({i+1}/{len(ckpt_paths)})...")
8483
checkpoint = torch.load(ckpt_path, map_location="cpu")
8584
# Flatten logic might be needed depending on pth structure,
8685
# here we assume standard state_dict or handle the wrapper keys manually if needed.
8786
if isinstance(checkpoint, dict):
8887
for k, v in checkpoint.items():
89-
if hasattr(v, "shape"):
90-
chkpt_vars_raw[k] = v.shape
91-
else:
92-
# Handle nested state dicts or wrapper keys if common in your workflow
93-
chkpt_vars_raw[k] = "Non-tensor found"
88+
# Handle nested state dicts or wrapper keys if common in your workflow
89+
shape = v.shape if hasattr(v, "shape") else "Non-tensor found"
90+
param_dict[k] = f"shape: {shape}"
9491

9592
print("\n=== Structure ===")
96-
print_structure(chkpt_vars_raw)
93+
print_structure(param_dict)
9794

9895

9996
# ==============================================================================
10097
# Mode 2: MaxText Architecture (On-the-fly)
10198
# ==============================================================================
102-
def inspect_maxtext(args):
103-
print(f"\n--- Inspecting MaxText Architecture: {args.model_name} (Scan: {args.scan_layers}) ---")
99+
def inspect_maxtext(args, remaining_args):
104100

105101
# Lazy imports
106102
import jax
@@ -113,17 +109,17 @@ def inspect_maxtext(args):
113109
Transformer = models.transformer_as_linen
114110

115111
# Setup config
116-
argv = [
117-
"", # First arg is usually script name in pyconfig
118-
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
119-
f"model_name={args.model_name}",
120-
f"scan_layers={args.scan_layers}",
121-
"attention=dot_product",
122-
"skip_jax_distributed_system=true",
123-
]
112+
argv = (
113+
# First arg is usually script name in pyconfig
114+
[None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")]
115+
+ remaining_args
116+
+ ["attention=dot_product", "skip_jax_distributed_system=true"]
117+
)
118+
print(argv)
124119

125120
# Initialize without heavyweight runtime
126121
config = pyconfig.initialize(argv)
122+
print(f"\n--- Inspecting MaxText Architecture: {config.model_name} (Scan: {config.scan_layers}) ---")
127123
devices_array = maxtext_utils.create_device_mesh(config)
128124
mesh = jax.sharding.Mesh(devices_array, config.mesh_axes)
129125
quant = quantizations.configure_quantization(config)
@@ -133,19 +129,23 @@ def inspect_maxtext(args):
133129
abstract_param = maxtext_utils.get_abstract_param(model, config)
134130
num_params = max_utils.calculate_num_params_from_pytree(abstract_param)
135131

136-
print(f"\nTotal Parameters: {num_params} (~{num_params/1e9:.2f}B)")
132+
print(f"\nTotal Parameters: {num_params} (~{num_params/1e9:.2f} B)")
137133
print("\n=== Structure ===")
138134

139135
abstract_params_flat, _ = jax.tree_util.tree_flatten_with_path(abstract_param)
140136

141-
flat_shapes = {}
137+
param_dict = {}
138+
# abstract_leaf_value: ShapeDtypeStruct(shape=(128, 58), dtype=float32)
142139
for path_tuple, abstract_leaf_value in abstract_params_flat:
143140
key_parts = [k.key for k in path_tuple if hasattr(k, "key")]
144141
# Construct MaxText style parameter key
145-
mt_param_key = "params-" + "-".join(key_parts)
146-
flat_shapes[mt_param_key] = abstract_leaf_value.shape
142+
param_key = "params-" + "-".join(key_parts)
143+
shape = abstract_leaf_value.shape
144+
param_dict[param_key] = f"shape: {shape}"
145+
dtype = abstract_leaf_value.dtype
146+
param_dict[param_key] += f" | dtype: {dtype}"
147147

148-
print_structure(flat_shapes)
148+
print_structure(param_dict)
149149

150150

151151
# ==============================================================================
@@ -163,36 +163,38 @@ def inspect_orbax(args):
163163

164164
path = epath.Path(args.path)
165165

166-
try:
167-
# Depending on Orbax version, metadata access might vary slightly.
168-
# This aligns with StandardCheckpointer usage.
169-
metadata = ocp.StandardCheckpointer().metadata(path)
170-
if hasattr(metadata, "item_metadata"):
171-
metadata = metadata.item_metadata
172-
except Exception as e:
173-
sys.exit(f"Error reading Orbax metadata: {e}")
166+
# Depending on Orbax version, metadata access might vary slightly.
167+
# This aligns with StandardCheckpointer usage.
168+
metadata = ocp.StandardCheckpointer().metadata(path)
169+
if hasattr(metadata, "item_metadata"):
170+
metadata = metadata.item_metadata
174171

175172
# Convert to flat dict
176173
dictionary = ocp.tree.to_flat_dict(metadata)
177174

178175
# Filter for params only and clean up keys
179-
flat_shapes = {}
176+
param_dict = {}
180177
for k, v in dictionary.items():
181178
# k is a tuple, join it. v is metadata object with .shape
182-
key_str = ".".join(k)
183-
if key_str.startswith("params"):
184-
flat_shapes[key_str] = v.shape
179+
param_key = ".".join(k)
180+
if not param_key.startswith("params"):
181+
continue
182+
shape = v.shape
183+
param_dict[param_key] = f"shape: {shape}"
184+
dtype = v.dtype
185+
param_dict[param_key] += f" | dtype: {dtype}"
186+
print(v)
185187

186188
print("\n=== Structure ===")
187-
print_structure(flat_shapes)
189+
print_structure(param_dict)
188190

189191

190192
# ==============================================================================
191193
# Main CLI Driver
192194
# ==============================================================================
193195
def main():
194196
parser = argparse.ArgumentParser(description="Consolidated Model Checkpoint Inspector")
195-
subparsers = parser.add_subparsers(dest="mode", required=True, help="Inspection mode")
197+
subparsers = parser.add_subparsers(dest="mode", required=True, help="Inspection mode: hf, maxtext, orbax")
196198

197199
# Mode 1: HuggingFace / PyTorch
198200
parser_hf = subparsers.add_parser("hf", help="Inspect .safetensors or .pth files")
@@ -203,26 +205,17 @@ def main():
203205

204206
# Mode 2: MaxText Architecture
205207
parser_mt = subparsers.add_parser("maxtext", help="Inspect MaxText theoretical architecture")
206-
parser_mt.add_argument("--model_name", type=str, required=True, help="e.g. deepseek3-671b")
207-
parser_mt.add_argument(
208-
"--scan_layers",
209-
type=str,
210-
required=False,
211-
default="true",
212-
choices=["true", "false", "True", "False"],
213-
help="Simulate scanned or unscanned structure",
214-
)
215208

216209
# Mode 3: Orbax
217210
parser_orbax = subparsers.add_parser("orbax", help="Inspect saved Orbax checkpoint metadata")
218211
parser_orbax.add_argument("--path", type=str, required=True, help="Path to checkpoint items (local or GCS)")
219212

220-
args = parser.parse_args()
213+
args, remaining_args = parser.parse_known_args()
221214

222215
if args.mode == "hf":
223216
inspect_hf(args)
224217
elif args.mode == "maxtext":
225-
inspect_maxtext(args)
218+
inspect_maxtext(args, remaining_args)
226219
elif args.mode == "orbax":
227220
inspect_orbax(args)
228221

0 commit comments

Comments
 (0)