Skip to content

Commit 0425801

Browse files
authored
[#15613][fix] Gemma4 multimodal: fix vision TP and xgrammar startup crashes (#15566)
Signed-off-by: Thach Nguyen <thach@deepinfra.com>
1 parent 2e33221 commit 0425801

2 files changed

Lines changed: 10 additions & 2 deletions

File tree

tensorrt_llm/_torch/models/modeling_gemma4_vision.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -949,8 +949,12 @@ def _pad_attention_head_dim(self, weights: Dict[str, torch.Tensor]) -> Dict[str,
949949

950950
hf_hd = first_attn.hf_head_dim
951951
padded_hd = first_attn.head_dim
952-
nh = first_attn.num_heads
953-
nkv = first_attn.num_key_value_heads
952+
# first_attn.num_heads / num_key_value_heads are already divided by
953+
# tp_size, but these weights are still unsharded here, so read the full
954+
# head counts from the vision config (no-op at tp1).
955+
vc = first_attn.vision_config
956+
nh = vc.num_attention_heads
957+
nkv = getattr(vc, "num_key_value_heads", nh)
954958
pad_w = padded_hd - hf_hd
955959

956960
# HF keys at this point have already had ``.linear.`` stripped if the

tensorrt_llm/_torch/models/modeling_gemma4mm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -768,6 +768,10 @@ def post_config(self):
768768
def infer_max_seq_len(self) -> int:
769769
return self.llm.infer_max_seq_len()
770770

771+
@property
772+
def vocab_size_padded(self) -> int:
773+
return self.llm.vocab_size_padded
774+
771775
@property
772776
def multimodal_data_device_paths(self) -> List[str]:
773777
"""Dotted paths in ``multimodal_data`` that the engine should ship to

0 commit comments

Comments
 (0)