Skip to content

Commit c64953d

Browse files
committed
lint
1 parent 9af3c53 commit c64953d

1 file changed

Lines changed: 4 additions & 1 deletion

File tree

  • src/maxtext/integration/vllm/maxtext_vllm_adapter

src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,10 @@ def generate_maxtext_config(vllm_config: VllmConfig, mesh: Mesh) -> pyconfig.Hyp
8282
else:
8383
max_tp_size = mesh.shape[ShardingAxisName.ATTN_HEAD]
8484

85-
if vllm_config.model_config.get_total_num_kv_heads() < max_tp_size:
85+
if (
86+
max_tp_size % vllm_config.model_config.get_total_num_kv_heads() == 0
87+
and vllm_config.model_config.get_total_num_kv_heads() < max_tp_size
88+
):
8689
max_logging.log(
8790
f"Padding num_kv_heads from {vllm_config.model_config.get_total_num_kv_heads()} to {max_tp_size} to match tp_size."
8891
)

0 commit comments

Comments
 (0)