@@ -1653,26 +1653,41 @@ def print_shardings_params(params, params_sharding, mesh, logical_annotations=No
16531653 """
16541654 Print state shardings comparing Logical Definition vs Physical Result.
16551655 """
1656- if not hasattr (params , "params" ):
1657- params = {"params" : params }
1658- if not hasattr (params_sharding , "params" ):
1659- params_sharding = {"params" : params_sharding }
1660- if logical_annotations and not hasattr (logical_annotations , "params" ):
1661- logical_annotations = {"params" : logical_annotations }
1656+ if not isinstance (params , nnx .State ):
1657+ if not hasattr (params , "params" ):
1658+ params = {"params" : params }
1659+ if not hasattr (params_sharding , "params" ):
1660+ params_sharding = {"params" : params_sharding }
1661+ if logical_annotations and not hasattr (logical_annotations , "params" ):
1662+ logical_annotations = {"params" : logical_annotations }
16621663
16631664 leaves_params , _ = jax .tree_util .tree_flatten_with_path (params )
16641665 leaves_sharding , _ = jax .tree_util .tree_flatten_with_path (params_sharding )
1665- leaves_logical , _ = jax .tree_util .tree_flatten_with_path (logical_annotations )
16661666
1667- for (path , leaf_val ), (_ , leaf_sharding ), (_ , leaf_logical_val ) in zip (leaves_params , leaves_sharding , leaves_logical ):
1668- path_str = "/" .join (str (p .key if hasattr (p , "key" ) else p .name ) for p in path )
1669- shape = jax .typeof (leaf_val )
1670- pspec = sharding .remove_size_one_mesh_axis (leaf_sharding .spec , mesh )
1671- pspec_str = str (tuple (pspec ))
1672- logical_str = str (leaf_logical_val )
1673-
1674- message = f" { path_str } \n " f" Shape: { shape } \n " f" Logical: { logical_str } \n " f" Physical: { pspec_str } "
1675- max_logging .info (message )
1667+ if logical_annotations is not None :
1668+ leaves_logical , _ = jax .tree_util .tree_flatten_with_path (logical_annotations )
1669+ for (path , leaf_val ), (_ , leaf_sharding ), (_ , leaf_logical_val ) in zip (
1670+ leaves_params , leaves_sharding , leaves_logical
1671+ ):
1672+ path_str = "/" .join (str (p .key if hasattr (p , "key" ) else p .name ) for p in path )
1673+ shape = jax .typeof (leaf_val )
1674+ pspec = sharding .remove_size_one_mesh_axis (leaf_sharding .spec , mesh )
1675+ pspec_str = str (tuple (pspec ))
1676+ logical_str = str (leaf_logical_val )
1677+
1678+ message = (
1679+ f" { path_str } \n " f" Shape: { shape } \n " f" Logical: { logical_str } \n " f" Physical: { pspec_str } "
1680+ )
1681+ max_logging .info (message )
1682+ else :
1683+ for (path , leaf_val ), (_ , leaf_sharding ) in zip (leaves_params , leaves_sharding ):
1684+ path_str = "/" .join (str (p .key if hasattr (p , "key" ) else p .name ) for p in path )
1685+ shape = jax .typeof (leaf_val )
1686+ pspec = sharding .remove_size_one_mesh_axis (leaf_sharding .spec , mesh )
1687+ pspec_str = str (tuple (pspec ))
1688+
1689+ message = f" { path_str } \n " f" Shape: { shape } \n " f" Physical: { pspec_str } "
1690+ max_logging .info (message )
16761691
16771692 print (flush = True )
16781693
0 commit comments