@@ -1253,34 +1253,17 @@ def print_shardings_params(params, params_sharding, mesh, logical_annotations=No
12531253 leaves_sharding , _ = jax .tree_util .tree_flatten_with_path (params_sharding )
12541254 leaves_logical , _ = jax .tree_util .tree_flatten_with_path (logical_annotations .params )
12551255
1256- for i , (( path , leaf_val ), (_ , leaf_sharding )) in enumerate ( zip (leaves_params , leaves_sharding ) ):
1256+ for ( path , leaf_val ), (_ , leaf_sharding ), ( _ , leaf_logical_val ) in zip (leaves_params , leaves_sharding , leaves_logical ):
12571257 path_str = "/" .join (str (p .key if hasattr (p , "key" ) else p .name ) for p in path )
12581258 shape = jax .typeof (leaf_val )
12591259 pspec = sharding .remove_size_one_mesh_axis (leaf_sharding .spec , mesh )
12601260 pspec_str = str (tuple (pspec ))
1261+ logical_str = str (leaf_logical_val )
12611262
1262- if not has_logical :
1263- leaves_logical = [( None , None )] * len ( leaves_params )
1263+ message = f" { path_str } \n " f" Shape: { shape } \n " f" Logical: { logical_str } \n " f" Physical: { pspec_str } "
1264+ max_logging . info ( message )
12641265
1265- if len (leaves_params ) != len (leaves_sharding ):
1266- max_logging .warning ("Warning: Params and Sharding tree mismatch." )
1267- return
1268-
1269- for i , (path , leaf_val ) in enumerate (leaves_params ):
1270- _ , leaf_sharding = leaves_sharding [i ]
1271- leaf_logical_val = leaves_logical [i ][1 ] if has_logical else None
1272-
1273- path_str = "/" .join (str (p .key if hasattr (p , "key" ) else getattr (p , "name" , "?" )) for p in path )
1274-
1275- shape = str (jax .typeof (leaf_val ))
1276-
1277- for (path , leaf_val ), (_ , leaf_sharding ), (_ , leaf_rule_value ) in zip (
1278- leaves_params , leaves_sharding , leaves_rule_values
1279- ):
1280- path_str = "/" .join (str (p .key if hasattr (p , "key" ) else p .name ) for p in path )
1281- shape = jax .typeof (leaf_val )
1282- pspec = sharding .remove_size_one_mesh_axis (leaf_sharding .spec , mesh )
1283- max_logging .log (f"{ path_str :.<80} { shape } { tuple (pspec )} " )
1266+ print (flush = True )
12841267
12851268
12861269def maybe_dump_jaxpr (config , p_train_step , train_step_inputs ):
0 commit comments