Skip to content

Commit 6db8f0a

Browse files
authored
refactor: LLMMetadataExtractor - adopt ChatGenerator protocol: deprecate generator_api, generator_api_params and LLMProvider (#9099)
* draft * improvements + tests * release note * mypy fixes * improve relnote * serialize chat_generator only * small simplification * clarify that also LLMProvider is deprecated * revert from_dict * test_from_dict_openai_using_chat_generator
1 parent dae8c7b commit 6db8f0a

3 files changed

Lines changed: 195 additions & 58 deletions

File tree

haystack/components/extractors/llm_metadata_extractor.py

Lines changed: 73 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import copy
66
import json
7+
import warnings
78
from concurrent.futures import ThreadPoolExecutor
89
from enum import Enum
910
from typing import Any, Dict, List, Optional, Union
@@ -14,7 +15,9 @@
1415
from haystack import Document, component, default_from_dict, default_to_dict, logging
1516
from haystack.components.builders import PromptBuilder
1617
from haystack.components.generators.chat import AzureOpenAIChatGenerator, OpenAIChatGenerator
18+
from haystack.components.generators.chat.types import ChatGenerator
1719
from haystack.components.preprocessors import DocumentSplitter
20+
from haystack.core.serialization import import_class_by_name
1821
from haystack.dataclasses import ChatMessage
1922
from haystack.lazy_imports import LazyImport
2023
from haystack.utils import deserialize_callable, deserialize_secrets_inplace, expand_page_range
@@ -76,7 +79,8 @@ class LLMMetadataExtractor:
7679
7780
```python
7881
from haystack import Document
79-
from haystack_experimental.components.extractors.llm_metadata_extractor import LLMMetadataExtractor
82+
from haystack.components.extractors.llm_metadata_extractor import LLMMetadataExtractor
83+
from haystack.components.generators.chat import OpenAIChatGenerator
8084
8185
NER_PROMPT = '''
8286
-Goal-
@@ -122,22 +126,24 @@ class LLMMetadataExtractor:
122126
Document(content="Hugging Face is a company that was founded in New York, USA and is known for its Transformers library")
123127
]
124128
129+
chat_generator = OpenAIChatGenerator(
130+
generation_kwargs={
131+
"max_tokens": 500,
132+
"temperature": 0.0,
133+
"seed": 0,
134+
"response_format": {"type": "json_object"},
135+
},
136+
max_retries=1,
137+
timeout=60.0,
138+
)
139+
125140
extractor = LLMMetadataExtractor(
126141
prompt=NER_PROMPT,
127-
generator_api="openai",
128-
generator_api_params={
129-
"generation_kwargs": {
130-
"max_tokens": 500,
131-
"temperature": 0.0,
132-
"seed": 0,
133-
"response_format": {"type": "json_object"},
134-
},
135-
"max_retries": 1,
136-
"timeout": 60.0,
137-
},
142+
chat_generator=generator,
138143
expected_keys=["entities"],
139144
raise_on_failure=False,
140145
)
146+
141147
extractor.warm_up()
142148
extractor.run(documents=docs)
143149
>> {'documents': [
@@ -159,8 +165,9 @@ class LLMMetadataExtractor:
159165
def __init__( # pylint: disable=R0917
160166
self,
161167
prompt: str,
162-
generator_api: Union[str, LLMProvider],
168+
generator_api: Optional[Union[str, LLMProvider]] = None,
163169
generator_api_params: Optional[Dict[str, Any]] = None,
170+
chat_generator: Optional[ChatGenerator] = None,
164171
expected_keys: Optional[List[str]] = None,
165172
page_range: Optional[List[Union[str, int]]] = None,
166173
raise_on_failure: bool = False,
@@ -170,18 +177,20 @@ def __init__( # pylint: disable=R0917
170177
Initializes the LLMMetadataExtractor.
171178
172179
:param prompt: The prompt to be used for the LLM.
173-
:param generator_api: The API provider for the LLM. Currently supported providers are:
174-
"openai", "openai_azure", "aws_bedrock", "google_vertex"
175-
:param generator_api_params: The parameters for the LLM generator.
180+
:param generator_api: The API provider for the LLM. Deprecated. Use chat_generator to configure the LLM.
181+
Currently supported providers are: "openai", "openai_azure", "aws_bedrock", "google_vertex".
182+
:param generator_api_params: The parameters for the LLM generator. Deprecated. Use chat_generator to configure
183+
the LLM.
184+
:param chat_generator: a ChatGenerator instance which represents the LLM. If provided, this will override
185+
settings in generator_api and generator_api_params.
176186
:param expected_keys: The keys expected in the JSON output from the LLM.
177187
:param page_range: A range of pages to extract metadata from. For example, page_range=['1', '3'] will extract
178-
metadata from the first and third pages of each document. It also accepts printable range
179-
strings, e.g.: ['1-3', '5', '8', '10-12'] will extract metadata from pages 1, 2, 3, 5, 8, 10,
180-
11, 12. If None, metadata will be extracted from the entire document for each document in the
181-
documents list.
182-
This parameter is optional and can be overridden in the `run` method.
188+
metadata from the first and third pages of each document. It also accepts printable range strings, e.g.:
189+
['1-3', '5', '8', '10-12'] will extract metadata from pages 1, 2, 3, 5, 8, 10,11, 12.
190+
If None, metadata will be extracted from the entire document for each document in the documents list.
191+
This parameter is optional and can be overridden in the `run` method.
183192
:param raise_on_failure: Whether to raise an error on failure during the execution of the Generator or
184-
validation of the JSON output.
193+
validation of the JSON output.
185194
:param max_workers: The maximum number of workers to use in the thread pool executor.
186195
"""
187196
self.prompt = prompt
@@ -195,11 +204,32 @@ def __init__( # pylint: disable=R0917
195204
self.builder = PromptBuilder(prompt, required_variables=variables)
196205
self.raise_on_failure = raise_on_failure
197206
self.expected_keys = expected_keys or []
198-
self.generator_api = (
199-
generator_api if isinstance(generator_api, LLMProvider) else LLMProvider.from_str(generator_api)
200-
)
201-
self.generator_api_params = generator_api_params or {}
202-
self.llm_provider = self._init_generator(self.generator_api, self.generator_api_params)
207+
generator_api_params = generator_api_params or {}
208+
209+
if generator_api is None and chat_generator is None:
210+
raise ValueError("Either generator_api or chat_generator must be provided.")
211+
212+
if chat_generator is not None:
213+
self._chat_generator = chat_generator
214+
if generator_api is not None:
215+
logger.warning(
216+
"Both chat_generator and generator_api are provided. "
217+
"chat_generator will be used. generator_api/generator_api_params/LLMProvider are deprecated and "
218+
"will be removed in Haystack 2.13.0."
219+
)
220+
else:
221+
warnings.warn(
222+
"generator_api, generator_api_params, and LLMProvider are deprecated and will be removed in Haystack "
223+
"2.13.0. Use chat_generator instead. For example, change `generator_api=LLMProvider.OPENAI` to "
224+
"`chat_generator=OpenAIChatGenerator()`.",
225+
DeprecationWarning,
226+
)
227+
assert generator_api is not None # verified by the checks above
228+
generator_api = (
229+
generator_api if isinstance(generator_api, LLMProvider) else LLMProvider.from_str(generator_api)
230+
)
231+
self._chat_generator = self._init_generator(generator_api, generator_api_params)
232+
203233
self.splitter = DocumentSplitter(split_by="page", split_length=1)
204234
self.expanded_range = expand_page_range(page_range) if page_range else None
205235
self.max_workers = max_workers
@@ -233,8 +263,8 @@ def warm_up(self):
233263
"""
234264
Warm up the LLM provider component.
235265
"""
236-
if hasattr(self.llm_provider, "warm_up"):
237-
self.llm_provider.warm_up()
266+
if hasattr(self._chat_generator, "warm_up"):
267+
self._chat_generator.warm_up()
238268

239269
def to_dict(self) -> Dict[str, Any]:
240270
"""
@@ -244,13 +274,10 @@ def to_dict(self) -> Dict[str, Any]:
244274
Dictionary with serialized data.
245275
"""
246276

247-
llm_provider = self.llm_provider.to_dict()
248-
249277
return default_to_dict(
250278
self,
251279
prompt=self.prompt,
252-
generator_api=self.generator_api.value,
253-
generator_api_params=llm_provider["init_parameters"],
280+
chat_generator=self._chat_generator.to_dict(),
254281
expected_keys=self.expected_keys,
255282
page_range=self.expanded_range,
256283
raise_on_failure=self.raise_on_failure,
@@ -270,6 +297,15 @@ def from_dict(cls, data: Dict[str, Any]) -> "LLMMetadataExtractor":
270297

271298
init_parameters = data.get("init_parameters", {})
272299

300+
# new deserialization with chat_generator
301+
if init_parameters.get("chat_generator") is not None:
302+
chat_generator_class = import_class_by_name(init_parameters["chat_generator"]["type"])
303+
assert hasattr(chat_generator_class, "from_dict") # we know but mypy doesn't
304+
chat_generator_instance = chat_generator_class.from_dict(init_parameters["chat_generator"])
305+
data["init_parameters"]["chat_generator"] = chat_generator_instance
306+
return default_from_dict(cls, data)
307+
308+
# legacy deserialization
273309
if "generator_api" in init_parameters:
274310
data["init_parameters"]["generator_api"] = LLMProvider.from_str(data["init_parameters"]["generator_api"])
275311

@@ -364,15 +400,15 @@ def _run_on_thread(self, prompt: Optional[ChatMessage]) -> Dict[str, Any]:
364400
return {"replies": ["{}"]}
365401

366402
try:
367-
result = self.llm_provider.run(messages=[prompt])
403+
result = self._chat_generator.run(messages=[prompt])
368404
except Exception as e:
405+
if self.raise_on_failure:
406+
raise e
369407
logger.error(
370408
"LLM {class_name} execution failed. Skipping metadata extraction. Failed with exception '{error}'.",
371-
class_name=self.llm_provider.__class__.__name__,
409+
class_name=self._chat_generator.__class__.__name__,
372410
error=e,
373411
)
374-
if self.raise_on_failure:
375-
raise e
376412
result = {"error": "LLM failed with exception: " + str(e)}
377413
return result
378414

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
---
2+
enhancements:
3+
- |
4+
`LLMMetadataExtractor` now accepts a `chat_generator` initialization parameter, consisting of a pre-configured
5+
`ChatGenerator` instance.
6+
Regardless of whether `LLMMetadataExtractor` is initialized using `generator_api` and `generator_api_params` or
7+
the new `chat_generator` parameter, the serialization format will only include `chat_generator` in preparation
8+
for the future removal of `generator_api` and `generator_api_params`.
9+
deprecations:
10+
- |
11+
The `generator_api` and `generator_api_params` initialization parameters of `LLMMetadataExtractor` and the
12+
`LLMProvider` enum are deprecated and will be removed in Haystack 2.13.0. Use `chat_generator` instead to configure
13+
the underlying LLM. For example, change `generator_api=LLMProvider.OPENAI` to
14+
`chat_generator=OpenAIChatGenerator()`.

0 commit comments

Comments
 (0)