Skip to content

Commit ca2799b

Browse files
test: add integration tests for vLLM tool parsing workflow
1 parent 060090d commit ca2799b

1 file changed

Lines changed: 199 additions & 0 deletions

File tree

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Integration tests for vLLM tool parsing in forge.
9+
10+
Tests the full tool-calling workflow: model generates tool call -> parse -> execute -> return result.
11+
12+
Requires GPU access.
13+
14+
Run:
15+
pytest tests/integration_tests/test_tool_parsing.py -v -s
16+
"""
17+
18+
import json
19+
import logging
20+
21+
import pytest
22+
import pytest_asyncio
23+
import torch
24+
25+
from forge.rl import Policy
26+
from huggingface_hub import snapshot_download
27+
from vllm.transformers_utils.tokenizer import get_tokenizer
28+
29+
logger = logging.getLogger(__name__)
30+
logger.setLevel(logging.INFO)
31+
32+
requires_cuda = pytest.mark.skipif(
33+
not torch.cuda.is_available(),
34+
reason="CUDA not available",
35+
)
36+
37+
MODEL_NAME = "Qwen/Qwen3-0.6B"
38+
39+
TOOLS = [
40+
{
41+
"type": "function",
42+
"function": {
43+
"name": "calculator",
44+
"description": "Evaluate a mathematical equation.",
45+
"parameters": {
46+
"type": "object",
47+
"properties": {
48+
"equation": {
49+
"type": "string",
50+
"description": "The mathematical equation to evaluate",
51+
},
52+
},
53+
"required": ["equation"],
54+
},
55+
},
56+
},
57+
]
58+
59+
60+
def calculator(equation: str) -> str:
61+
"""Safely evaluate a mathematical equation."""
62+
try:
63+
# Only allow safe math operations
64+
allowed = set("0123456789+-*/().^ ")
65+
if all(c in allowed for c in equation):
66+
result = eval(equation.replace("^", "**"))
67+
return str(result)
68+
return "Error: Invalid characters in equation"
69+
except Exception as e:
70+
return f"Error: {e}"
71+
72+
73+
@pytest.fixture(scope="module")
74+
def model_path():
75+
"""Download model once for all tests in this module."""
76+
logger.info(f"Downloading model checkpoint: {MODEL_NAME}")
77+
cached_dir = snapshot_download(repo_id=MODEL_NAME)
78+
logger.info(f"Model downloaded to: {cached_dir}")
79+
return cached_dir
80+
81+
82+
@pytest.fixture(scope="module")
83+
def tokenizer():
84+
"""Create tokenizer once for all tests in this module."""
85+
return get_tokenizer(MODEL_NAME)
86+
87+
88+
@pytest_asyncio.fixture
89+
async def policy(model_path):
90+
"""Create and teardown policy service for each test."""
91+
logger.info("Setting up policy service...")
92+
policy = await Policy.options(
93+
procs=1,
94+
num_replicas=1,
95+
with_gpus=True,
96+
).as_service(
97+
engine_args={"model": model_path},
98+
sampling_params={"n": 1, "max_tokens": 256},
99+
tool_call_parser="hermes",
100+
)
101+
102+
yield policy
103+
104+
# Teardown
105+
logger.info("Shutting down policy service...")
106+
await policy.shutdown()
107+
108+
109+
@requires_cuda
110+
@pytest.mark.asyncio
111+
async def test_tool_parsing_multi_turn(policy, tokenizer):
112+
"""
113+
Multi-turn conversation: tool call -> execute -> feed result back -> final answer.
114+
"""
115+
messages = [
116+
{
117+
"role": "system",
118+
"content": "/no_think Use the calculator tool for math.",
119+
},
120+
{"role": "user", "content": "Calculate 123 + 456"},
121+
]
122+
123+
# First turn - get tool call
124+
formatted = tokenizer.apply_chat_template(
125+
messages, tools=TOOLS, tokenize=False, add_generation_prompt=True
126+
)
127+
response = await policy.generate.route(formatted)
128+
completion = response[0]
129+
130+
assert completion.has_tool_calls, "Expected tool calls"
131+
tool_call = completion.tool_calls[0]
132+
args = json.loads(tool_call.function.arguments)
133+
result = calculator(args["equation"])
134+
135+
# Add assistant response and tool result to conversation
136+
messages.append(
137+
{
138+
"role": "assistant",
139+
"content": completion.text,
140+
}
141+
)
142+
messages.append(
143+
{
144+
"role": "tool",
145+
"tool_call_id": tool_call.id,
146+
"content": result,
147+
}
148+
)
149+
150+
# Second turn - get final answer
151+
formatted = tokenizer.apply_chat_template(
152+
messages, tools=TOOLS, tokenize=False, add_generation_prompt=True
153+
)
154+
response = await policy.generate.route(formatted)
155+
final = response[0]
156+
157+
logger.info(f"Final answer: {final.text}")
158+
assert "579" in final.text, "Expected 123 + 456 = 579"
159+
160+
logger.info("✅ test_tool_parsing_multi_turn passed!")
161+
162+
163+
@requires_cuda
164+
@pytest.mark.asyncio
165+
async def test_content_without_tool_calls(policy, tokenizer):
166+
"""
167+
Test that content equals text when no tool calls are made.
168+
169+
When a request doesn't trigger tool usage, the completion's content
170+
field should equal the raw text output.
171+
"""
172+
# Ask a non-math question that won't trigger the calculator tool
173+
messages = [
174+
{
175+
"role": "system",
176+
"content": "/no_think You are a helpful assistant.",
177+
},
178+
{"role": "user", "content": "What is the capital of France?"},
179+
]
180+
181+
formatted_request = tokenizer.apply_chat_template(
182+
messages,
183+
tokenize=False,
184+
add_generation_prompt=True,
185+
)
186+
187+
response = await policy.generate.route(formatted_request)
188+
completion = response[0]
189+
190+
logger.info(f"Response text: {completion.text}")
191+
logger.info(f"Response content: {completion.content}")
192+
193+
assert completion.tool_calls == [], "Should have no tool calls"
194+
assert completion.content is not None, "Should have content when no tools called"
195+
assert (
196+
completion.content == completion.text
197+
), "Content should equal text when no tools"
198+
199+
logger.info("✅ test_content_without_tool_calls passed!")

0 commit comments

Comments
 (0)