|
23 | 23 | # limitations under the License. |
24 | 24 | """Inference-only GLM-4V model compatible with HuggingFace weights.""" |
25 | 25 |
|
| 26 | +import itertools |
26 | 27 | import math |
27 | 28 | from collections.abc import Callable, Iterable, Mapping, Sequence |
28 | 29 | from functools import partial |
|
33 | 34 | import torch.nn as nn |
34 | 35 | import torch.nn.functional as F |
35 | 36 | from einops import rearrange |
36 | | -from transformers import BatchFeature |
| 37 | +from transformers import BatchFeature, PretrainedConfig |
37 | 38 | from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig |
38 | | -from transformers.models.glm4v.image_processing_glm4v import Glm4vImageProcessor, smart_resize |
| 39 | +from transformers.models.glm4v.image_processing_glm4v import ( |
| 40 | + Glm4vImageProcessor, |
| 41 | + smart_resize, |
| 42 | +) |
39 | 43 | from transformers.models.glm4v.video_processing_glm4v import Glm4vVideoProcessor |
40 | 44 | from transformers.video_utils import VideoMetadata |
41 | 45 |
|
42 | 46 | from aphrodite.attention.backends.registry import _Backend |
43 | | -from aphrodite.attention.layer import check_upstream_fa_availability, maybe_get_vit_flash_attn_backend |
| 47 | +from aphrodite.attention.layer import ( |
| 48 | + check_upstream_fa_availability, |
| 49 | + maybe_get_vit_flash_attn_backend, |
| 50 | +) |
44 | 51 | from aphrodite.common.sequence import IntermediateTensors |
45 | 52 | from aphrodite.config import AphroditeConfig |
46 | 53 | from aphrodite.config.multimodal import BaseDummyOptions, VideoDummyOptions |
|
58 | 65 | from aphrodite.modeling.model_loader.weight_utils import default_weight_loader |
59 | 66 | from aphrodite.modeling.models.module_mapping import MultiModelKeys |
60 | 67 | from aphrodite.multimodal import MULTIMODAL_REGISTRY |
61 | | -from aphrodite.multimodal.inputs import MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, VideoItem |
| 68 | +from aphrodite.multimodal.inputs import ( |
| 69 | + MultiModalDataDict, |
| 70 | + MultiModalFieldConfig, |
| 71 | + MultiModalKwargsItems, |
| 72 | + VideoItem, |
| 73 | +) |
62 | 74 | from aphrodite.multimodal.parse import ImageSize, MultiModalDataItems, MultiModalDataParser |
63 | 75 | from aphrodite.multimodal.processing import ( |
64 | 76 | BaseMultiModalProcessor, |
|
72 | 84 | from aphrodite.utils.tensor_schema import TensorSchema, TensorShape |
73 | 85 |
|
74 | 86 | from ..layers.activation import SiluAndMul |
75 | | -from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP |
| 87 | +from .interfaces import ( |
| 88 | + MultiModalEmbeddings, |
| 89 | + SupportsLoRA, |
| 90 | + SupportsMRoPE, |
| 91 | + SupportsMultiModal, |
| 92 | + SupportsPP, |
| 93 | +) |
76 | 94 | from .qwen2_vl import _create_qwen2vl_field_factory, apply_rotary_pos_emb_vision |
77 | | -from .utils import AutoWeightsLoader, WeightsMapper, init_aphrodite_registered_model, maybe_prefix |
78 | | -from .vision import conv3d_to_linear_weight, get_vit_attn_backend, run_dp_sharded_mrope_vision_model |
| 95 | +from .utils import ( |
| 96 | + AutoWeightsLoader, |
| 97 | + WeightsMapper, |
| 98 | + init_aphrodite_registered_model, |
| 99 | + maybe_prefix, |
| 100 | +) |
| 101 | +from .vision import ( |
| 102 | + conv3d_to_linear_weight, |
| 103 | + get_vit_attn_backend, |
| 104 | + run_dp_sharded_mrope_vision_model, |
| 105 | +) |
79 | 106 |
|
80 | 107 | logger = init_logger(__name__) |
81 | 108 |
|
@@ -1262,7 +1289,7 @@ def get_video_replacement_glm4v(item_idx: int): |
1262 | 1289 | info=Glm4vProcessingInfo, |
1263 | 1290 | dummy_inputs=Glm4vDummyInputsBuilder, |
1264 | 1291 | ) |
1265 | | -class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP): |
| 1292 | +class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE): |
1266 | 1293 | merge_by_field_config = True |
1267 | 1294 |
|
1268 | 1295 | packed_modules_mapping = { |
@@ -1452,6 +1479,109 @@ def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings | |
1452 | 1479 | multimodal_embeddings += tuple(video_embeddings) |
1453 | 1480 | return multimodal_embeddings |
1454 | 1481 |
|
| 1482 | + def get_mrope_input_positions( |
| 1483 | + self, |
| 1484 | + input_tokens: list[int], |
| 1485 | + hf_config: "PretrainedConfig", |
| 1486 | + image_grid_thw: list[list[int]] | torch.Tensor | None, |
| 1487 | + video_grid_thw: list[list[int]] | torch.Tensor | None, |
| 1488 | + second_per_grid_ts: list[float] | None = None, |
| 1489 | + context_len: int = 0, |
| 1490 | + seq_len: int | None = None, |
| 1491 | + audio_feature_lengths: torch.Tensor | None = None, |
| 1492 | + use_audio_in_video: bool = False, |
| 1493 | + ) -> tuple[torch.Tensor, int]: |
| 1494 | + """Get mrope input positions and delta value for GLM4V.""" |
| 1495 | + |
| 1496 | + image_token_id = hf_config.image_token_id |
| 1497 | + video_start_token_id = hf_config.video_start_token_id |
| 1498 | + video_end_token_id = hf_config.video_end_token_id |
| 1499 | + spatial_merge_size = hf_config.vision_config.spatial_merge_size |
| 1500 | + llm_pos_ids_list: list = [] |
| 1501 | + |
| 1502 | + if not (image_grid_thw is None and video_grid_thw is None): |
| 1503 | + if isinstance(image_grid_thw, torch.Tensor): |
| 1504 | + image_grid_thw = image_grid_thw.tolist() |
| 1505 | + |
| 1506 | + input_token_type: list[str] = [] |
| 1507 | + video_check_flg = False |
| 1508 | + for token in input_tokens: |
| 1509 | + if token == video_start_token_id: |
| 1510 | + video_check_flg = True |
| 1511 | + elif token == video_end_token_id: |
| 1512 | + video_check_flg = False |
| 1513 | + |
| 1514 | + if (token == image_token_id) and (video_check_flg is False): |
| 1515 | + input_token_type.append("image") |
| 1516 | + elif (token == image_token_id) and (video_check_flg is True): |
| 1517 | + input_token_type.append("video") |
| 1518 | + else: |
| 1519 | + input_token_type.append("text") |
| 1520 | + |
| 1521 | + input_type_group: list[tuple[str, int, int]] = [] |
| 1522 | + for key, group_iter in itertools.groupby(enumerate(input_token_type), lambda x: x[1]): |
| 1523 | + group_list = list(group_iter) |
| 1524 | + start_index = group_list[0][0] |
| 1525 | + end_index = group_list[-1][0] + 1 |
| 1526 | + input_type_group.append((key, start_index, end_index)) |
| 1527 | + |
| 1528 | + video_frame_num = 1 |
| 1529 | + mm_data_idx = 0 |
| 1530 | + for modality_type, start_idx, end_idx in input_type_group: |
| 1531 | + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 |
| 1532 | + if modality_type == "image": |
| 1533 | + t, h, w = ( |
| 1534 | + image_grid_thw[mm_data_idx][0], |
| 1535 | + image_grid_thw[mm_data_idx][1], |
| 1536 | + image_grid_thw[mm_data_idx][2], |
| 1537 | + ) |
| 1538 | + llm_grid_t, llm_grid_h, llm_grid_w = ( |
| 1539 | + t, |
| 1540 | + h // spatial_merge_size, |
| 1541 | + w // spatial_merge_size, |
| 1542 | + ) |
| 1543 | + |
| 1544 | + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() |
| 1545 | + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() |
| 1546 | + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() |
| 1547 | + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx) |
| 1548 | + mm_data_idx += 1 |
| 1549 | + |
| 1550 | + elif modality_type == "video": |
| 1551 | + t, h, w = ( |
| 1552 | + video_frame_num, |
| 1553 | + image_grid_thw[mm_data_idx][1], |
| 1554 | + image_grid_thw[mm_data_idx][2], |
| 1555 | + ) |
| 1556 | + llm_grid_t, llm_grid_h, llm_grid_w = ( |
| 1557 | + t, |
| 1558 | + h // spatial_merge_size, |
| 1559 | + w // spatial_merge_size, |
| 1560 | + ) |
| 1561 | + |
| 1562 | + for t_idx in range(llm_grid_t): |
| 1563 | + t_index = torch.tensor(t_idx).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() |
| 1564 | + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(1, -1, llm_grid_w).flatten() |
| 1565 | + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten() |
| 1566 | + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx) |
| 1567 | + |
| 1568 | + mm_data_idx += 1 |
| 1569 | + video_frame_num += 1 |
| 1570 | + |
| 1571 | + else: |
| 1572 | + text_len = end_idx - start_idx |
| 1573 | + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) |
| 1574 | + video_frame_num = 1 |
| 1575 | + |
| 1576 | + else: |
| 1577 | + text_len = len(input_tokens) |
| 1578 | + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1)) |
| 1579 | + |
| 1580 | + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) |
| 1581 | + llm_positions = llm_positions[:, context_len:seq_len] |
| 1582 | + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() |
| 1583 | + return llm_positions, mrope_position_delta |
| 1584 | + |
1455 | 1585 | def forward( |
1456 | 1586 | self, |
1457 | 1587 | input_ids: torch.Tensor, |
|
0 commit comments