Skip to content

Commit bf0bd43

Browse files
feat(meta-llama): add timeout and max_retries to chat generator (#2872)
1 parent a7f4732 commit bf0bd43

2 files changed

Lines changed: 28 additions & 0 deletions

File tree

integrations/meta_llama/src/haystack_integrations/components/generators/meta_llama/chat/chat_generator.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def __init__(
6262
streaming_callback: StreamingCallbackT | None = None,
6363
api_base_url: str | None = "https://api.llama.com/compat/v1/",
6464
generation_kwargs: dict[str, Any] | None = None,
65+
timeout: float | None = None,
66+
max_retries: int | None = None,
6567
tools: ToolsType | None = None,
6668
):
6769
"""
@@ -99,6 +101,10 @@ def __init__(
99101
For details, see the [OpenAI Structured Outputs documentation](https://platform.openai.com/docs/guides/structured-outputs).
100102
For structured outputs with streaming, the `response_format` must be a JSON
101103
schema and not a Pydantic model.
104+
:param timeout:
105+
Timeout for Llama API client calls.
106+
:param max_retries:
107+
Maximum number of retries to attempt for failed requests.
102108
:param tools:
103109
A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
104110
Each tool should have a unique name.
@@ -110,6 +116,8 @@ def __init__(
110116
api_base_url=api_base_url,
111117
organization=None,
112118
generation_kwargs=generation_kwargs,
119+
timeout=timeout,
120+
max_retries=max_retries,
113121
tools=tools,
114122
)
115123

@@ -166,5 +174,7 @@ def to_dict(self) -> dict[str, Any]:
166174
api_base_url=self.api_base_url,
167175
generation_kwargs=generation_kwargs,
168176
api_key=self.api_key.to_dict(),
177+
timeout=self.timeout,
178+
max_retries=self.max_retries,
169179
tools=serialize_tools_or_toolset(self.tools),
170180
)

integrations/meta_llama/tests/test_llama_chat_generator.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ def test_init_default(self, monkeypatch):
111111
assert component.api_base_url == "https://api.llama.com/compat/v1/"
112112
assert component.streaming_callback is None
113113
assert not component.generation_kwargs
114+
assert component.timeout is None
115+
assert component.max_retries is None
114116

115117
def test_init_fail_wo_api_key(self, monkeypatch):
116118
monkeypatch.delenv("LLAMA_API_KEY", raising=False)
@@ -124,6 +126,8 @@ def test_init_with_parameters(self):
124126
streaming_callback=print_streaming_chunk,
125127
api_base_url="test-base-url",
126128
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
129+
timeout=15.0,
130+
max_retries=3,
127131
)
128132
assert component.client.api_key == "test-api-key"
129133
assert component.model == "Llama-4-Scout-17B-16E-Instruct-FP8"
@@ -132,6 +136,8 @@ def test_init_with_parameters(self):
132136
"max_tokens": 10,
133137
"some_test_param": "test-params",
134138
}
139+
assert component.timeout == 15.0
140+
assert component.max_retries == 3
135141

136142
def test_to_dict_default(self, monkeypatch):
137143
monkeypatch.setenv("LLAMA_API_KEY", "test-api-key")
@@ -153,6 +159,8 @@ def test_to_dict_default(self, monkeypatch):
153159
"streaming_callback": None,
154160
"api_base_url": "https://api.llama.com/compat/v1/",
155161
"generation_kwargs": {},
162+
"timeout": None,
163+
"max_retries": None,
156164
}
157165

158166
for key, value in expected_params.items():
@@ -212,6 +220,8 @@ class NobelPrizeInfo(BaseModel):
212220
"api_base_url": "test-base-url",
213221
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
214222
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params", "response_format": schema},
223+
"timeout": None,
224+
"max_retries": None,
215225
}
216226

217227
for key, value in expected_params.items():
@@ -234,6 +244,8 @@ def test_from_dict(self, monkeypatch):
234244
"max_tokens": 10,
235245
"some_test_param": "test-params",
236246
},
247+
"timeout": 30.0,
248+
"max_retries": 5,
237249
},
238250
}
239251
component = MetaLlamaChatGenerator.from_dict(data)
@@ -245,6 +257,8 @@ def test_from_dict(self, monkeypatch):
245257
"some_test_param": "test-params",
246258
}
247259
assert component.api_key == Secret.from_env_var("LLAMA_API_KEY")
260+
assert component.timeout == 30.0
261+
assert component.max_retries == 5
248262

249263
def test_from_dict_fail_wo_env_var(self, monkeypatch):
250264
monkeypatch.delenv("LLAMA_API_KEY", raising=False)
@@ -263,6 +277,8 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch):
263277
"max_tokens": 10,
264278
"some_test_param": "test-params",
265279
},
280+
"timeout": 30.0,
281+
"max_retries": 5,
266282
},
267283
}
268284
with pytest.raises(ValueError, match=r"None of the .* environment variables are set"):
@@ -561,6 +577,8 @@ def test_serde_in_pipeline(self, monkeypatch):
561577
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
562578
"api_base_url": "https://api.llama.com/compat/v1/",
563579
"generation_kwargs": {"temperature": 0.7},
580+
"timeout": None,
581+
"max_retries": None,
564582
"tools": [
565583
{
566584
"type": "haystack.tools.tool.Tool",

0 commit comments

Comments
 (0)