Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 0 additions & 95 deletions integrations/meta_llama/tests/test_llama_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,35 +322,6 @@ def test_live_run_wrong_model(self, chat_messages):
with pytest.raises(OpenAIError):
component.run(chat_messages)

@pytest.mark.skipif(
not os.environ.get("LLAMA_API_KEY", None),
reason="Export an env var called LLAMA_API_KEY containing the OpenAI API key to run this test.",
)
@pytest.mark.integration
def test_live_run_streaming(self):
class Callback:
def __init__(self):
self.responses = ""
self.counter = 0

def __call__(self, chunk: StreamingChunk) -> None:
self.counter += 1
self.responses += chunk.content if chunk.content else ""

callback = Callback()
component = MetaLlamaChatGenerator(streaming_callback=callback)
results = component.run([ChatMessage.from_user("What's the capital of France?")])

assert len(results["replies"]) == 1
message: ChatMessage = results["replies"][0]
assert "Paris" in message.text

assert "Llama-4-Scout-17B-16E-Instruct-FP8" in message.meta["model"]
assert message.meta["finish_reason"] == "stop"

assert callback.counter > 1
assert "Paris" in callback.responses

@pytest.mark.skipif(
not os.environ.get("LLAMA_API_KEY", None),
reason="Export an env var called LLAMA_API_KEY containing the Llama API key to run this test.",
Expand Down Expand Up @@ -638,69 +609,3 @@ def tool_fn(city: str) -> str:
)

assert generator.tools == [weather_tool, toolset]

@pytest.mark.skipif(
not os.environ.get("LLAMA_API_KEY", None),
reason="Export an env var called LLAMA_API_KEY containing the OpenAI API key to run this test.",
)
@pytest.mark.integration
def test_live_run_with_mixed_tools(self, mixed_tools):
"""
Integration test that verifies MetaLlamaChatGenerator works with mixed Tool and Toolset.
This tests that the LLM can correctly invoke tools from both a standalone Tool and a Toolset.
"""
initial_messages = [
ChatMessage.from_user("What's the weather like in Paris and what is the population of Berlin?")
]
# default model Llama-4-Scout-17B-16E-Instruct-FP8 can't handle multiple tool responses well
# we use stronger model Llama-4-Maverick-17B-128E-Instruct-FP8 for this test
component = MetaLlamaChatGenerator(model="Llama-4-Maverick-17B-128E-Instruct-FP8", tools=mixed_tools)
results = component.run(messages=initial_messages)

assert len(results["replies"]) > 0, "No replies received"

# Find the message with tool calls
tool_call_message = None
for message in results["replies"]:
if message.tool_calls:
tool_call_message = message
break

assert tool_call_message is not None, "No message with tool call found"
assert isinstance(tool_call_message, ChatMessage), "Tool message is not a ChatMessage instance"
assert ChatMessage.is_from(tool_call_message, ChatRole.ASSISTANT), "Tool message is not from the assistant"

tool_calls = tool_call_message.tool_calls
assert len(tool_calls) == 2, f"Expected 2 tool calls, got {len(tool_calls)}"

# Verify we got calls to both weather and population tools
tool_names = {tc.tool_name for tc in tool_calls}
assert "weather" in tool_names, "Expected 'weather' tool call"
assert "population" in tool_names, "Expected 'population' tool call"

# Verify tool call details
for tool_call in tool_calls:
assert tool_call.id, "Tool call does not contain value for 'id' key"
assert tool_call.tool_name in ["weather", "population"]
assert "city" in tool_call.arguments
assert tool_call.arguments["city"] in ["Paris", "Berlin"]
assert tool_call_message.meta["finish_reason"] == "tool_calls"

# Mock the response we'd get from ToolInvoker
tool_result_messages = []
for tool_call in tool_calls:
if tool_call.tool_name == "weather":
result = "The weather in Paris is sunny and 32°C"
else: # population
result = "The population of Berlin is 2.2 million"
tool_result_messages.append(ChatMessage.from_tool(tool_result=result, origin=tool_call))

new_messages = [*initial_messages, tool_call_message, *tool_result_messages]
results = component.run(new_messages)

assert len(results["replies"]) == 1
final_message = results["replies"][0]
assert not final_message.tool_calls
assert len(final_message.text) > 0
assert "paris" in final_message.text.lower()
assert "berlin" in final_message.text.lower()
Loading