Skip to content

Commit 0f4beb8

Browse files
committed
update debug sharding logic for decode
1 parent 4d486c4 commit 0f4beb8

7 files changed

Lines changed: 122 additions & 110 deletions

File tree

src/maxtext/common/common_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
AxisIdxes = tuple[int, ...]
3333

3434
BATCH = "activation_batch"
35+
BATCH_ATTN = "activation_batch_attn"
3536

3637
ATTN_LENGTH = "activation_attn_length"
3738

src/maxtext/configs/custom_mesh_and_rule/vllm-attn-ep.yml

Lines changed: 0 additions & 53 deletions
This file was deleted.

src/maxtext/configs/inference/vllm.yml

Lines changed: 64 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -29,59 +29,78 @@ weight_dtype: bfloat16
2929
# -------------- Logical Axis Rules --------------
3030
mesh_axes: ['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']
3131
logical_axis_rules: [
32-
['activation_batch', ['data']],
33-
['activation_batch_moe', ['data']],
34-
['activation_embed_and_logits_batch', ['data', 'expert']],
35-
['activation_embed_and_logits_batch_sequence', ['data', 'expert']],
32+
# ==========================================
33+
# Vocabulary Embedding
34+
# ==========================================
35+
# Vocab Activations
36+
['activation_embed_and_logits_batch', ['data']],
37+
['activation_embed_and_logits_batch_sequence', ['data']],
38+
['activation_vocab', ['model', 'expert', 'attn_dp', 'attn_dp_expert']],
39+
# Vocab Weights
40+
['vocab', ['model', 'expert', 'attn_dp', 'attn_dp_expert']],
41+
['embed_vocab', []],
42+
# ==========================================
43+
# Attention
44+
# ==========================================
45+
# Attention Activations
46+
['activation_batch_attn', ['data', 'attn_dp', 'attn_dp_expert']],
3647
['activation_heads', ['model', 'expert']],
3748
['activation_kv_heads', ['model', 'expert']],
38-
['activation_attn_length', []],
39-
['activation_length', []],
40-
['activation_length_moe', []],
41-
['activation_q_length', ['expert', 'attn_dp_expert']],
42-
['activation_attn_embed', 'model'],
43-
# Expert is missing explicitly from activation_embed despite using TP.
44-
# We are going for a replicate-AR style of TP as opposed to our typical AG-RS style of TP
45-
# due to the output sharding of the fused_moe_gmm kernel in tpu-inference.
46-
['activation_embed', ['model', 'attn_dp']],
47-
['activation_embed_moe', ['model', 'attn_dp']],
48-
['activation_mlp', ['model', 'attn_dp']],
49-
['activation_mlp_moe', ['model', 'attn_dp']],
50-
['activation_kv', ['model']],
51-
['activation_prefill_kv_batch', ['expert', 'attn_dp_expert']],
52-
['activation_kv_batch', ['data']],
53-
['activation_kv_head_dim', ['model']],
54-
['activation_vocab', ['model', 'attn_dp']],
55-
['activation_norm_length', []],
56-
['activation_norm_length_moe', []],
57-
['activation_exp', ['expert', 'attn_dp_expert']],
58-
['decode_batch', ['data']],
59-
['decode_batch_moe', ['data']],
60-
['decode_length', []],
61-
['mlp', ['model', 'attn_dp']],
62-
['mlp_moe', ['model', 'attn_dp']],
63-
['mlp_no_fsdp', ['model', 'attn_dp']],
64-
['vocab', ['model', 'attn_dp']],
65-
# Expert is intended to act like TP for attention.
66-
# We target two all-reduces, one at the end of attention out projection and one at the end of the feedforward.
49+
['activation_attn_embed', []],
50+
['activation_kv', ['model', 'expert']],
51+
['activation_kv_batch', ['data', 'attn_dp', 'attn_dp_expert']],
52+
['activation_kv_head_dim', []],
53+
# Attention Weights
6754
['heads', ['model', 'expert']],
6855
['q_heads', ['model', 'expert']],
6956
['kv_heads', ['model', 'expert']],
70-
['kv_head_dim', []],
57+
['qkv', []],
7158
['kv', []],
72-
['embed', ['expert', 'attn_dp_expert']],
73-
['embed', ['attn_dp_expert']],
74-
['embed_vocab', ['expert', 'attn_dp_expert']],
75-
['embed_vocab', ['attn_dp_expert']],
76-
['embed_moe', []],
59+
['kv_head_dim', []],
60+
['q_lora', []],
61+
["q_lora_up_proj", []],
62+
['kv_lora', []],
63+
["kv_lora_up_proj", []],
64+
# ==========================================
65+
# Mixture of Experts (MoE)
66+
# ==========================================
67+
# MoE Activations
68+
['activation_batch_moe', ['data']],
69+
['activation_embed_moe', ['model']],
70+
['activation_mlp_moe', []],
71+
['activation_exp', ['expert', 'attn_dp', 'attn_dp_expert']],
72+
# MoE Weights
73+
['exp', ['expert', 'attn_dp', 'attn_dp_expert']],
74+
['mlp_moe', []],
7775
['embed_moe', []],
78-
['embed_tensor_transpose', ['attn_dp', 'model']],
79-
['q_lora', ['expert', 'attn_dp_expert']],
80-
['kv_lora', ['expert', 'attn_dp_expert']],
76+
# ==========================================
77+
# Standard MLP / Dense Layers / Model Structure
78+
# ==========================================
79+
# Dense Activations
80+
['activation_mlp', ['model', 'expert', 'attn_dp', 'attn_dp_expert']],
81+
# Note activation batch and length also get used in attention and vocab
82+
['activation_batch', ['data']],
83+
['activation_embed', ['model', 'expert', 'attn_dp', 'attn_dp_expert']],
84+
# General Weights
85+
['mlp', ['model', 'expert', 'attn_dp', 'attn_dp_expert']],
86+
['embed', []],
8187
['norm', []],
88+
# ==========================================
89+
# Inference(Prefill, Decode, Cache)
90+
# ==========================================
91+
['activation_prefill_kv_batch', ['data', 'attn_dp', 'attn_dp_expert']],
92+
['decode_batch', ['data', 'attn_dp', 'attn_dp_expert']],
93+
['cache_heads', ['model', 'expert']],
8294
['cache_heads', ['model']],
83-
['exp', ['expert', 'attn_dp_expert']],
84-
['paged_kv_heads', ['model']],
85-
]
95+
['paged_kv_heads', ['model', 'expert']],
96+
['cache_batch_prefill', []],
97+
['cache_batch', []],
98+
['cache_heads_none', []],
99+
['cache_kv', []],
100+
['cache_sequence', []],
101+
['num_pages', []],
102+
['tokens_per_page', []],
103+
['paged_kv_head_dim_size', []],
104+
]
86105
data_sharding: [['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']]
87106
input_data_sharding_logical_axes: ['activation_embed_and_logits_batch']

src/maxtext/layers/attention_mla.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
Array,
3737
AxisIdxes,
3838
AxisNames,
39-
BATCH,
39+
BATCH_ATTN as BATCH,
4040
CACHE_BATCH,
4141
CACHE_BATCH_PREFILL,
4242
CACHE_SEQUENCE,

src/maxtext/layers/attention_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
AttentionType,
3939
AxisIdxes,
4040
AxisNames,
41-
BATCH,
41+
BATCH_ATTN as BATCH,
4242
CACHE_BATCH,
4343
CACHE_BATCH_PREFILL,
4444
CACHE_HEADS,

src/maxtext/layers/attentions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
from maxtext.common.common_types import (
2929
DecoderBlockType,
30-
BATCH,
30+
BATCH_ATTN as BATCH,
3131
HEAD,
3232
PREFILL_LENGTH,
3333
D_KV,

src/maxtext/utils/maxtext_utils.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
# limitations under the License.
1414

1515
# pylint: disable=line-too-long, disable=bare-except, consider-using-generator
16-
""" Utils that are only interesting to MaxText. """
16+
"""Utils that are only interesting to MaxText."""
1717

1818
import functools
1919
import pickle
2020
import os
2121

22-
from flax import linen as nn
22+
from flax import nnx, linen as nn
2323
from flax.linen import partitioning as nn_partitioning
2424
from flax.training import train_state
2525

@@ -1625,7 +1625,35 @@ def schedule(step):
16251625
return optax.join_schedules(pieces, boundaries)
16261626

16271627

1628-
def print_shardings_params(params, params_sharding, mesh, logical_annotations=None):
1628+
# def print_shardings_params(params, params_sharding, mesh, logical_annotations=None):
1629+
# """
1630+
# Print state shardings comparing Logical Definition vs Physical Result.
1631+
# """
1632+
# if not hasattr(params, "params"):
1633+
# params = {"params": params}
1634+
# if not hasattr(params_sharding, "params"):
1635+
# params_sharding = {"params": params_sharding}
1636+
# if logical_annotations and not hasattr(logical_annotations, "params"):
1637+
# logical_annotations = {"params": logical_annotations}
1638+
1639+
# leaves_params, _ = jax.tree_util.tree_flatten_with_path(params, is_leaf=lambda x: isinstance(x, nnx.Variable))
1640+
# leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(params_sharding, is_leaf=lambda x: isinstance(x, nnx.Variable))
1641+
# leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations, is_leaf=lambda x: isinstance(x, nnx.Variable))
1642+
1643+
# for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip(leaves_params, leaves_sharding, leaves_logical):
1644+
# path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path)
1645+
# shape = jax.typeof(getattr(leaf_val, "value"))
1646+
# pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
1647+
# pspec_str = str(tuple(pspec))
1648+
# logical_str = str(getattr(leaf_logical_val, "out_sharding", None))
1649+
1650+
# message = f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}"
1651+
# max_logging.info(message)
1652+
1653+
# print(flush=True)
1654+
1655+
1656+
def print_shardings_params(params, params_sharding, mesh, logical_annotations=None, target_layer=0):
16291657
"""
16301658
Print state shardings comparing Logical Definition vs Physical Result.
16311659
"""
@@ -1636,16 +1664,33 @@ def print_shardings_params(params, params_sharding, mesh, logical_annotations=No
16361664
if logical_annotations and not hasattr(logical_annotations, "params"):
16371665
logical_annotations = {"params": logical_annotations}
16381666

1639-
leaves_params, _ = jax.tree_util.tree_flatten_with_path(params)
1640-
leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(params_sharding)
1641-
leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations)
1667+
leaves_params, _ = jax.tree_util.tree_flatten_with_path(params, is_leaf=lambda x: isinstance(x, nnx.Variable))
1668+
leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(
1669+
params_sharding, is_leaf=lambda x: isinstance(x, nnx.Variable)
1670+
)
1671+
leaves_logical, _ = jax.tree_util.tree_flatten_with_path(
1672+
logical_annotations, is_leaf=lambda x: isinstance(x, nnx.Variable)
1673+
)
16421674

16431675
for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip(leaves_params, leaves_sharding, leaves_logical):
1644-
path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path)
1645-
shape = jax.typeof(leaf_val)
1676+
# Extract path keys to accurately check for layer names
1677+
path_keys = [str(p.key if hasattr(p, "key") else p.name) for p in path]
1678+
path_str = "/".join(path_keys)
1679+
1680+
# Check if param is inside a layer block, and if it's the target layer
1681+
is_layer_param = any(k.startswith("layers_") for k in path_keys)
1682+
is_target_layer = any(k == f"layers_{target_layer}" for k in path_keys)
1683+
# Skip logging if it belongs to a layer that isn't our target
1684+
if is_layer_param and not is_target_layer:
1685+
continue
1686+
1687+
if "to_nnx__rngs" in path_str:
1688+
continue
1689+
1690+
shape = jax.typeof(getattr(leaf_val, "value"))
16461691
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
16471692
pspec_str = str(tuple(pspec))
1648-
logical_str = str(leaf_logical_val)
1693+
logical_str = str(getattr(leaf_logical_val, "out_sharding", None))
16491694

16501695
message = f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}"
16511696
max_logging.info(message)

0 commit comments

Comments
 (0)