Skip to content

Commit 410890a

Browse files
authored
Fix GLM OCR token forwarding (#2216)
1 parent 859f6af commit 410890a

7 files changed

Lines changed: 192 additions & 8 deletions

File tree

inference/core/workflows/core_steps/models/foundation/glm_ocr/v1.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ class BlockManifest(WorkflowBlockManifest):
106106
},
107107
},
108108
)
109+
max_new_tokens: Optional[int] = Field(
110+
default=None,
111+
description="Maximum number of tokens to generate. If not set, the model default will be used.",
112+
)
109113

110114
model_config = ConfigDict(
111115
json_schema_extra={
@@ -199,19 +203,22 @@ def run(
199203
model_version: str,
200204
task_type: str,
201205
prompt: Optional[str],
206+
max_new_tokens: Optional[int] = None,
202207
) -> BlockResult:
203208
resolved_prompt = _resolve_prompt(task_type, prompt)
204209
if self._step_execution_mode == StepExecutionMode.LOCAL:
205210
return self.run_locally(
206211
images=images,
207212
model_version=model_version,
208213
prompt=resolved_prompt,
214+
max_new_tokens=max_new_tokens,
209215
)
210216
elif self._step_execution_mode == StepExecutionMode.REMOTE:
211217
return self.run_remotely(
212218
images=images,
213219
model_version=model_version,
214220
prompt=resolved_prompt,
221+
max_new_tokens=max_new_tokens,
215222
)
216223
else:
217224
raise ValueError(
@@ -223,6 +230,7 @@ def run_remotely(
223230
images: Batch[WorkflowImageData],
224231
model_version: str,
225232
prompt: str,
233+
max_new_tokens: Optional[int] = None,
226234
) -> BlockResult:
227235
api_url = (
228236
LOCAL_INFERENCE_API_URL
@@ -243,6 +251,7 @@ def run_remotely(
243251
model_id=model_version,
244252
prompt=prompt,
245253
model_id_in_path=True,
254+
max_new_tokens=max_new_tokens,
246255
)
247256
response_text = result.get("response", result)
248257
predictions.append({"parsed_output": response_text})
@@ -254,6 +263,7 @@ def run_locally(
254263
images: Batch[WorkflowImageData],
255264
model_version: str,
256265
prompt: str,
266+
max_new_tokens: Optional[int] = None,
257267
) -> BlockResult:
258268
inference_images = [
259269
i.to_inference_format(numpy_preferred=False) for i in images
@@ -263,13 +273,16 @@ def run_locally(
263273

264274
predictions = []
265275
for image in inference_images:
266-
request = LMMInferenceRequest(
276+
request_kwargs = dict(
267277
api_key=self._api_key,
268278
model_id=model_version,
269279
image=image,
270280
source="workflow-execution",
271281
prompt=prompt,
272282
)
283+
if max_new_tokens is not None:
284+
request_kwargs["max_new_tokens"] = max_new_tokens
285+
request = LMMInferenceRequest(**request_kwargs)
273286
prediction = self._model_manager.infer_from_request_sync(
274287
model_id=model_version, request=request
275288
)

inference/core/workflows/core_steps/models/foundation/qwen3_5vl/v1.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,8 @@ def run(
182182
model_version=model_version,
183183
prompt=prompt,
184184
system_prompt=system_prompt,
185+
enable_thinking=enable_thinking,
186+
max_new_tokens=max_new_tokens,
185187
)
186188
else:
187189
raise ValueError(
@@ -194,6 +196,8 @@ def run_remotely(
194196
model_version: str,
195197
prompt: Optional[str],
196198
system_prompt: Optional[str],
199+
enable_thinking: bool = False,
200+
max_new_tokens: Optional[int] = None,
197201
) -> BlockResult:
198202
api_url = (
199203
LOCAL_INFERENCE_API_URL
@@ -221,6 +225,8 @@ def run_remotely(
221225
model_id=model_version,
222226
prompt=combined_prompt,
223227
model_id_in_path=True,
228+
enable_thinking=enable_thinking,
229+
max_new_tokens=max_new_tokens,
224230
)
225231
response_text = result.get("response", result)
226232
predictions.append({"parsed_output": response_text, "thinking": ""})

inference_models/inference_models/models/glm_ocr/glm_ocr_hf.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44

55
from threading import Lock
6-
from typing import Any, List, Union
6+
from typing import Any, List, Optional, Union
77

88
import numpy as np
99
import torch
@@ -99,7 +99,7 @@ def recognize_table(
9999
self,
100100
images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
101101
input_color_format: ColorFormat = None,
102-
max_new_tokens: int = INFERENCE_MODELS_GLM_OCR_DEFAULT_MAX_NEW_TOKENS,
102+
max_new_tokens: Optional[int] = INFERENCE_MODELS_GLM_OCR_DEFAULT_MAX_NEW_TOKENS,
103103
do_sample: bool = INFERENCE_MODELS_GLM_OCR_DEFAULT_DO_SAMPLE,
104104
skip_special_tokens: bool = True,
105105
**kwargs,
@@ -118,7 +118,7 @@ def recognize_formula(
118118
self,
119119
images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
120120
input_color_format: ColorFormat = None,
121-
max_new_tokens: int = INFERENCE_MODELS_GLM_OCR_DEFAULT_MAX_NEW_TOKENS,
121+
max_new_tokens: Optional[int] = INFERENCE_MODELS_GLM_OCR_DEFAULT_MAX_NEW_TOKENS,
122122
do_sample: bool = INFERENCE_MODELS_GLM_OCR_DEFAULT_DO_SAMPLE,
123123
skip_special_tokens: bool = True,
124124
**kwargs,
@@ -137,7 +137,7 @@ def recognize_text(
137137
self,
138138
images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
139139
input_color_format: ColorFormat = None,
140-
max_new_tokens: int = INFERENCE_MODELS_GLM_OCR_DEFAULT_MAX_NEW_TOKENS,
140+
max_new_tokens: Optional[int] = INFERENCE_MODELS_GLM_OCR_DEFAULT_MAX_NEW_TOKENS,
141141
do_sample: bool = INFERENCE_MODELS_GLM_OCR_DEFAULT_DO_SAMPLE,
142142
skip_special_tokens: bool = True,
143143
**kwargs,
@@ -157,7 +157,7 @@ def prompt(
157157
images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
158158
prompt: str = None,
159159
input_color_format: ColorFormat = None,
160-
max_new_tokens: int = INFERENCE_MODELS_GLM_OCR_DEFAULT_MAX_NEW_TOKENS,
160+
max_new_tokens: Optional[int] = INFERENCE_MODELS_GLM_OCR_DEFAULT_MAX_NEW_TOKENS,
161161
do_sample: bool = INFERENCE_MODELS_GLM_OCR_DEFAULT_DO_SAMPLE,
162162
skip_special_tokens: bool = True,
163163
**kwargs,
@@ -211,10 +211,12 @@ def pre_process_generation(
211211
def generate(
212212
self,
213213
inputs: dict,
214-
max_new_tokens: int = INFERENCE_MODELS_GLM_OCR_DEFAULT_MAX_NEW_TOKENS,
214+
max_new_tokens: Optional[int] = INFERENCE_MODELS_GLM_OCR_DEFAULT_MAX_NEW_TOKENS,
215215
do_sample: bool = INFERENCE_MODELS_GLM_OCR_DEFAULT_DO_SAMPLE,
216216
**kwargs,
217217
) -> torch.Tensor:
218+
if max_new_tokens is None:
219+
max_new_tokens = INFERENCE_MODELS_GLM_OCR_DEFAULT_MAX_NEW_TOKENS
218220
input_len = inputs["input_ids"].shape[-1]
219221

220222
with self._lock, torch.inference_mode():
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from unittest.mock import MagicMock
2+
3+
import numpy as np
4+
5+
from inference_models.configuration import (
6+
INFERENCE_MODELS_GLM_OCR_DEFAULT_MAX_NEW_TOKENS,
7+
)
8+
from inference_models.models.glm_ocr.glm_ocr_hf import GlmOcrHF
9+
10+
11+
def test_generate_uses_default_max_new_tokens_when_none_is_given() -> None:
12+
model = MagicMock()
13+
model.generate.return_value = np.array([[11, 12, 21, 22]])
14+
glm_ocr = GlmOcrHF(model=model, processor=MagicMock(), device=MagicMock())
15+
16+
result = glm_ocr.generate(
17+
inputs={"input_ids": np.array([[11, 12]])},
18+
max_new_tokens=None,
19+
)
20+
21+
assert model.generate.call_args.kwargs["max_new_tokens"] == (
22+
INFERENCE_MODELS_GLM_OCR_DEFAULT_MAX_NEW_TOKENS
23+
)
24+
assert result.tolist() == [[21, 22]]

inference_sdk/http/client.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1601,6 +1601,8 @@ def infer_lmm(
16011601
model_id: str,
16021602
prompt: Optional[str] = None,
16031603
model_id_in_path: bool = False,
1604+
max_new_tokens: Optional[int] = None,
1605+
enable_thinking: Optional[bool] = None,
16041606
) -> Union[dict, List[dict]]:
16051607
"""Run inference using a Large Multimodal Model (LMM).
16061608
@@ -1620,6 +1622,10 @@ def infer_lmm(
16201622
model_id_in_path (bool, optional): If True, includes model_id in the URL path
16211623
(e.g., /infer/lmm/florence-2-base) which enables path-based routing.
16221624
If False (default), model_id is only sent in the request body.
1625+
max_new_tokens (Optional[int], optional): Maximum number of tokens to generate.
1626+
If not provided, the server-side model default is used.
1627+
enable_thinking (Optional[bool], optional): Enables reasoning mode for models
1628+
that support it. If not provided, the server-side model default is used.
16231629
16241630
Returns:
16251631
Union[dict, List[dict]]: Inference results containing the model response.
@@ -1632,6 +1638,10 @@ def infer_lmm(
16321638
extra_payload = {"model_id": model_id}
16331639
if prompt is not None:
16341640
extra_payload["prompt"] = prompt
1641+
if max_new_tokens is not None:
1642+
extra_payload["max_new_tokens"] = max_new_tokens
1643+
if enable_thinking is not None:
1644+
extra_payload["enable_thinking"] = enable_thinking
16351645

16361646
if model_id_in_path:
16371647
endpoint = f"/infer/lmm/{model_id}"
@@ -1652,6 +1662,8 @@ async def infer_lmm_async(
16521662
model_id: str,
16531663
prompt: Optional[str] = None,
16541664
model_id_in_path: bool = False,
1665+
max_new_tokens: Optional[int] = None,
1666+
enable_thinking: Optional[bool] = None,
16551667
) -> Union[dict, List[dict]]:
16561668
"""Run inference using a Large Multimodal Model (LMM) asynchronously.
16571669
@@ -1666,6 +1678,10 @@ async def infer_lmm_async(
16661678
model_id_in_path (bool, optional): If True, includes model_id in the URL path
16671679
(e.g., /infer/lmm/florence-2-base) which enables path-based routing.
16681680
If False (default), model_id is only sent in the request body.
1681+
max_new_tokens (Optional[int], optional): Maximum number of tokens to generate.
1682+
If not provided, the server-side model default is used.
1683+
enable_thinking (Optional[bool], optional): Enables reasoning mode for models
1684+
that support it. If not provided, the server-side model default is used.
16691685
16701686
Returns:
16711687
Union[dict, List[dict]]: Inference results containing the model response.
@@ -1677,6 +1693,10 @@ async def infer_lmm_async(
16771693
extra_payload = {"model_id": model_id}
16781694
if prompt is not None:
16791695
extra_payload["prompt"] = prompt
1696+
if max_new_tokens is not None:
1697+
extra_payload["max_new_tokens"] = max_new_tokens
1698+
if enable_thinking is not None:
1699+
extra_payload["enable_thinking"] = enable_thinking
16801700

16811701
if model_id_in_path:
16821702
endpoint = f"/infer/lmm/{model_id}"

tests/inference_sdk/unit_tests/http/test_client.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3755,6 +3755,39 @@ def test_infer_from_workflow_when_no_parameters_given(
37553755
}, "Request payload must contain api key and inputs"
37563756

37573757

3758+
@mock.patch.object(client, "load_static_inference_input")
3759+
def test_infer_lmm_when_generation_parameters_given(
3760+
load_static_inference_input_mock: MagicMock,
3761+
requests_mock: Mocker,
3762+
) -> None:
3763+
api_url = "http://some.com"
3764+
http_client = InferenceHTTPClient(api_key="my-api-key", api_url=api_url)
3765+
load_static_inference_input_mock.return_value = [("base64_image", 0.5)]
3766+
requests_mock.post(
3767+
f"{api_url}/infer/lmm/glm-ocr",
3768+
json={"response": "recognized text"},
3769+
)
3770+
3771+
result = http_client.infer_lmm(
3772+
inference_input="/some/image.jpg",
3773+
model_id="glm-ocr",
3774+
prompt="Text Recognition:",
3775+
model_id_in_path=True,
3776+
max_new_tokens=4096,
3777+
enable_thinking=True,
3778+
)
3779+
3780+
assert result == {"response": "recognized text"}
3781+
assert requests_mock.request_history[0].json() == {
3782+
"api_key": "my-api-key",
3783+
"image": {"type": "base64", "value": "base64_image"},
3784+
"model_id": "glm-ocr",
3785+
"prompt": "Text Recognition:",
3786+
"max_new_tokens": 4096,
3787+
"enable_thinking": True,
3788+
}
3789+
3790+
37583791
@mock.patch.object(client, "load_nested_batches_of_inference_input")
37593792
@pytest.mark.parametrize(
37603793
"legacy_endpoints, endpoint_to_use, parameter_name",

tests/workflows/unit_tests/core_steps/models/foundation/test_vlm_remote_execution.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Unit tests for VLM blocks remote execution (Florence2, Moondream2, SmolVLM, Qwen)."""
1+
"""Unit tests for VLM blocks remote execution."""
22

33
from unittest.mock import MagicMock, patch
44

@@ -171,6 +171,48 @@ def test_run_remotely_calls_infer_lmm(
171171
mock_client.infer_lmm.assert_called_once()
172172

173173

174+
class TestGLMOCRRemote:
175+
"""Tests for GLM-OCR remote execution."""
176+
177+
@patch(
178+
"inference.core.workflows.core_steps.models.foundation.glm_ocr.v1.InferenceHTTPClient"
179+
)
180+
def test_run_remotely_forwards_max_new_tokens(
181+
self, mock_client_cls, mock_model_manager, mock_workflow_image_data
182+
):
183+
from inference.core.workflows.core_steps.models.foundation.glm_ocr.v1 import (
184+
GLMOCRBlockV1,
185+
)
186+
187+
mock_client = MagicMock()
188+
mock_client.infer_lmm.return_value = {"response": "recognized text"}
189+
mock_client_cls.return_value = mock_client
190+
191+
block = GLMOCRBlockV1(
192+
model_manager=mock_model_manager,
193+
api_key="test_api_key",
194+
step_execution_mode=StepExecutionMode.REMOTE,
195+
)
196+
197+
result = block.run(
198+
images=[mock_workflow_image_data],
199+
model_version="glm-ocr",
200+
task_type="text-recognition",
201+
prompt=None,
202+
max_new_tokens=4096,
203+
)
204+
205+
assert len(result) == 1
206+
assert result[0]["parsed_output"] == "recognized text"
207+
mock_client.infer_lmm.assert_called_once_with(
208+
inference_input=mock_workflow_image_data.base64_image,
209+
model_id="glm-ocr",
210+
prompt="Text Recognition:",
211+
model_id_in_path=True,
212+
max_new_tokens=4096,
213+
)
214+
215+
174216
class TestQwen25VLRemote:
175217
"""Tests for Qwen2.5-VL remote execution."""
176218

@@ -206,6 +248,50 @@ def test_run_remotely_calls_infer_lmm(
206248
mock_client.infer_lmm.assert_called_once()
207249

208250

251+
class TestQwen35VLRemote:
252+
"""Tests for Qwen3.5-VL remote execution."""
253+
254+
@patch(
255+
"inference.core.workflows.core_steps.models.foundation.qwen3_5vl.v1.InferenceHTTPClient"
256+
)
257+
def test_run_remotely_forwards_generation_parameters(
258+
self, mock_client_cls, mock_model_manager, mock_workflow_image_data
259+
):
260+
from inference.core.workflows.core_steps.models.foundation.qwen3_5vl.v1 import (
261+
Qwen35VLBlockV1,
262+
)
263+
264+
mock_client = MagicMock()
265+
mock_client.infer_lmm.return_value = {"response": "This is a test response."}
266+
mock_client_cls.return_value = mock_client
267+
268+
block = Qwen35VLBlockV1(
269+
model_manager=mock_model_manager,
270+
api_key="test_api_key",
271+
step_execution_mode=StepExecutionMode.REMOTE,
272+
)
273+
274+
result = block.run(
275+
images=[mock_workflow_image_data],
276+
model_version="qwen3_5-2b",
277+
prompt="Describe this image",
278+
system_prompt="You are helpful.",
279+
enable_thinking=True,
280+
max_new_tokens=1024,
281+
)
282+
283+
assert len(result) == 1
284+
assert "parsed_output" in result[0]
285+
mock_client.infer_lmm.assert_called_once_with(
286+
inference_input=mock_workflow_image_data.base64_image,
287+
model_id="qwen3_5-2b",
288+
prompt="Describe this image<system_prompt>You are helpful.",
289+
model_id_in_path=True,
290+
enable_thinking=True,
291+
max_new_tokens=1024,
292+
)
293+
294+
209295
class TestQwen3VLRemote:
210296
"""Tests for Qwen3-VL remote execution."""
211297

0 commit comments

Comments
 (0)