2323[Mode 1: HF/PyTorch]
2424 python src/maxtext/checkpoint_conversion/inspect_checkpoint.py hf --path <local_hf_path> --format <safetensors | pth>
2525[Mode 2: MaxText Arch]
26- python src/maxtext/checkpoint_conversion/inspect_checkpoint.py maxtext -- model_name <maxtext_model_name> -- scan_layers <True | False>
26+ python src/maxtext/checkpoint_conversion/inspect_checkpoint.py maxtext model_name <maxtext_model_name> scan_layers <True | False>
2727[Mode 3: Orbax]
2828 python src/maxtext/checkpoint_conversion/inspect_checkpoint.py orbax --path <local_orbax_path | gcs_orbax_path>
2929"""
@@ -43,8 +43,7 @@ def natural_sort_key(s: str):
4343def print_structure (data_dict ):
4444 """Utility to print sorted keys and shapes from a flat dictionary."""
4545 for key in sorted (data_dict .keys (), key = natural_sort_key ):
46- shape = data_dict [key ]
47- print (f"key: { key } | shape: { shape } " )
46+ print (f"key: { key } | { data_dict [key ]} " )
4847
4948
5049# ==============================================================================
@@ -53,17 +52,11 @@ def print_structure(data_dict):
5352def inspect_hf (args ):
5453 print (f"\n --- Inspecting { args .format } files in { args .path } ---" )
5554
56- # Lazy imports
57- try :
58- import torch
59- except ImportError :
60- sys .exit ("Error: 'torch' is required for this mode. `pip install torch`" )
61-
6255 ckpt_paths = sorted (pathlib .Path (args .path ).glob (f"[!.]*.{ args .format } " ))
6356 if not ckpt_paths :
6457 sys .exit (f"No files with extension .{ args .format } found in { args .path } " )
6558
66- chkpt_vars_raw = {}
59+ param_dict = {}
6760
6861 if args .format == "safetensors" :
6962 try :
@@ -76,31 +69,34 @@ def inspect_hf(args):
7669 with safe_open (ckpt_path , framework = "pt" ) as f :
7770 for k in f .keys ():
7871 # Storing shape directly to save memory, rather than the full tensor
79- chkpt_vars_raw [k ] = f .get_tensor (k ).shape
72+ shape = f .get_tensor (k ).shape
73+ param_dict [k ] = f"shape: { shape } "
8074
8175 elif args .format == "pth" :
76+ try :
77+ import torch
78+ except ImportError :
79+ sys .exit ("Error: 'torch' is required for this mode. `pip install torch`" )
80+
8281 for i , ckpt_path in enumerate (ckpt_paths ):
8382 print (f"Loading { ckpt_path .name } ({ i + 1 } /{ len (ckpt_paths )} )..." )
8483 checkpoint = torch .load (ckpt_path , map_location = "cpu" )
8584 # Flatten logic might be needed depending on pth structure,
8685 # here we assume standard state_dict or handle the wrapper keys manually if needed.
8786 if isinstance (checkpoint , dict ):
8887 for k , v in checkpoint .items ():
89- if hasattr (v , "shape" ):
90- chkpt_vars_raw [k ] = v .shape
91- else :
92- # Handle nested state dicts or wrapper keys if common in your workflow
93- chkpt_vars_raw [k ] = "Non-tensor found"
88+ # Handle nested state dicts or wrapper keys if common in your workflow
89+ shape = v .shape if hasattr (v , "shape" ) else "Non-tensor found"
90+ param_dict [k ] = f"shape: { shape } "
9491
9592 print ("\n === Structure ===" )
96- print_structure (chkpt_vars_raw )
93+ print_structure (param_dict )
9794
9895
9996# ==============================================================================
10097# Mode 2: MaxText Architecture (On-the-fly)
10198# ==============================================================================
102- def inspect_maxtext (args ):
103- print (f"\n --- Inspecting MaxText Architecture: { args .model_name } (Scan: { args .scan_layers } ) ---" )
99+ def inspect_maxtext (args , remaining_args ):
104100
105101 # Lazy imports
106102 import jax
@@ -113,17 +109,17 @@ def inspect_maxtext(args):
113109 Transformer = models .transformer_as_linen
114110
115111 # Setup config
116- argv = [
117- "" , # First arg is usually script name in pyconfig
118- os .path .join (MAXTEXT_PKG_DIR , "configs" , "base.yml" ),
119- f"model_name={ args .model_name } " ,
120- f"scan_layers={ args .scan_layers } " ,
121- "attention=dot_product" ,
122- "skip_jax_distributed_system=true" ,
123- ]
112+ argv = (
113+ # First arg is usually script name in pyconfig
114+ [None , os .path .join (MAXTEXT_PKG_DIR , "configs" , "base.yml" )]
115+ + remaining_args
116+ + ["attention=dot_product" , "skip_jax_distributed_system=true" ]
117+ )
118+ print (argv )
124119
125120 # Initialize without heavyweight runtime
126121 config = pyconfig .initialize (argv )
122+ print (f"\n --- Inspecting MaxText Architecture: { config .model_name } (Scan: { config .scan_layers } ) ---" )
127123 devices_array = maxtext_utils .create_device_mesh (config )
128124 mesh = jax .sharding .Mesh (devices_array , config .mesh_axes )
129125 quant = quantizations .configure_quantization (config )
@@ -133,19 +129,23 @@ def inspect_maxtext(args):
133129 abstract_param = maxtext_utils .get_abstract_param (model , config )
134130 num_params = max_utils .calculate_num_params_from_pytree (abstract_param )
135131
136- print (f"\n Total Parameters: { num_params } (~{ num_params / 1e9 :.2f} B)" )
132+ print (f"\n Total Parameters: { num_params } (~{ num_params / 1e9 :.2f} B)" )
137133 print ("\n === Structure ===" )
138134
139135 abstract_params_flat , _ = jax .tree_util .tree_flatten_with_path (abstract_param )
140136
141- flat_shapes = {}
137+ param_dict = {}
138+ # abstract_leaf_value: ShapeDtypeStruct(shape=(128, 58), dtype=float32)
142139 for path_tuple , abstract_leaf_value in abstract_params_flat :
143140 key_parts = [k .key for k in path_tuple if hasattr (k , "key" )]
144141 # Construct MaxText style parameter key
145- mt_param_key = "params-" + "-" .join (key_parts )
146- flat_shapes [mt_param_key ] = abstract_leaf_value .shape
142+ param_key = "params-" + "-" .join (key_parts )
143+ shape = abstract_leaf_value .shape
144+ param_dict [param_key ] = f"shape: { shape } "
145+ dtype = abstract_leaf_value .dtype
146+ param_dict [param_key ] += f" | dtype: { dtype } "
147147
148- print_structure (flat_shapes )
148+ print_structure (param_dict )
149149
150150
151151# ==============================================================================
@@ -163,36 +163,38 @@ def inspect_orbax(args):
163163
164164 path = epath .Path (args .path )
165165
166- try :
167- # Depending on Orbax version, metadata access might vary slightly.
168- # This aligns with StandardCheckpointer usage.
169- metadata = ocp .StandardCheckpointer ().metadata (path )
170- if hasattr (metadata , "item_metadata" ):
171- metadata = metadata .item_metadata
172- except Exception as e :
173- sys .exit (f"Error reading Orbax metadata: { e } " )
166+ # Depending on Orbax version, metadata access might vary slightly.
167+ # This aligns with StandardCheckpointer usage.
168+ metadata = ocp .StandardCheckpointer ().metadata (path )
169+ if hasattr (metadata , "item_metadata" ):
170+ metadata = metadata .item_metadata
174171
175172 # Convert to flat dict
176173 dictionary = ocp .tree .to_flat_dict (metadata )
177174
178175 # Filter for params only and clean up keys
179- flat_shapes = {}
176+ param_dict = {}
180177 for k , v in dictionary .items ():
181178 # k is a tuple, join it. v is metadata object with .shape
182- key_str = "." .join (k )
183- if key_str .startswith ("params" ):
184- flat_shapes [key_str ] = v .shape
179+ param_key = "." .join (k )
180+ if not param_key .startswith ("params" ):
181+ continue
182+ shape = v .shape
183+ param_dict [param_key ] = f"shape: { shape } "
184+ dtype = v .dtype
185+ param_dict [param_key ] += f" | dtype: { dtype } "
186+ print (v )
185187
186188 print ("\n === Structure ===" )
187- print_structure (flat_shapes )
189+ print_structure (param_dict )
188190
189191
190192# ==============================================================================
191193# Main CLI Driver
192194# ==============================================================================
193195def main ():
194196 parser = argparse .ArgumentParser (description = "Consolidated Model Checkpoint Inspector" )
195- subparsers = parser .add_subparsers (dest = "mode" , required = True , help = "Inspection mode" )
197+ subparsers = parser .add_subparsers (dest = "mode" , required = True , help = "Inspection mode: hf, maxtext, orbax " )
196198
197199 # Mode 1: HuggingFace / PyTorch
198200 parser_hf = subparsers .add_parser ("hf" , help = "Inspect .safetensors or .pth files" )
@@ -203,26 +205,17 @@ def main():
203205
204206 # Mode 2: MaxText Architecture
205207 parser_mt = subparsers .add_parser ("maxtext" , help = "Inspect MaxText theoretical architecture" )
206- parser_mt .add_argument ("--model_name" , type = str , required = True , help = "e.g. deepseek3-671b" )
207- parser_mt .add_argument (
208- "--scan_layers" ,
209- type = str ,
210- required = False ,
211- default = "true" ,
212- choices = ["true" , "false" , "True" , "False" ],
213- help = "Simulate scanned or unscanned structure" ,
214- )
215208
216209 # Mode 3: Orbax
217210 parser_orbax = subparsers .add_parser ("orbax" , help = "Inspect saved Orbax checkpoint metadata" )
218211 parser_orbax .add_argument ("--path" , type = str , required = True , help = "Path to checkpoint items (local or GCS)" )
219212
220- args = parser .parse_args ()
213+ args , remaining_args = parser .parse_known_args ()
221214
222215 if args .mode == "hf" :
223216 inspect_hf (args )
224217 elif args .mode == "maxtext" :
225- inspect_maxtext (args )
218+ inspect_maxtext (args , remaining_args )
226219 elif args .mode == "orbax" :
227220 inspect_orbax (args )
228221
0 commit comments