Skip to content

Commit fa4e1e7

Browse files
Merge pull request #3148 from AI-Hypercomputer:chengnuojin-fix-debug
PiperOrigin-RevId: 871547910
2 parents 95ed966 + 8546ebf commit fa4e1e7

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/MaxText/sharding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def maybe_shard_with_logical(
7878
named_sharding = create_sharding(mesh, logical_axes, rules=rules)
7979

8080
if debug_sharding and isinstance(inputs, Tracer):
81-
log_key = (str(jax.typeof(inputs)), logical_axes, extra_stack_level)
81+
log_key = (str(jax.typeof(inputs)), tuple(logical_axes), extra_stack_level)
8282

8383
if log_key not in _LOGGED_LOGICAL_AXES:
8484
max_logging.info(f"Logical: {log_key[0]:.<60} {log_key[1]}", stacklevel=3 + extra_stack_level)

0 commit comments

Comments
 (0)