Skip to content

Commit 3bd46d7

Browse files
committed
fix
1 parent c5b2b81 commit 3bd46d7

3 files changed

Lines changed: 94 additions & 87 deletions

File tree

lightllm/models/gemma4/model.py

Lines changed: 0 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
import math
21
import os
32
import json
43
import torch
54
from lightllm.models.registry import ModelRegistry
6-
from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer
75
from lightllm.common.basemodel.attention.triton.fp import TritonAttBackend
86
from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class
97
from lightllm.common.build_utils import repair_config
@@ -21,90 +19,6 @@
2119
logger = init_logger(__name__)
2220

2321

24-
class Gemma4Tokenizer(BaseMultiModalTokenizer):
25-
def __init__(self, tokenizer, model_cfg, image_processor=None):
26-
super().__init__(tokenizer)
27-
self.image_token_index = model_cfg.get("image_token_id", 258880)
28-
self.boi_token_index = model_cfg.get("boi_token_id", 255999)
29-
self.eoi_token_index = model_cfg.get("eoi_token_id", 258882)
30-
self.image_processor = image_processor
31-
self.image_length = model_cfg.get("vision_soft_tokens_per_image", 280)
32-
self.patch_size = getattr(self.image_processor, "patch_size", 16)
33-
self.pooling_kernel_size = getattr(self.image_processor, "pooling_kernel_size", 3)
34-
self.max_soft_tokens = getattr(self.image_processor, "max_soft_tokens", self.image_length)
35-
# HF Gemma-4 tokenizer does not prepend BOS even with add_special_tokens=True.
36-
self.bos_token_id = tokenizer.bos_token_id
37-
38-
def init_imageitem_extral_params(self, img, multi_params, sampling_params):
39-
return
40-
41-
def init_audioitem_extral_params(self, audio, multi_params, sampling_params):
42-
raise NotImplementedError
43-
44-
def get_image_token_length(self, img):
45-
if self.image_processor is None or img.image_w <= 0 or img.image_h <= 0:
46-
return self.image_length
47-
48-
patch, kernel = self.patch_size, self.pooling_kernel_size
49-
unit = patch * kernel
50-
num_patches_orig = (img.image_h / patch) * (img.image_w / patch)
51-
scale = math.sqrt(self.max_soft_tokens * kernel ** 2 / num_patches_orig)
52-
target_h = max(unit, int(math.floor(img.image_h * scale / unit)) * unit)
53-
target_w = max(unit, int(math.floor(img.image_w * scale / unit)) * unit)
54-
num_patches = (target_h // patch) * (target_w // patch)
55-
return min(num_patches // kernel ** 2, self.max_soft_tokens)
56-
57-
def get_audio_token_length(self, audio):
58-
raise NotImplementedError
59-
60-
def encode(self, prompt, multimodal_params=None, add_special_tokens=False):
61-
origin_ids = self.tokenizer(prompt, add_special_tokens=False).input_ids
62-
if (
63-
add_special_tokens
64-
and self.bos_token_id is not None
65-
and (len(origin_ids) == 0 or origin_ids[0] != self.bos_token_id)
66-
):
67-
origin_ids = [self.bos_token_id] + origin_ids
68-
69-
images = [] if multimodal_params is None else getattr(multimodal_params, "images", [])
70-
if not images:
71-
return origin_ids
72-
73-
input_ids = []
74-
image_id = 0
75-
start = 0
76-
while True:
77-
try:
78-
image_start = origin_ids.index(self.image_token_index, start)
79-
except ValueError:
80-
break
81-
82-
input_ids.extend(origin_ids[start:image_start])
83-
image_end = image_start + 1
84-
while image_end < len(origin_ids) and origin_ids[image_end] == self.image_token_index:
85-
image_end += 1
86-
if image_id >= len(images):
87-
raise ValueError("image token error")
88-
89-
img = images[image_id]
90-
if not input_ids or input_ids[-1] != self.boi_token_index:
91-
input_ids.append(self.boi_token_index)
92-
img.start_idx = len(input_ids)
93-
input_ids.extend(range(img.token_id, img.token_id + img.token_num))
94-
input_ids.append(self.eoi_token_index)
95-
96-
if image_end < len(origin_ids) and origin_ids[image_end] == self.eoi_token_index:
97-
image_end += 1
98-
start = image_end
99-
image_id += 1
100-
101-
input_ids.extend(origin_ids[start:])
102-
image_cnt = len(images)
103-
if image_cnt != image_id:
104-
raise ValueError(f"invalid image tag num: {image_cnt} vs {image_id}!")
105-
return input_ids
106-
107-
10822
@ModelRegistry("gemma4", is_multimodal=True)
10923
class Gemma4TpPartModel(LlamaTpPartModel):
11024
pre_and_post_weight_class = Gemma4PreAndPostLayerWeight
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import math
2+
3+
from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer
4+
from lightllm.server.core.objs.sampling_params import SamplingParams
5+
from lightllm.server.multimodal_params import AudioItem, ImageItem, MultimodalParams
6+
7+
8+
class Gemma4Tokenizer(BaseMultiModalTokenizer):
9+
def __init__(self, tokenizer, model_cfg, image_processor=None):
10+
super().__init__(tokenizer)
11+
self.image_token_index = model_cfg.get("image_token_id", 258880)
12+
self.boi_token_index = model_cfg.get("boi_token_id", 255999)
13+
self.eoi_token_index = model_cfg.get("eoi_token_id", 258882)
14+
self.image_processor = image_processor
15+
self.image_length = model_cfg.get("vision_soft_tokens_per_image", 280)
16+
self.patch_size = getattr(self.image_processor, "patch_size", 16)
17+
self.pooling_kernel_size = getattr(self.image_processor, "pooling_kernel_size", 3)
18+
self.max_soft_tokens = getattr(self.image_processor, "max_soft_tokens", self.image_length)
19+
# HF Gemma-4 tokenizer does not prepend BOS even with add_special_tokens=True.
20+
self.bos_token_id = tokenizer.bos_token_id
21+
22+
def init_imageitem_extral_params(
23+
self, img: ImageItem, multi_params: MultimodalParams, sampling_params: SamplingParams
24+
):
25+
return
26+
27+
def init_audioitem_extral_params(
28+
self, audio: AudioItem, multi_params: MultimodalParams, sampling_params: SamplingParams
29+
):
30+
raise NotImplementedError
31+
32+
def get_image_token_length(self, img: ImageItem):
33+
if self.image_processor is None or img.image_w <= 0 or img.image_h <= 0:
34+
return self.image_length
35+
36+
patch, kernel = self.patch_size, self.pooling_kernel_size
37+
unit = patch * kernel
38+
num_patches_orig = (img.image_h / patch) * (img.image_w / patch)
39+
scale = math.sqrt(self.max_soft_tokens * kernel ** 2 / num_patches_orig)
40+
target_h = max(unit, int(math.floor(img.image_h * scale / unit)) * unit)
41+
target_w = max(unit, int(math.floor(img.image_w * scale / unit)) * unit)
42+
num_patches = (target_h // patch) * (target_w // patch)
43+
return min(num_patches // kernel ** 2, self.max_soft_tokens)
44+
45+
def get_audio_token_length(self, audio: AudioItem):
46+
raise NotImplementedError
47+
48+
def encode(self, prompt, multimodal_params: MultimodalParams = None, add_special_tokens=False):
49+
origin_ids = self.tokenizer(prompt, add_special_tokens=False).input_ids
50+
if (
51+
add_special_tokens
52+
and self.bos_token_id is not None
53+
and (len(origin_ids) == 0 or origin_ids[0] != self.bos_token_id)
54+
):
55+
origin_ids = [self.bos_token_id] + origin_ids
56+
57+
images = [] if multimodal_params is None else getattr(multimodal_params, "images", [])
58+
if not images:
59+
return origin_ids
60+
61+
input_ids = []
62+
image_id = 0
63+
start = 0
64+
while True:
65+
try:
66+
image_start = origin_ids.index(self.image_token_index, start)
67+
except ValueError:
68+
break
69+
70+
input_ids.extend(origin_ids[start:image_start])
71+
image_end = image_start + 1
72+
while image_end < len(origin_ids) and origin_ids[image_end] == self.image_token_index:
73+
image_end += 1
74+
if image_id >= len(images):
75+
raise ValueError("image token error")
76+
77+
img = images[image_id]
78+
if not input_ids or input_ids[-1] != self.boi_token_index:
79+
input_ids.append(self.boi_token_index)
80+
img.start_idx = len(input_ids)
81+
input_ids.extend(range(img.token_id, img.token_id + img.token_num))
82+
input_ids.append(self.eoi_token_index)
83+
84+
if image_end < len(origin_ids) and origin_ids[image_end] == self.eoi_token_index:
85+
image_end += 1
86+
start = image_end
87+
image_id += 1
88+
89+
input_ids.extend(origin_ids[start:])
90+
image_cnt = len(images)
91+
if image_cnt != image_id:
92+
raise ValueError(f"invalid image tag num: {image_cnt} vs {image_id}!")
93+
return input_ids

lightllm/server/tokenizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from ..models.qwen3_vl.model import QWen3VLTokenizer
3232
from ..models.internvl.model import InternvlTokenizer
3333
from ..models.gemma3.model import Gemma3Tokenizer
34-
from ..models.gemma4.model import Gemma4Tokenizer
34+
from ..models.gemma4.tokenizer import Gemma4Tokenizer
3535
from ..models.qwen3_omni_moe_thinker.model import QWen3OmniTokenizer
3636

3737
# A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file.

0 commit comments

Comments
 (0)