3838from torchstore.api import _controller as get_torchstore_controller
3939from vllm.engine.arg_utils import EngineArgs
4040from vllm.entrypoints.llm import UsageContext
41+ from vllm.entrypoints.openai.protocol import (
42+ ChatCompletionRequest,
43+ ExtractedToolCallInformation,
44+ ToolCall,
45+ )
46+ from vllm.entrypoints.openai.tool_parsers import ToolParserManager
4147from vllm.outputs import RequestOutput
4248from vllm.sampling_params import RequestOutputKind, SamplingParams
49+ from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
4350from vllm.v1.engine.async_llm import AsyncLLM
4451
4552logger = logging.getLogger(__name__)
@@ -78,6 +85,7 @@ class Generator(ForgeActor):
7885 sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams)
7986 prefetch_weights_to_shm: bool = True
8087 n_fetcher_procs: int = 8
88+ tool_call_parser: str | None = None
8189
8290 def __post_init__(self):
8391 super().__init__()
@@ -91,6 +99,8 @@ def __post_init__(self):
9199 self.engine_args = EngineArgs(**self.engine_args)
92100 self.vllm_config = self.engine_args.create_engine_config(UsageContext.LLM_CLASS)
93101
102+ self._tool_parser = None # Will hold ToolParser instance if configured
103+
94104 if isinstance(self.sampling_params, Mapping):
95105 self.sampling_params = SamplingParams.from_optional(**self.sampling_params)
96106 self.sampling_params.output_kind = RequestOutputKind.FINAL_ONLY
@@ -273,9 +283,44 @@ async def setup(self, host_mesh, worker_registry, gpu_ids: list[str]):
273283 )
274284 logger.info(f"Retrieved workers from registry: {self.workers}")
275285
286+ if self.tool_call_parser is not None:
287+ self._tool_parser = self._init_tool_parser()
288+
276289 if self.prefetch_weights_to_shm:
277290 self._spawn_fetchers()
278291
292+ def _init_tool_parser(self, tokenizer=None): # type: ignore[no-untyped-def]
293+ """Initialize the tool parser based on configuration.
294+
295+ Args:
296+ tokenizer: Optional tokenizer wrapper (with .tokenizer attr). If not provided,
297+ one is created from vllm_config. Passing explicitly is useful for testing.
298+
299+ Returns:
300+ Initialized ToolParser instance, or None if tool parsing is not configured.
301+ """
302+ try:
303+ if tokenizer is None:
304+ tokenizer = init_tokenizer_from_configs(
305+ model_config=self.vllm_config.model_config,
306+ scheduler_config=self.vllm_config.scheduler_config,
307+ lora_config=self.vllm_config.lora_config,
308+ )
309+ parser_cls = ToolParserManager.get_tool_parser(self.tool_call_parser) # type: ignore[union-attr]
310+ parser = parser_cls(tokenizer.tokenizer)
311+ logger.info(f"Initialized tool parser: {self.tool_call_parser}")
312+ return parser
313+ except KeyError:
314+ available = list(ToolParserManager.tool_parsers.keys())
315+ logger.error(
316+ f"Unknown tool parser: '{self.tool_call_parser}'. "
317+ f"Available parsers: {available}"
318+ )
319+ return None
320+ except Exception as e:
321+ logger.error(f"Failed to initialize tool parser: {e}")
322+ return None
323+
279324 def _spawn_fetchers(self):
280325 """Spawn weight fetchers that prefetch weights from torchstore to shared memory.
281326
@@ -545,6 +590,38 @@ def _extract_logprobs(self, output) -> torch.Tensor | None:
545590 )
546591 return None
547592
593+ def _extract_tool_calls(self, model_output: str) -> ExtractedToolCallInformation:
594+ """Extract tool calls from model output using the configured tool parser.
595+
596+ Args:
597+ model_output: Raw text output from the model.
598+
599+ Returns:
600+ ExtractedToolCallInformation with parsed tool calls and remaining content.
601+ """
602+ if self._tool_parser is None:
603+ return ExtractedToolCallInformation(
604+ tools_called=False, tool_calls=[], content=model_output
605+ )
606+
607+ try:
608+ dummy_request = ChatCompletionRequest(
609+ model=self.vllm_config.model_config.model,
610+ messages=[{"role": "user", "content": ""}],
611+ seed=42, # to calm the linter
612+ )
613+
614+ extracted = self._tool_parser.extract_tool_calls(
615+ model_output, dummy_request
616+ )
617+
618+ return extracted
619+ except Exception as e:
620+ logger.warning(f"Failed to parse tool calls: {e}")
621+ return ExtractedToolCallInformation(
622+ tools_called=False, tool_calls=[], content=model_output
623+ )
624+
548625 def _to_completions(
549626 self, request_output: RequestOutput, prompt: str
550627 ) -> list[Completion]:
@@ -560,6 +637,14 @@ def _to_completions(
560637 completions = []
561638
562639 for output in request_output.outputs:
640+ tool_calls: list[ToolCall] = []
641+ content: str | None = None
642+
643+ if self._tool_parser is not None:
644+ extracted = self._extract_tool_calls(output.text)
645+ tool_calls = extracted.tool_calls
646+ content = extracted.content
647+
563648 completion = Completion(
564649 prompt=to_prompt(prompt),
565650 text=output.text,
@@ -575,6 +660,8 @@ def _to_completions(
575660 stop_reason=output.finish_reason,
576661 generator_version=self.generator_version,
577662 metadata={"num_cached_tokens": request_output.num_cached_tokens},
663+ tool_calls=tool_calls,
664+ content=content,
578665 )
579666 completions.append(completion)
580667
0 commit comments