Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
402 changes: 402 additions & 0 deletions requirements.txt

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/maxtext/common/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
AxisIdxes = tuple[int, ...]

BATCH = "activation_batch"
BATCH_ATTN = "activation_batch_attn"

ATTN_LENGTH = "activation_attn_length"

Expand Down
3 changes: 2 additions & 1 deletion src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,7 @@ ici_tensor_sequence_parallelism: 1
ici_autoregressive_parallelism: 1
ici_pipeline_parallelism: 1
ici_expert_parallelism: 1
ici_attn_dp_expert_parallelism: 1

# Enable ZeRO-1 optimizer sharding over data axis
shard_optimizer_over_data: False
Expand Down Expand Up @@ -985,7 +986,7 @@ xprof_e2e_enable_fw_power_level_event: False
xprof_e2e_enable_fw_thermal_event: False
profile_power_events: False # Set to True to enable TPU-specific power/thermal profiling events. Defaults to False to avoid breaking GPU xplane tracing.

log_config: True # Prints the config (after defaults have been set by pyconfig logic)
log_config: False # Prints the config (after defaults have been set by pyconfig logic)
debug_sharding: False # Prints model weights sharding info

# Checkpoint Structured logging
Expand Down
107 changes: 65 additions & 42 deletions src/maxtext/configs/inference/vllm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,55 +29,78 @@ weight_dtype: bfloat16
# -------------- Logical Axis Rules --------------
mesh_axes: ['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']
logical_axis_rules: [
['activation_batch', ['data']],
['activation_batch_moe', []],
['activation_embed_and_logits_batch', ['data', 'expert']],
['activation_embed_and_logits_batch_sequence', ['data', 'expert']],
# ==========================================
# Vocabulary Embedding
# ==========================================
# Vocab Activations
['activation_embed_and_logits_batch', ['data']],
['activation_embed_and_logits_batch_sequence', ['data']],
['activation_vocab', ['model', 'expert', 'attn_dp', 'attn_dp_expert']],
# Vocab Weights
['vocab', ['model', 'expert', 'attn_dp', 'attn_dp_expert']],
['embed_vocab', []],
# ==========================================
# Attention
# ==========================================
# Attention Activations
['activation_batch_attn', ['data', 'attn_dp', 'attn_dp_expert']],
['activation_heads', ['model', 'expert']],
['activation_kv_heads', ['model', 'expert']],
['activation_attn_length', []],
['activation_length', ['data']],
['activation_length_moe', ['data', 'expert']],
['activation_length_moe', 'data'],
['activation_q_length', ['expert', 'attn_dp_expert']],
['activation_attn_embed', 'model'],
['activation_embed', ['model', 'attn_dp']],
['activation_embed_moe', ['model', 'attn_dp']],
['activation_mlp', ['model', 'attn_dp']],
['activation_mlp_moe', ['model', 'attn_dp']],
['activation_kv', ['model']],
['activation_prefill_kv_batch', ['expert', 'attn_dp_expert']],
['activation_kv_batch', ['data']],
['activation_kv_head_dim', ['model']],
['activation_vocab', ['model', 'attn_dp']],
['activation_norm_length', []],
['activation_norm_length_moe', []],
['activation_exp', ['expert', 'attn_dp_expert']],
['decode_batch', ['expert', 'attn_dp_expert']],
['decode_batch_moe', []],
['decode_length', []],
['mlp', ['model', 'attn_dp']],
['mlp_moe', ['model', 'attn_dp']],
['mlp_no_fsdp', ['model', 'attn_dp']],
['vocab', ['model', 'attn_dp']],
['heads', ['model']],
['activation_attn_embed', []],
['activation_kv', ['model', 'expert']],
['activation_kv_batch', ['data', 'attn_dp', 'attn_dp_expert']],
['activation_kv_head_dim', []],
# Attention Weights
['heads', ['model', 'expert']],
['q_heads', ['model', 'expert']],
['kv_heads', ['model', 'expert']],
['kv_head_dim', []],
['qkv', []],
['kv', []],
['embed', ['expert', 'attn_dp_expert']],
['embed', ['attn_dp_expert']],
['embed_vocab', ['expert', 'attn_dp_expert']],
['embed_vocab', ['attn_dp_expert']],
['embed_moe', []],
['kv_head_dim', []],
['q_lora', []],
["q_lora_up_proj", []],
['kv_lora', []],
["kv_lora_up_proj", []],
# ==========================================
# Mixture of Experts (MoE)
# ==========================================
# MoE Activations
['activation_batch_moe', ['data']],
['activation_embed_moe', ['model']],
['activation_mlp_moe', []],
['activation_exp', ['expert', 'attn_dp', 'attn_dp_expert']],
# MoE Weights
['exp', ['expert', 'attn_dp', 'attn_dp_expert']],
['mlp_moe', []],
['embed_moe', []],
['embed_tensor_transpose', ['attn_dp', 'model']],
['q_lora', ['expert', 'attn_dp_expert']],
['kv_lora', ['expert', 'attn_dp_expert']],
# ==========================================
# Standard MLP / Dense Layers / Model Structure
# ==========================================
# Dense Activations
['activation_mlp', ['model', 'expert', 'attn_dp', 'attn_dp_expert']],
# Note activation batch and length also get used in attention and vocab
['activation_batch', ['data']],
['activation_embed', ['model', 'expert', 'attn_dp', 'attn_dp_expert']],
# General Weights
['mlp', ['model', 'expert', 'attn_dp', 'attn_dp_expert']],
['embed', []],
['norm', []],
# ==========================================
# Inference(Prefill, Decode, Cache)
# ==========================================
['activation_prefill_kv_batch', ['data', 'attn_dp', 'attn_dp_expert']],
['decode_batch', ['data', 'attn_dp', 'attn_dp_expert']],
['cache_heads', ['model', 'expert']],
['cache_heads', ['model']],
['exp', ['expert', 'attn_dp_expert']],
['paged_kv_heads', ['model']],
]
['paged_kv_heads', ['model', 'expert']],
['cache_batch_prefill', []],
['cache_batch', []],
['cache_heads_none', []],
['cache_kv', []],
['cache_sequence', []],
['num_pages', []],
['tokens_per_page', []],
['paged_kv_head_dim_size', []],
]
data_sharding: [['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']]
input_data_sharding_logical_axes: ['activation_embed_and_logits_batch']
8 changes: 7 additions & 1 deletion src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,11 @@ class MoEGeneral(BaseModel):
False,
description="Whether to cast inputs to fp32 to compute MoE gate logits for numerical stability.",
)
prefuse_moe_weights: bool = Field(
False,
description="Whether to pre-fuse MoE weights (w0 and w1) during initialization. "
"This is useful for inference performance in vllm_rpa mode.",
)


class MoEKernels(BaseModel):
Expand Down Expand Up @@ -881,6 +886,7 @@ class IciParallelism(BaseModel):
ici_autoregressive_parallelism: int = Field(1, description="ICI axis for autoregressive parallelism.")
ici_pipeline_parallelism: int = Field(1, description="ICI axis for pipeline parallelism.")
ici_expert_parallelism: int = Field(1, description="ICI axis for expert parallelism.")
ici_attn_dp_expert_parallelism: int = Field(1, description="ICI axis for attn dp expert parallelism.")


class PipelineParallelism(BaseModel):
Expand Down Expand Up @@ -2741,7 +2747,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
"expert": self.ici_expert_parallelism,
"autoregressive": self.ici_autoregressive_parallelism,
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
"attn_dp_expert": 1, # initialized to 1, vLLM will auto calculate this value based on EP
"attn_dp_expert": self.ici_attn_dp_expert_parallelism,
}
self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes]

Expand Down
4 changes: 4 additions & 0 deletions src/maxtext/inference/vllm_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def decode_with_vllm(config: Config) -> None:
"weight_dtype": "bfloat16",
"allow_split_physical_axes": True,
"debug_sharding": config.debug_sharding,
"prefuse_moe_weights": config.prefuse_moe_weights,
},
"sharding": {
"sharding_strategy": {
Expand All @@ -99,6 +100,9 @@ def decode_with_vllm(config: Config) -> None:
enable_expert_parallel = config.ici_expert_parallelism > 1
if enable_expert_parallel:
vllm_args["additional_config"]["sharding"]["sharding_strategy"]["expert_parallelism"] = config.ici_expert_parallelism
vllm_args["additional_config"]["sharding"]["sharding_strategy"][
"attention_data_expert_parallelism"
] = config.ici_attn_dp_expert_parallelism
vllm_args["enable_expert_parallel"] = enable_expert_parallel

max_logging.log(
Expand Down
11 changes: 7 additions & 4 deletions src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh):
# Model creation
self.model: nnx.Module | None = None

# Indicates that the model handles its own sharding logic
self._self_manages_sharding = True

# Handle dummy weight loading during initialization
if vllm_config.load_config.load_format == "dummy":
self.load_weights(rng_key)
Expand Down Expand Up @@ -161,8 +164,8 @@ def __call__(
raise ValueError("Model must be an instance of type nnx.Module.")

# Ensure inputs are at least 2D with a batch dimension
input_ids = jnp.atleast_2d(input_ids)
input_positions = jnp.atleast_2d(attention_metadata.input_positions)
input_ids = jnp.expand_dims(input_ids, axis=1)
input_positions = jnp.expand_dims(attention_metadata.input_positions, axis=1)

with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
aux_hidden_states = []
Expand Down Expand Up @@ -233,7 +236,7 @@ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:

with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
# Reshape to (num_tokens, 1, hidden_dim) for decoder output head
y = hidden_states[:, jnp.newaxis, :]
y = jnp.expand_dims(hidden_states, axis=1)

# Compute logits using the MaxText decoder's output head
logits = self.model.decoder.apply_output_head(self.model.token_embedder, y, True, self.model_mode)
Expand All @@ -250,7 +253,7 @@ def load_weights(self, rng_key: jax.Array) -> None:
if self.model is not None:
return

with self.mesh, nn.logical_axis_rules(""):
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
model, _ = model_creation_utils.create_nnx_model(
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
)
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/layers/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
Array,
AxisIdxes,
AxisNames,
BATCH,
BATCH_ATTN as BATCH,
CACHE_BATCH,
CACHE_BATCH_PREFILL,
CACHE_SEQUENCE,
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/layers/attention_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
AttentionType,
AxisIdxes,
AxisNames,
BATCH,
BATCH_ATTN as BATCH,
CACHE_BATCH,
CACHE_BATCH_PREFILL,
CACHE_HEADS,
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from maxtext.common.common_types import (
DecoderBlockType,
BATCH,
BATCH_ATTN as BATCH,
HEAD,
PREFILL_LENGTH,
D_KV,
Expand Down
Loading
Loading