Skip to content

Commit 67aeb93

Browse files
fix: port tool parsing to v1 Generator after rebase
The rebase onto main dropped tool parsing changes from generator.py because main refactored it into a thin version-detection wrapper. This ports the tool parsing logic to the v1 Generator: - Add tool_call_parser field and _init_tool_parser() method - Add _extract_tool_calls() using vLLM's ToolParserManager - Update _to_completions() to populate tool_calls/content on Completion - Fix tests to match v1 Generator interface (_to_completions takes prompt) - Fix integration test to import Generator directly (Policy doesn't exist)
1 parent 878d0b9 commit 67aeb93

3 files changed

Lines changed: 92 additions & 7 deletions

File tree

src/forge/actors/vllm/v1/generator.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,15 @@
3838
from torchstore.api import _controller as get_torchstore_controller
3939
from vllm.engine.arg_utils import EngineArgs
4040
from vllm.entrypoints.llm import UsageContext
41+
from vllm.entrypoints.openai.protocol import (
42+
ChatCompletionRequest,
43+
ExtractedToolCallInformation,
44+
ToolCall,
45+
)
46+
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
4147
from vllm.outputs import RequestOutput
4248
from vllm.sampling_params import RequestOutputKind, SamplingParams
49+
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
4350
from vllm.v1.engine.async_llm import AsyncLLM
4451

4552
logger = logging.getLogger(__name__)
@@ -78,6 +85,7 @@ class Generator(ForgeActor):
7885
sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams)
7986
prefetch_weights_to_shm: bool = True
8087
n_fetcher_procs: int = 8
88+
tool_call_parser: str | None = None
8189

8290
def __post_init__(self):
8391
super().__init__()
@@ -91,6 +99,8 @@ def __post_init__(self):
9199
self.engine_args = EngineArgs(**self.engine_args)
92100
self.vllm_config = self.engine_args.create_engine_config(UsageContext.LLM_CLASS)
93101

102+
self._tool_parser = None # Will hold ToolParser instance if configured
103+
94104
if isinstance(self.sampling_params, Mapping):
95105
self.sampling_params = SamplingParams.from_optional(**self.sampling_params)
96106
self.sampling_params.output_kind = RequestOutputKind.FINAL_ONLY
@@ -273,9 +283,44 @@ async def setup(self, host_mesh, worker_registry, gpu_ids: list[str]):
273283
)
274284
logger.info(f"Retrieved workers from registry: {self.workers}")
275285

286+
if self.tool_call_parser is not None:
287+
self._tool_parser = self._init_tool_parser()
288+
276289
if self.prefetch_weights_to_shm:
277290
self._spawn_fetchers()
278291

292+
def _init_tool_parser(self, tokenizer=None): # type: ignore[no-untyped-def]
293+
"""Initialize the tool parser based on configuration.
294+
295+
Args:
296+
tokenizer: Optional tokenizer wrapper (with .tokenizer attr). If not provided,
297+
one is created from vllm_config. Passing explicitly is useful for testing.
298+
299+
Returns:
300+
Initialized ToolParser instance, or None if tool parsing is not configured.
301+
"""
302+
try:
303+
if tokenizer is None:
304+
tokenizer = init_tokenizer_from_configs(
305+
model_config=self.vllm_config.model_config,
306+
scheduler_config=self.vllm_config.scheduler_config,
307+
lora_config=self.vllm_config.lora_config,
308+
)
309+
parser_cls = ToolParserManager.get_tool_parser(self.tool_call_parser) # type: ignore[union-attr]
310+
parser = parser_cls(tokenizer.tokenizer)
311+
logger.info(f"Initialized tool parser: {self.tool_call_parser}")
312+
return parser
313+
except KeyError:
314+
available = list(ToolParserManager.tool_parsers.keys())
315+
logger.error(
316+
f"Unknown tool parser: '{self.tool_call_parser}'. "
317+
f"Available parsers: {available}"
318+
)
319+
return None
320+
except Exception as e:
321+
logger.error(f"Failed to initialize tool parser: {e}")
322+
return None
323+
279324
def _spawn_fetchers(self):
280325
"""Spawn weight fetchers that prefetch weights from torchstore to shared memory.
281326

@@ -545,6 +590,38 @@ def _extract_logprobs(self, output) -> torch.Tensor | None:
545590
)
546591
return None
547592

593+
def _extract_tool_calls(self, model_output: str) -> ExtractedToolCallInformation:
594+
"""Extract tool calls from model output using the configured tool parser.
595+
596+
Args:
597+
model_output: Raw text output from the model.
598+
599+
Returns:
600+
ExtractedToolCallInformation with parsed tool calls and remaining content.
601+
"""
602+
if self._tool_parser is None:
603+
return ExtractedToolCallInformation(
604+
tools_called=False, tool_calls=[], content=model_output
605+
)
606+
607+
try:
608+
dummy_request = ChatCompletionRequest(
609+
model=self.vllm_config.model_config.model,
610+
messages=[{"role": "user", "content": ""}],
611+
seed=42, # to calm the linter
612+
)
613+
614+
extracted = self._tool_parser.extract_tool_calls(
615+
model_output, dummy_request
616+
)
617+
618+
return extracted
619+
except Exception as e:
620+
logger.warning(f"Failed to parse tool calls: {e}")
621+
return ExtractedToolCallInformation(
622+
tools_called=False, tool_calls=[], content=model_output
623+
)
624+
548625
def _to_completions(
549626
self, request_output: RequestOutput, prompt: str
550627
) -> list[Completion]:
@@ -560,6 +637,14 @@ def _to_completions(
560637
completions = []
561638

562639
for output in request_output.outputs:
640+
tool_calls: list[ToolCall] = []
641+
content: str | None = None
642+
643+
if self._tool_parser is not None:
644+
extracted = self._extract_tool_calls(output.text)
645+
tool_calls = extracted.tool_calls
646+
content = extracted.content
647+
563648
completion = Completion(
564649
prompt=to_prompt(prompt),
565650
text=output.text,
@@ -575,6 +660,8 @@ def _to_completions(
575660
stop_reason=output.finish_reason,
576661
generator_version=self.generator_version,
577662
metadata={"num_cached_tokens": request_output.num_cached_tokens},
663+
tool_calls=tool_calls,
664+
content=content,
578665
)
579666
completions.append(completion)
580667

tests/integration_tests/test_tool_parsing.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121
import pytest
2222
import pytest_asyncio
2323
import torch
24-
25-
from forge.rl import Policy
24+
from forge.actors.generator import Generator
2625
from huggingface_hub import snapshot_download
2726
from vllm.transformers_utils.tokenizer import get_tokenizer
2827

@@ -89,7 +88,7 @@ def tokenizer():
8988
async def policy(model_path):
9089
"""Create and teardown policy service for each test."""
9190
logger.info("Setting up policy service...")
92-
policy = await Policy.options(
91+
policy = await Generator.options(
9392
procs=1,
9493
num_replicas=1,
9594
with_gpus=True,

tests/unit_tests/test_generator.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from unittest.mock import MagicMock
1313

1414
import pytest
15-
1615
from vllm.outputs import CompletionOutput, RequestOutput
1716

1817

@@ -298,7 +297,7 @@ def test_to_completions_without_tool_parser(self):
298297
outputs=[{"text": "The answer is 4.", "token_ids": [10, 20, 30]}],
299298
)
300299

301-
completions = generator._to_completions(request_output)
300+
completions = generator._to_completions(request_output, request_output.prompt)
302301

303302
assert len(completions) == 1
304303
completion = completions[0]
@@ -319,7 +318,7 @@ def test_to_completions_no_tool_call_with_parser(self, generator_with_hermes):
319318
],
320319
)
321320

322-
completions = generator._to_completions(request_output)
321+
completions = generator._to_completions(request_output, request_output.prompt)
323322

324323
assert len(completions) == 1
325324
completion = completions[0]
@@ -345,7 +344,7 @@ def test_to_completions_multiple_outputs(self, generator_with_hermes):
345344
],
346345
)
347346

348-
completions = generator._to_completions(request_output)
347+
completions = generator._to_completions(request_output, request_output.prompt)
349348

350349
assert len(completions) == 2
351350
# First completion has tool call

0 commit comments

Comments
 (0)