Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def __init__(
streaming_callback: StreamingCallbackT | None = None,
api_base_url: str | None = "https://api.llama.com/compat/v1/",
generation_kwargs: dict[str, Any] | None = None,
timeout: float | None = None,
max_retries: int | None = None,
tools: ToolsType | None = None,
):
"""
Expand Down Expand Up @@ -99,6 +101,10 @@ def __init__(
For details, see the [OpenAI Structured Outputs documentation](https://platform.openai.com/docs/guides/structured-outputs).
For structured outputs with streaming, the `response_format` must be a JSON
schema and not a Pydantic model.
:param timeout:
Timeout for Llama API client calls.
:param max_retries:
Maximum number of retries to attempt for failed requests.
:param tools:
A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
Each tool should have a unique name.
Expand All @@ -110,6 +116,8 @@ def __init__(
api_base_url=api_base_url,
organization=None,
generation_kwargs=generation_kwargs,
timeout=timeout,
max_retries=max_retries,
tools=tools,
)

Expand Down Expand Up @@ -166,5 +174,7 @@ def to_dict(self) -> dict[str, Any]:
api_base_url=self.api_base_url,
generation_kwargs=generation_kwargs,
api_key=self.api_key.to_dict(),
timeout=self.timeout,
max_retries=self.max_retries,
tools=serialize_tools_or_toolset(self.tools),
)
18 changes: 18 additions & 0 deletions integrations/meta_llama/tests/test_llama_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def test_init_default(self, monkeypatch):
assert component.api_base_url == "https://api.llama.com/compat/v1/"
assert component.streaming_callback is None
assert not component.generation_kwargs
assert component.timeout is None
assert component.max_retries is None

def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("LLAMA_API_KEY", raising=False)
Expand All @@ -124,6 +126,8 @@ def test_init_with_parameters(self):
streaming_callback=print_streaming_chunk,
api_base_url="test-base-url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
timeout=15.0,
max_retries=3,
)
assert component.client.api_key == "test-api-key"
assert component.model == "Llama-4-Scout-17B-16E-Instruct-FP8"
Expand All @@ -132,6 +136,8 @@ def test_init_with_parameters(self):
"max_tokens": 10,
"some_test_param": "test-params",
}
assert component.timeout == 15.0
assert component.max_retries == 3

def test_to_dict_default(self, monkeypatch):
monkeypatch.setenv("LLAMA_API_KEY", "test-api-key")
Expand All @@ -153,6 +159,8 @@ def test_to_dict_default(self, monkeypatch):
"streaming_callback": None,
"api_base_url": "https://api.llama.com/compat/v1/",
"generation_kwargs": {},
"timeout": None,
"max_retries": None,
}

for key, value in expected_params.items():
Expand Down Expand Up @@ -212,6 +220,8 @@ class NobelPrizeInfo(BaseModel):
"api_base_url": "test-base-url",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params", "response_format": schema},
"timeout": None,
"max_retries": None,
}

for key, value in expected_params.items():
Expand All @@ -234,6 +244,8 @@ def test_from_dict(self, monkeypatch):
"max_tokens": 10,
"some_test_param": "test-params",
},
"timeout": 30.0,
"max_retries": 5,
},
}
component = MetaLlamaChatGenerator.from_dict(data)
Expand All @@ -245,6 +257,8 @@ def test_from_dict(self, monkeypatch):
"some_test_param": "test-params",
}
assert component.api_key == Secret.from_env_var("LLAMA_API_KEY")
assert component.timeout == 30.0
assert component.max_retries == 5

def test_from_dict_fail_wo_env_var(self, monkeypatch):
monkeypatch.delenv("LLAMA_API_KEY", raising=False)
Expand All @@ -263,6 +277,8 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch):
"max_tokens": 10,
"some_test_param": "test-params",
},
"timeout": 30.0,
"max_retries": 5,
},
}
with pytest.raises(ValueError, match=r"None of the .* environment variables are set"):
Expand Down Expand Up @@ -561,6 +577,8 @@ def test_serde_in_pipeline(self, monkeypatch):
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"api_base_url": "https://api.llama.com/compat/v1/",
"generation_kwargs": {"temperature": 0.7},
"timeout": None,
"max_retries": None,
"tools": [
{
"type": "haystack.tools.tool.Tool",
Expand Down