|
18 | 18 | import functools |
19 | 19 | import pickle |
20 | 20 |
|
| 21 | +from collections import defaultdict |
| 22 | + |
21 | 23 | from flax import linen as nn |
22 | 24 | from flax.linen import partitioning as nn_partitioning |
23 | 25 | from flax.training import train_state |
|
26 | 28 |
|
27 | 29 | from jax.experimental import mesh_utils |
28 | 30 | from jax.experimental.serialize_executable import deserialize_and_load |
| 31 | +from jax.sharding import PartitionSpec as P |
29 | 32 |
|
30 | 33 | import jax |
31 | 34 | import jax.numpy as jnp |
@@ -1204,12 +1207,109 @@ def schedule(step): |
1204 | 1207 | return optax.join_schedules(pieces, boundaries) |
1205 | 1208 |
|
1206 | 1209 |
|
1207 | | -def print_state_mesh_shardings_params(state, state_sharding, mesh): |
| 1210 | +def print_state_mesh_shardings_params( |
| 1211 | + state, state_sharding, state_logical_annotations, mesh, logical_axis_rules |
| 1212 | +): |
1208 | 1213 | """Print state shardings.""" |
| 1214 | + if (not hasattr(state, 'params') or |
| 1215 | + not hasattr(state_sharding, 'params') or |
| 1216 | + not hasattr(state_logical_annotations, 'params')): |
| 1217 | + max_logging.warning( |
| 1218 | + "Warning: 'params' attribute missing in one of the inputs to " |
| 1219 | + "print_state_mesh_shardings_params." |
| 1220 | + ) |
| 1221 | + return |
| 1222 | + |
1209 | 1223 | leaves_params, _ = jax.tree_util.tree_flatten_with_path(state.params) |
1210 | 1224 | leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(state_sharding.params) |
1211 | | - for (path, leaf_val), (_, leaf_sharding) in zip(leaves_params, leaves_sharding): |
| 1225 | + leaves_rule_values, _ = jax.tree_util.tree_flatten_with_path( |
| 1226 | + state_logical_annotations.params |
| 1227 | + ) |
| 1228 | + |
| 1229 | + if not len(leaves_params) == len(leaves_sharding) == len(leaves_rule_values): |
| 1230 | + max_logging.warning( |
| 1231 | + "Warning: Parameter tree structure mismatch between state, shardings," |
| 1232 | + " and logical annotations." |
| 1233 | + ) |
| 1234 | + return |
| 1235 | + |
| 1236 | + # Build a reverse map (Potential Physical Axes Tuple -> List of Semantic Names) |
| 1237 | + rule_value_to_semantic = defaultdict(list) |
| 1238 | + if logical_axis_rules: |
| 1239 | + rules_iter = ( |
| 1240 | + logical_axis_rules.items() |
| 1241 | + if isinstance(logical_axis_rules, dict) |
| 1242 | + else logical_axis_rules |
| 1243 | + ) |
| 1244 | + for name, potentials in rules_iter: |
| 1245 | + # name: LHS for example 'embed/activation_batch |
| 1246 | + # potentials: RHS for example 'data', 'model', None, ['data', 'model'] |
| 1247 | + if isinstance(potentials, str): |
| 1248 | + key = (potentials,) |
| 1249 | + elif potentials is None: |
| 1250 | + key = (None,) |
| 1251 | + elif isinstance(potentials, list): |
| 1252 | + key = tuple(potentials) |
| 1253 | + elif isinstance(potentials, tuple): |
| 1254 | + key = potentials |
| 1255 | + else: |
| 1256 | + key = (potentials,) |
| 1257 | + |
| 1258 | + key = tuple(p for p in key) |
| 1259 | + rule_value_to_semantic[key].append(name) |
| 1260 | + |
| 1261 | + # Header for the entire block |
| 1262 | + max_logging.info("Parameter Path") |
| 1263 | + max_logging.info("Shape") |
| 1264 | + max_logging.info("Logical Axes") |
| 1265 | + max_logging.info("Physical PartitionSpec") |
| 1266 | + max_logging.info("-" * 120) |
| 1267 | + |
| 1268 | + for (path, leaf_val), (_, leaf_sharding), (_, leaf_rule_value) in zip( |
| 1269 | + leaves_params, leaves_sharding, leaves_rule_values |
| 1270 | + ): |
1212 | 1271 | path_str = "/".join(str(p.key) for p in path) |
1213 | | - shape = jax.typeof(leaf_val) |
| 1272 | + shape = str(jax.typeof(leaf_val)) |
| 1273 | + |
| 1274 | + # Physical PartitionSpec from NamedSharding |
1214 | 1275 | pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh) |
1215 | | - max_logging.log(f"{path_str:.<80} {shape} {tuple(pspec)}") |
| 1276 | + pspec_str = str(tuple(pspec)) |
| 1277 | + |
| 1278 | + def get_semantic_names(rule_val_item, rmap): |
| 1279 | + if rule_val_item is None: |
| 1280 | + key = (None,) |
| 1281 | + elif isinstance(rule_val_item, str): |
| 1282 | + key = (rule_val_item,) |
| 1283 | + elif isinstance(rule_val_item, tuple): |
| 1284 | + key = rule_val_item |
| 1285 | + else: |
| 1286 | + return f"'{str(rule_val_item)}'" |
| 1287 | + |
| 1288 | + key = tuple(p for p in key) |
| 1289 | + names = rmap.get(key) |
| 1290 | + |
| 1291 | + if names: |
| 1292 | + return "{" + " | ".join(sorted(list(set(names)))) + "}" |
| 1293 | + else: |
| 1294 | + # Show rule value if unmapped. |
| 1295 | + return f"'{str(key)}'" |
| 1296 | + |
| 1297 | + # Logical Axes string representation |
| 1298 | + if isinstance(leaf_rule_value, P): |
| 1299 | + semantic_parts = [] |
| 1300 | + for s in leaf_rule_value: |
| 1301 | + name_str = get_semantic_names(s, rule_value_to_semantic) |
| 1302 | + semantic_parts.append(str(name_str)) |
| 1303 | + semantic_str = "Partitionspec(" + ", ".join(semantic_parts) + ")" |
| 1304 | + elif leaf_rule_value is None: |
| 1305 | + semantic_str = "Partitionspec(None)" |
| 1306 | + else: # Should not be common |
| 1307 | + semantic_str = str(leaf_rule_value) |
| 1308 | + |
| 1309 | + # Multi-line logging for each parameter |
| 1310 | + max_logging.info(f"{path_str}") |
| 1311 | + max_logging.info(f"{shape}") |
| 1312 | + max_logging.info(f"{semantic_str}") |
| 1313 | + max_logging.info(f"{pspec_str}") |
| 1314 | + max_logging.info("-" * 120) |
| 1315 | + print(flush=True) |
0 commit comments