Skip to content

Commit 819497c

Browse files
committed
Merge branch 'main' of https://github.com/ModelTC/LightLLM into support_gemma4
2 parents 87da477 + f850264 commit 819497c

12 files changed

Lines changed: 741 additions & 10 deletions

File tree

docs/CN/source/tutorial/api_server_args.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,18 @@ PD 分离模式参数
272272

273273
多模态资源的缓存服务器容量,默认为 ``200``
274274

275+
.. option:: --max_image_token_count
276+
277+
单张图片在转换为 token 后允许的最大 token 数量,默认为 ``6128``
278+
279+
当任意图片超过该阈值时,请求会被拒绝。
280+
281+
.. option:: --max_image_pixels
282+
283+
单张图片在预处理缩放前允许的最大像素数量,默认为 ``8294400``(约等于 4K 图片像素总量)。
284+
285+
当输入图片超过该阈值时,LightLLM 会先自动将其缩放到该像素预算内,再继续后续流程。
286+
275287
.. option:: --visual_infer_batch_size
276288
277289
每次推理批次中处理的图像数量,默认为 ``1``

docs/EN/source/tutorial/api_server_args.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,18 @@ Multimodal Parameters
270270

271271
Cache server capacity for multimodal resources, default is ``200``
272272

273+
.. option:: --max_image_token_count
274+
275+
Maximum allowed token count for a single image after tokenization, default is ``6128``
276+
277+
Requests are rejected when any image exceeds this limit.
278+
279+
.. option:: --max_image_pixels
280+
281+
Maximum allowed pixel count for a single image before preprocessing resize, default is ``8294400`` (about 4K image pixels).
282+
283+
If an input image exceeds this threshold, LightLLM automatically resizes it down to this pixel budget before continuing.
284+
273285
.. option:: --visual_infer_batch_size
274286

275287
Number of images processed in each inference batch, default is ``1``

lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def _context_attention_kernel(
6565
out=None,
6666
):
6767
if self.network_config_["layer_types"][self.layer_num_] == "sliding_attention":
68-
window_size = (self.sliding_window - 1, self.sliding_window - 1)
68+
window_size = (self.sliding_window - 1, 0)
6969
use_sliding_window = True
7070
else:
7171
window_size = (-1, -1)
@@ -92,7 +92,7 @@ def _token_attention_kernel(
9292
self, q: torch.Tensor, infer_state: LlamaInferStateInfo, layer_weight: GptOssTransformerLayerWeight, out=None
9393
):
9494
if self.network_config_["layer_types"][self.layer_num_] == "sliding_attention":
95-
window_size = (self.sliding_window - 1, self.sliding_window - 1)
95+
window_size = (self.sliding_window - 1, 0)
9696
use_sliding_window = True
9797
else:
9898
window_size = (-1, -1)

lightllm/server/api_cli.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,18 @@ def make_argument_parser() -> argparse.ArgumentParser:
443443
parser.add_argument(
444444
"--cache_capacity", type=int, default=200, help="cache server capacity for multimodal resources"
445445
)
446+
parser.add_argument(
447+
"--max_image_token_count",
448+
type=int,
449+
default=8192,
450+
help="maximum allowed token count for one image after tokenization",
451+
)
452+
parser.add_argument(
453+
"--max_image_pixels",
454+
type=int,
455+
default=8294400,
456+
help="maximum allowed pixel count for one image before resize preprocessing",
457+
)
446458
parser.add_argument(
447459
"--embed_cache_storage_size",
448460
type=float,

lightllm/server/core/objs/start_args_type.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ class StartArgs:
9595
enable_decode_microbatch_overlap: bool = field(default=False)
9696
enable_prefill_microbatch_overlap: bool = field(default=False)
9797
cache_capacity: int = field(default=200)
98+
max_image_token_count: int = field(default=8192)
99+
max_image_pixels: int = field(default=8294400)
98100
embed_cache_storage_size: float = field(default=4)
99101
data_type: Optional[str] = field(
100102
default=None, metadata={"choices": ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"]}

lightllm/server/httpserver/manager.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,17 @@ async def _alloc_resource(self, items, md5sums, token_nums, datas):
181181
self.cache_client.root.set_items_data(update_data_ids)
182182
return
183183

184+
def _assert_image_token_count(self, token_num: int):
185+
if token_num > self.args.max_image_token_count:
186+
err_msg = (
187+
f"single image token count {token_num} exceeds max_image_token_count {self.args.max_image_token_count}."
188+
f"You can increase this limit by setting --max_image_token_count to a larger value when starting "
189+
f"LightLLM. Warning: increasing this limit raises runtime OOM risk."
190+
)
191+
logger.warning(err_msg)
192+
raise ValueError(err_msg)
193+
return
194+
184195
async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, sampling_params: SamplingParams):
185196
# 只有 P 和 NORMAL 节点需要真的管理多模态资源
186197
if self.pd_mode.is_P_or_NORMAL():
@@ -190,6 +201,7 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams,
190201
data = img.read()
191202
# must after init_imageitem_extral_params
192203
token_num = self.tokenizer.get_image_token_length(img)
204+
self._assert_image_token_count(token_num)
193205
md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(img.extra_params)))
194206
md5sums.append(md5sum)
195207
img.md5 = md5sum
@@ -245,7 +257,9 @@ def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwar
245257
for img in multimodal_params.images:
246258
img_count += 1
247259
self.tokenizer.init_imageitem_extral_params(img, multimodal_params, samping_params)
248-
image_tokens += self.tokenizer.get_image_token_length(img)
260+
token_num = self.tokenizer.get_image_token_length(img)
261+
self._assert_image_token_count(token_num)
262+
image_tokens += token_num
249263
for audio in multimodal_params.audios:
250264
audio_count += 1
251265
self.tokenizer.init_audioitem_extral_params(audio, multimodal_params, samping_params)

lightllm/server/httpserver_for_pd_master/manager.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,16 @@ def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwar
8181
for img in multimodal_params.images:
8282
img_count += 1
8383
self.tokenizer.init_imageitem_extral_params(img, multimodal_params, samping_params)
84-
image_tokens += self.tokenizer.get_image_token_length(img)
84+
token_num = self.tokenizer.get_image_token_length(img)
85+
if token_num > self.args.max_image_token_count:
86+
err_msg = (
87+
f"the image token count {token_num} > max_image_token_count {self.args.max_image_token_count}. "
88+
f"You can increase this limit by setting --max_image_token_count to a larger value when starting "
89+
f"LightLLM. Warning: increasing this limit raises runtime OOM risk."
90+
)
91+
logger.warning(err_msg)
92+
raise ValueError(err_msg)
93+
image_tokens += token_num
8594
for audio in multimodal_params.audios:
8695
audio_count += 1
8796
self.tokenizer.init_audioitem_extral_params(audio, multimodal_params, samping_params)

lightllm/server/multimodal_params.py

Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import librosa
55
import base64
66
import numpy as np
7-
from typing import List, Tuple
7+
from typing import List, Tuple, Optional
88
from io import BytesIO
99
from concurrent.futures import ThreadPoolExecutor
1010
from PIL import Image, ImageFile
@@ -13,6 +13,7 @@
1313
from lightllm.utils.error_utils import ClientDisconnected
1414
from lightllm.utils.multimodal_utils import fetch_resource
1515
from lightllm.utils.log_utils import init_logger
16+
from lightllm.utils.envs_utils import get_env_start_args
1617

1718

1819
logger = init_logger(__name__)
@@ -131,6 +132,9 @@ def __init__(self, **kwargs):
131132
self.extra_params = {}
132133

133134
async def preload(self, request: Request):
135+
136+
max_image_pixels = get_env_start_args().max_image_pixels
137+
134138
try:
135139
if self._type == "url":
136140
timeout = int(os.getenv("REQUEST_TIMEOUT", "5"))
@@ -141,8 +145,14 @@ async def preload(self, request: Request):
141145
elif self._type == "image_size":
142146
# image_size 代表直接传入图片的 width,height,主要是用于一些场景
143147
# 的 token 计数判断, 所以只需要图片长宽信息,不需要具体图片的内容信息
144-
self.image_w = self._data[0]
145-
self.image_h = self._data[1]
148+
src_w = self._data[0]
149+
src_h = self._data[1]
150+
self.image_w, self.image_h = _resize_image_dimensions_if_needed(src_w, src_h, max_image_pixels)
151+
if (self.image_w, self.image_h) != (src_w, src_h):
152+
logger.warning(
153+
f"image_size pixels {src_w * src_h} exceed max_image_pixels={max_image_pixels}, "
154+
f"resized to {self.image_w}x{self.image_h}"
155+
)
146156
return
147157
else:
148158
raise ValueError(f"cannot read image which type is {self._type}!")
@@ -151,7 +161,24 @@ async def preload(self, request: Request):
151161
# Decoding is mainly done in the C libraries (libjpeg/libpng/libwebp), which releases the GIL,
152162
# and multiple threads can achieve true parallelism.
153163
loop = asyncio.get_running_loop()
154-
self.image_w, self.image_h = await loop.run_in_executor(_IMAGE_VERIFY_POOL, _verify_image_bytes, img_data)
164+
# 1) Verify original input bytes first.
165+
src_w, src_h = await loop.run_in_executor(_IMAGE_VERIFY_POOL, _verify_image_bytes, img_data)
166+
# 2) Resize (or no-op) after verification.
167+
img_data, resized_w, resized_h = await loop.run_in_executor(
168+
_IMAGE_VERIFY_POOL,
169+
_resize_image_bytes_if_needed,
170+
img_data,
171+
src_w,
172+
src_h,
173+
max_image_pixels,
174+
)
175+
self.image_w, self.image_h = resized_w, resized_h
176+
177+
if (resized_w, resized_h) != (src_w, src_h):
178+
logger.warning(
179+
f"image pixels {src_w * src_h} exceed max_image_pixels={max_image_pixels},"
180+
f" resized to {self.image_w}x{self.image_h}"
181+
)
155182

156183
self._preload_data = img_data
157184
return
@@ -245,3 +272,45 @@ def _verify_image_bytes(img_data: bytes) -> Tuple[int, int]:
245272
w, h = image.size
246273
image.load()
247274
return w, h
275+
276+
277+
def _resize_image_bytes_if_needed(
278+
img_data: bytes, src_w: int, src_h: int, max_image_pixels: int
279+
) -> Tuple[bytes, int, int]:
280+
"""
281+
Resize image bytes to satisfy max pixel constraint and return resized bytes with size.
282+
"""
283+
new_w, new_h = _resize_image_dimensions_if_needed(src_w, src_h, max_image_pixels)
284+
if (new_w, new_h) == (src_w, src_h):
285+
return img_data, src_w, src_h
286+
287+
with Image.open(BytesIO(img_data)) as image:
288+
resampling = Image.Resampling.LANCZOS if hasattr(Image, "Resampling") else Image.LANCZOS
289+
resized_image = image.resize((new_w, new_h), resampling).convert("RGB")
290+
291+
buffer = BytesIO()
292+
resized_image.save(buffer, format="JPEG", quality=96, optimize=True)
293+
return buffer.getvalue(), new_w, new_h
294+
295+
296+
def _resize_image_dimensions_if_needed(src_w: int, src_h: int, max_image_pixels: int) -> Tuple[int, int]:
297+
"""
298+
Compute resized (w, h) under a max pixel budget while preserving aspect ratio.
299+
"""
300+
old_pixels = src_w * src_h
301+
if old_pixels <= max_image_pixels:
302+
return src_w, src_h
303+
304+
scale = (max_image_pixels / old_pixels) ** 0.5
305+
new_w = max(1, int(src_w * scale))
306+
new_h = max(1, int(src_h * scale))
307+
308+
# Avoid overflow from integer rounding.
309+
while new_w * new_h > max_image_pixels:
310+
if new_w >= new_h:
311+
new_w = max(1, new_w - 1)
312+
else:
313+
new_h = max(1, new_h - 1)
314+
315+
assert new_w > 0 and new_h > 0, "resized image dimensions must be positive"
316+
return new_w, new_h

lightllm/server/router/manager.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,9 +436,11 @@ def _generate_new_batch(self):
436436
new_batch = self.req_queue.generate_new_batch(
437437
Batch.merge_two_batch(self.running_batch, self.schedule_new_batch)
438438
)
439+
440+
if new_batch is not None and len(new_batch.reqs) > 0:
441+
logger.info(f"generate new batch, {new_batch.simple_log()}")
442+
439443
self.schedule_new_batch = Batch.merge_two_batch(self.schedule_new_batch, new_batch)
440-
if self.schedule_new_batch is not None:
441-
logger.info(f"gen new batch, {self.schedule_new_batch.simple_log()}")
442444
return
443445

444446
def _multinode_tp_generate_new_batch(self):
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
quant_type: none
2+
mix_bits:
3+
- name: "fused_moe"
4+
quant_type: "deepgemm-fp8w8a8-b128"

0 commit comments

Comments
 (0)