@@ -1563,26 +1563,41 @@ def print_shardings_params(params, params_sharding, mesh, logical_annotations=No
15631563 """
15641564 Print state shardings comparing Logical Definition vs Physical Result.
15651565 """
1566- if not hasattr (params , "params" ):
1567- params = {"params" : params }
1568- if not hasattr (params_sharding , "params" ):
1569- params_sharding = {"params" : params_sharding }
1570- if logical_annotations and not hasattr (logical_annotations , "params" ):
1571- logical_annotations = {"params" : logical_annotations }
1566+ if not isinstance (params , nnx .State ):
1567+ if not hasattr (params , "params" ):
1568+ params = {"params" : params }
1569+ if not hasattr (params_sharding , "params" ):
1570+ params_sharding = {"params" : params_sharding }
1571+ if logical_annotations and not hasattr (logical_annotations , "params" ):
1572+ logical_annotations = {"params" : logical_annotations }
15721573
15731574 leaves_params , _ = jax .tree_util .tree_flatten_with_path (params )
15741575 leaves_sharding , _ = jax .tree_util .tree_flatten_with_path (params_sharding )
1575- leaves_logical , _ = jax .tree_util .tree_flatten_with_path (logical_annotations )
15761576
1577- for (path , leaf_val ), (_ , leaf_sharding ), (_ , leaf_logical_val ) in zip (leaves_params , leaves_sharding , leaves_logical ):
1578- path_str = "/" .join (str (p .key if hasattr (p , "key" ) else p .name ) for p in path )
1579- shape = jax .typeof (leaf_val )
1580- pspec = sharding .remove_size_one_mesh_axis (leaf_sharding .spec , mesh )
1581- pspec_str = str (tuple (pspec ))
1582- logical_str = str (leaf_logical_val )
1583-
1584- message = f" { path_str } \n " f" Shape: { shape } \n " f" Logical: { logical_str } \n " f" Physical: { pspec_str } "
1585- max_logging .info (message )
1577+ if logical_annotations is not None :
1578+ leaves_logical , _ = jax .tree_util .tree_flatten_with_path (logical_annotations )
1579+ for (path , leaf_val ), (_ , leaf_sharding ), (_ , leaf_logical_val ) in zip (
1580+ leaves_params , leaves_sharding , leaves_logical
1581+ ):
1582+ path_str = "/" .join (str (p .key if hasattr (p , "key" ) else p .name ) for p in path )
1583+ shape = jax .typeof (leaf_val )
1584+ pspec = sharding .remove_size_one_mesh_axis (leaf_sharding .spec , mesh )
1585+ pspec_str = str (tuple (pspec ))
1586+ logical_str = str (leaf_logical_val )
1587+
1588+ message = (
1589+ f" { path_str } \n " f" Shape: { shape } \n " f" Logical: { logical_str } \n " f" Physical: { pspec_str } "
1590+ )
1591+ max_logging .info (message )
1592+ else :
1593+ for (path , leaf_val ), (_ , leaf_sharding ) in zip (leaves_params , leaves_sharding ):
1594+ path_str = "/" .join (str (p .key if hasattr (p , "key" ) else p .name ) for p in path )
1595+ shape = jax .typeof (leaf_val )
1596+ pspec = sharding .remove_size_one_mesh_axis (leaf_sharding .spec , mesh )
1597+ pspec_str = str (tuple (pspec ))
1598+
1599+ message = f" { path_str } \n " f" Shape: { shape } \n " f" Physical: { pspec_str } "
1600+ max_logging .info (message )
15861601
15871602 print (flush = True )
15881603
0 commit comments