|
1 | | -import math |
2 | 1 | import os |
3 | 2 | import json |
4 | 3 | import torch |
5 | 4 | from lightllm.models.registry import ModelRegistry |
6 | | -from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer |
7 | 5 | from lightllm.common.basemodel.attention.triton.fp import TritonAttBackend |
8 | 6 | from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class |
9 | 7 | from lightllm.common.build_utils import repair_config |
|
21 | 19 | logger = init_logger(__name__) |
22 | 20 |
|
23 | 21 |
|
24 | | -class Gemma4Tokenizer(BaseMultiModalTokenizer): |
25 | | - def __init__(self, tokenizer, model_cfg, image_processor=None): |
26 | | - super().__init__(tokenizer) |
27 | | - self.image_token_index = model_cfg.get("image_token_id", 258880) |
28 | | - self.boi_token_index = model_cfg.get("boi_token_id", 255999) |
29 | | - self.eoi_token_index = model_cfg.get("eoi_token_id", 258882) |
30 | | - self.image_processor = image_processor |
31 | | - self.image_length = model_cfg.get("vision_soft_tokens_per_image", 280) |
32 | | - self.patch_size = getattr(self.image_processor, "patch_size", 16) |
33 | | - self.pooling_kernel_size = getattr(self.image_processor, "pooling_kernel_size", 3) |
34 | | - self.max_soft_tokens = getattr(self.image_processor, "max_soft_tokens", self.image_length) |
35 | | - # HF Gemma-4 tokenizer does not prepend BOS even with add_special_tokens=True. |
36 | | - self.bos_token_id = tokenizer.bos_token_id |
37 | | - |
38 | | - def init_imageitem_extral_params(self, img, multi_params, sampling_params): |
39 | | - return |
40 | | - |
41 | | - def init_audioitem_extral_params(self, audio, multi_params, sampling_params): |
42 | | - raise NotImplementedError |
43 | | - |
44 | | - def get_image_token_length(self, img): |
45 | | - if self.image_processor is None or img.image_w <= 0 or img.image_h <= 0: |
46 | | - return self.image_length |
47 | | - |
48 | | - patch, kernel = self.patch_size, self.pooling_kernel_size |
49 | | - unit = patch * kernel |
50 | | - num_patches_orig = (img.image_h / patch) * (img.image_w / patch) |
51 | | - scale = math.sqrt(self.max_soft_tokens * kernel ** 2 / num_patches_orig) |
52 | | - target_h = max(unit, int(math.floor(img.image_h * scale / unit)) * unit) |
53 | | - target_w = max(unit, int(math.floor(img.image_w * scale / unit)) * unit) |
54 | | - num_patches = (target_h // patch) * (target_w // patch) |
55 | | - return min(num_patches // kernel ** 2, self.max_soft_tokens) |
56 | | - |
57 | | - def get_audio_token_length(self, audio): |
58 | | - raise NotImplementedError |
59 | | - |
60 | | - def encode(self, prompt, multimodal_params=None, add_special_tokens=False): |
61 | | - origin_ids = self.tokenizer(prompt, add_special_tokens=False).input_ids |
62 | | - if ( |
63 | | - add_special_tokens |
64 | | - and self.bos_token_id is not None |
65 | | - and (len(origin_ids) == 0 or origin_ids[0] != self.bos_token_id) |
66 | | - ): |
67 | | - origin_ids = [self.bos_token_id] + origin_ids |
68 | | - |
69 | | - images = [] if multimodal_params is None else getattr(multimodal_params, "images", []) |
70 | | - if not images: |
71 | | - return origin_ids |
72 | | - |
73 | | - input_ids = [] |
74 | | - image_id = 0 |
75 | | - start = 0 |
76 | | - while True: |
77 | | - try: |
78 | | - image_start = origin_ids.index(self.image_token_index, start) |
79 | | - except ValueError: |
80 | | - break |
81 | | - |
82 | | - input_ids.extend(origin_ids[start:image_start]) |
83 | | - image_end = image_start + 1 |
84 | | - while image_end < len(origin_ids) and origin_ids[image_end] == self.image_token_index: |
85 | | - image_end += 1 |
86 | | - if image_id >= len(images): |
87 | | - raise ValueError("image token error") |
88 | | - |
89 | | - img = images[image_id] |
90 | | - if not input_ids or input_ids[-1] != self.boi_token_index: |
91 | | - input_ids.append(self.boi_token_index) |
92 | | - img.start_idx = len(input_ids) |
93 | | - input_ids.extend(range(img.token_id, img.token_id + img.token_num)) |
94 | | - input_ids.append(self.eoi_token_index) |
95 | | - |
96 | | - if image_end < len(origin_ids) and origin_ids[image_end] == self.eoi_token_index: |
97 | | - image_end += 1 |
98 | | - start = image_end |
99 | | - image_id += 1 |
100 | | - |
101 | | - input_ids.extend(origin_ids[start:]) |
102 | | - image_cnt = len(images) |
103 | | - if image_cnt != image_id: |
104 | | - raise ValueError(f"invalid image tag num: {image_cnt} vs {image_id}!") |
105 | | - return input_ids |
106 | | - |
107 | | - |
108 | 22 | @ModelRegistry("gemma4", is_multimodal=True) |
109 | 23 | class Gemma4TpPartModel(LlamaTpPartModel): |
110 | 24 | pre_and_post_weight_class = Gemma4PreAndPostLayerWeight |
|
0 commit comments