@@ -1207,40 +1207,32 @@ def schedule(step):
12071207 return optax .join_schedules (pieces , boundaries )
12081208
12091209
1210- def print_state_mesh_shardings_params (
1211- state , state_sharding , state_logical_annotations , mesh , logical_axis_rules
1212- ):
1210+ def print_state_mesh_shardings_params (state , state_sharding , state_logical_annotations , mesh , logical_axis_rules ):
12131211 """Print state shardings."""
1214- if (not hasattr (state , 'params' ) or
1215- not hasattr (state_sharding , 'params' ) or
1216- not hasattr (state_logical_annotations , 'params' )):
1212+ if (
1213+ not hasattr (state , "params" )
1214+ or not hasattr (state_sharding , "params" )
1215+ or not hasattr (state_logical_annotations , "params" )
1216+ ):
12171217 max_logging .warning (
1218- "Warning: 'params' attribute missing in one of the inputs to "
1219- "print_state_mesh_shardings_params."
1218+ "Warning: 'params' attribute missing in one of the inputs to " "print_state_mesh_shardings_params."
12201219 )
12211220 return
12221221
12231222 leaves_params , _ = jax .tree_util .tree_flatten_with_path (state .params )
12241223 leaves_sharding , _ = jax .tree_util .tree_flatten_with_path (state_sharding .params )
1225- leaves_rule_values , _ = jax .tree_util .tree_flatten_with_path (
1226- state_logical_annotations .params
1227- )
1224+ leaves_rule_values , _ = jax .tree_util .tree_flatten_with_path (state_logical_annotations .params )
12281225
12291226 if not len (leaves_params ) == len (leaves_sharding ) == len (leaves_rule_values ):
12301227 max_logging .warning (
1231- "Warning: Parameter tree structure mismatch between state, shardings,"
1232- " and logical annotations."
1228+ "Warning: Parameter tree structure mismatch between state, shardings," " and logical annotations."
12331229 )
12341230 return
12351231
12361232 # Build a reverse map (Potential Physical Axes Tuple -> List of Semantic Names)
12371233 rule_value_to_semantic = defaultdict (list )
12381234 if logical_axis_rules :
1239- rules_iter = (
1240- logical_axis_rules .items ()
1241- if isinstance (logical_axis_rules , dict )
1242- else logical_axis_rules
1243- )
1235+ rules_iter = logical_axis_rules .items () if isinstance (logical_axis_rules , dict ) else logical_axis_rules
12441236 for name , potentials in rules_iter :
12451237 # name: LHS for example 'embed/activation_batch
12461238 # potentials: RHS for example 'data', 'model', None, ['data', 'model']
@@ -1289,7 +1281,7 @@ def get_semantic_names(rule_val_item, rmap):
12891281 names = rmap .get (key )
12901282
12911283 if names :
1292- return "{" + " | " .join (sorted (list (set (names )))) + "}"
1284+ return "{" + " | " .join (sorted (list (set (names )))) + "}"
12931285 else :
12941286 # Show rule value if unmapped.
12951287 return f"'{ str (key )} '"
@@ -1302,8 +1294,8 @@ def get_semantic_names(rule_val_item, rmap):
13021294 semantic_parts .append (str (name_str ))
13031295 semantic_str = "Partitionspec(" + ", " .join (semantic_parts ) + ")"
13041296 elif leaf_rule_value is None :
1305- semantic_str = "Partitionspec(None)"
1306- else : # Should not be common
1297+ semantic_str = "Partitionspec(None)"
1298+ else : # Should not be common
13071299 semantic_str = str (leaf_rule_value )
13081300
13091301 # Multi-line logging for each parameter
0 commit comments