Skip to content

Commit e1f2cc9

Browse files
authored
[mm] fix broken MRoPE for GLM-4.1/4.5V (#1575)
Signed-off-by: AlpinDale <alpindale@gmail.com>
1 parent 253aec1 commit e1f2cc9

1 file changed

Lines changed: 138 additions & 8 deletions

File tree

aphrodite/modeling/models/glm4_1v.py

Lines changed: 138 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
# limitations under the License.
2424
"""Inference-only GLM-4V model compatible with HuggingFace weights."""
2525

26+
import itertools
2627
import math
2728
from collections.abc import Callable, Iterable, Mapping, Sequence
2829
from functools import partial
@@ -33,14 +34,20 @@
3334
import torch.nn as nn
3435
import torch.nn.functional as F
3536
from einops import rearrange
36-
from transformers import BatchFeature
37+
from transformers import BatchFeature, PretrainedConfig
3738
from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig
38-
from transformers.models.glm4v.image_processing_glm4v import Glm4vImageProcessor, smart_resize
39+
from transformers.models.glm4v.image_processing_glm4v import (
40+
Glm4vImageProcessor,
41+
smart_resize,
42+
)
3943
from transformers.models.glm4v.video_processing_glm4v import Glm4vVideoProcessor
4044
from transformers.video_utils import VideoMetadata
4145

4246
from aphrodite.attention.backends.registry import _Backend
43-
from aphrodite.attention.layer import check_upstream_fa_availability, maybe_get_vit_flash_attn_backend
47+
from aphrodite.attention.layer import (
48+
check_upstream_fa_availability,
49+
maybe_get_vit_flash_attn_backend,
50+
)
4451
from aphrodite.common.sequence import IntermediateTensors
4552
from aphrodite.config import AphroditeConfig
4653
from aphrodite.config.multimodal import BaseDummyOptions, VideoDummyOptions
@@ -58,7 +65,12 @@
5865
from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
5966
from aphrodite.modeling.models.module_mapping import MultiModelKeys
6067
from aphrodite.multimodal import MULTIMODAL_REGISTRY
61-
from aphrodite.multimodal.inputs import MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, VideoItem
68+
from aphrodite.multimodal.inputs import (
69+
MultiModalDataDict,
70+
MultiModalFieldConfig,
71+
MultiModalKwargsItems,
72+
VideoItem,
73+
)
6274
from aphrodite.multimodal.parse import ImageSize, MultiModalDataItems, MultiModalDataParser
6375
from aphrodite.multimodal.processing import (
6476
BaseMultiModalProcessor,
@@ -72,10 +84,25 @@
7284
from aphrodite.utils.tensor_schema import TensorSchema, TensorShape
7385

7486
from ..layers.activation import SiluAndMul
75-
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP
87+
from .interfaces import (
88+
MultiModalEmbeddings,
89+
SupportsLoRA,
90+
SupportsMRoPE,
91+
SupportsMultiModal,
92+
SupportsPP,
93+
)
7694
from .qwen2_vl import _create_qwen2vl_field_factory, apply_rotary_pos_emb_vision
77-
from .utils import AutoWeightsLoader, WeightsMapper, init_aphrodite_registered_model, maybe_prefix
78-
from .vision import conv3d_to_linear_weight, get_vit_attn_backend, run_dp_sharded_mrope_vision_model
95+
from .utils import (
96+
AutoWeightsLoader,
97+
WeightsMapper,
98+
init_aphrodite_registered_model,
99+
maybe_prefix,
100+
)
101+
from .vision import (
102+
conv3d_to_linear_weight,
103+
get_vit_attn_backend,
104+
run_dp_sharded_mrope_vision_model,
105+
)
79106

80107
logger = init_logger(__name__)
81108

@@ -1262,7 +1289,7 @@ def get_video_replacement_glm4v(item_idx: int):
12621289
info=Glm4vProcessingInfo,
12631290
dummy_inputs=Glm4vDummyInputsBuilder,
12641291
)
1265-
class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP):
1292+
class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE):
12661293
merge_by_field_config = True
12671294

12681295
packed_modules_mapping = {
@@ -1452,6 +1479,109 @@ def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings |
14521479
multimodal_embeddings += tuple(video_embeddings)
14531480
return multimodal_embeddings
14541481

1482+
def get_mrope_input_positions(
1483+
self,
1484+
input_tokens: list[int],
1485+
hf_config: "PretrainedConfig",
1486+
image_grid_thw: list[list[int]] | torch.Tensor | None,
1487+
video_grid_thw: list[list[int]] | torch.Tensor | None,
1488+
second_per_grid_ts: list[float] | None = None,
1489+
context_len: int = 0,
1490+
seq_len: int | None = None,
1491+
audio_feature_lengths: torch.Tensor | None = None,
1492+
use_audio_in_video: bool = False,
1493+
) -> tuple[torch.Tensor, int]:
1494+
"""Get mrope input positions and delta value for GLM4V."""
1495+
1496+
image_token_id = hf_config.image_token_id
1497+
video_start_token_id = hf_config.video_start_token_id
1498+
video_end_token_id = hf_config.video_end_token_id
1499+
spatial_merge_size = hf_config.vision_config.spatial_merge_size
1500+
llm_pos_ids_list: list = []
1501+
1502+
if not (image_grid_thw is None and video_grid_thw is None):
1503+
if isinstance(image_grid_thw, torch.Tensor):
1504+
image_grid_thw = image_grid_thw.tolist()
1505+
1506+
input_token_type: list[str] = []
1507+
video_check_flg = False
1508+
for token in input_tokens:
1509+
if token == video_start_token_id:
1510+
video_check_flg = True
1511+
elif token == video_end_token_id:
1512+
video_check_flg = False
1513+
1514+
if (token == image_token_id) and (video_check_flg is False):
1515+
input_token_type.append("image")
1516+
elif (token == image_token_id) and (video_check_flg is True):
1517+
input_token_type.append("video")
1518+
else:
1519+
input_token_type.append("text")
1520+
1521+
input_type_group: list[tuple[str, int, int]] = []
1522+
for key, group_iter in itertools.groupby(enumerate(input_token_type), lambda x: x[1]):
1523+
group_list = list(group_iter)
1524+
start_index = group_list[0][0]
1525+
end_index = group_list[-1][0] + 1
1526+
input_type_group.append((key, start_index, end_index))
1527+
1528+
video_frame_num = 1
1529+
mm_data_idx = 0
1530+
for modality_type, start_idx, end_idx in input_type_group:
1531+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1532+
if modality_type == "image":
1533+
t, h, w = (
1534+
image_grid_thw[mm_data_idx][0],
1535+
image_grid_thw[mm_data_idx][1],
1536+
image_grid_thw[mm_data_idx][2],
1537+
)
1538+
llm_grid_t, llm_grid_h, llm_grid_w = (
1539+
t,
1540+
h // spatial_merge_size,
1541+
w // spatial_merge_size,
1542+
)
1543+
1544+
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
1545+
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
1546+
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
1547+
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)
1548+
mm_data_idx += 1
1549+
1550+
elif modality_type == "video":
1551+
t, h, w = (
1552+
video_frame_num,
1553+
image_grid_thw[mm_data_idx][1],
1554+
image_grid_thw[mm_data_idx][2],
1555+
)
1556+
llm_grid_t, llm_grid_h, llm_grid_w = (
1557+
t,
1558+
h // spatial_merge_size,
1559+
w // spatial_merge_size,
1560+
)
1561+
1562+
for t_idx in range(llm_grid_t):
1563+
t_index = torch.tensor(t_idx).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
1564+
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(1, -1, llm_grid_w).flatten()
1565+
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten()
1566+
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)
1567+
1568+
mm_data_idx += 1
1569+
video_frame_num += 1
1570+
1571+
else:
1572+
text_len = end_idx - start_idx
1573+
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
1574+
video_frame_num = 1
1575+
1576+
else:
1577+
text_len = len(input_tokens)
1578+
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1))
1579+
1580+
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1581+
llm_positions = llm_positions[:, context_len:seq_len]
1582+
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
1583+
return llm_positions, mrope_position_delta
1584+
14551585
def forward(
14561586
self,
14571587
input_ids: torch.Tensor,

0 commit comments

Comments
 (0)