We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
2 parents 95ed966 + 8546ebf commit fa4e1e7Copy full SHA for fa4e1e7
1 file changed
src/MaxText/sharding.py
@@ -78,7 +78,7 @@ def maybe_shard_with_logical(
78
named_sharding = create_sharding(mesh, logical_axes, rules=rules)
79
80
if debug_sharding and isinstance(inputs, Tracer):
81
- log_key = (str(jax.typeof(inputs)), logical_axes, extra_stack_level)
+ log_key = (str(jax.typeof(inputs)), tuple(logical_axes), extra_stack_level)
82
83
if log_key not in _LOGGED_LOGICAL_AXES:
84
max_logging.info(f"Logical: {log_key[0]:.<60} {log_key[1]}", stacklevel=3 + extra_stack_level)
0 commit comments