44
55import copy
66import json
7+ import warnings
78from concurrent .futures import ThreadPoolExecutor
89from enum import Enum
910from typing import Any , Dict , List , Optional , Union
1415from haystack import Document , component , default_from_dict , default_to_dict , logging
1516from haystack .components .builders import PromptBuilder
1617from haystack .components .generators .chat import AzureOpenAIChatGenerator , OpenAIChatGenerator
18+ from haystack .components .generators .chat .types import ChatGenerator
1719from haystack .components .preprocessors import DocumentSplitter
20+ from haystack .core .serialization import import_class_by_name
1821from haystack .dataclasses import ChatMessage
1922from haystack .lazy_imports import LazyImport
2023from haystack .utils import deserialize_callable , deserialize_secrets_inplace , expand_page_range
@@ -76,7 +79,8 @@ class LLMMetadataExtractor:
7679
7780 ```python
7881 from haystack import Document
79- from haystack_experimental.components.extractors.llm_metadata_extractor import LLMMetadataExtractor
82+ from haystack.components.extractors.llm_metadata_extractor import LLMMetadataExtractor
83+ from haystack.components.generators.chat import OpenAIChatGenerator
8084
8185 NER_PROMPT = '''
8286 -Goal-
@@ -122,22 +126,24 @@ class LLMMetadataExtractor:
122126 Document(content="Hugging Face is a company that was founded in New York, USA and is known for its Transformers library")
123127 ]
124128
129+ chat_generator = OpenAIChatGenerator(
130+ generation_kwargs={
131+ "max_tokens": 500,
132+ "temperature": 0.0,
133+ "seed": 0,
134+ "response_format": {"type": "json_object"},
135+ },
136+ max_retries=1,
137+ timeout=60.0,
138+ )
139+
125140 extractor = LLMMetadataExtractor(
126141 prompt=NER_PROMPT,
127- generator_api="openai",
128- generator_api_params={
129- "generation_kwargs": {
130- "max_tokens": 500,
131- "temperature": 0.0,
132- "seed": 0,
133- "response_format": {"type": "json_object"},
134- },
135- "max_retries": 1,
136- "timeout": 60.0,
137- },
142+ chat_generator=generator,
138143 expected_keys=["entities"],
139144 raise_on_failure=False,
140145 )
146+
141147 extractor.warm_up()
142148 extractor.run(documents=docs)
143149 >> {'documents': [
@@ -159,8 +165,9 @@ class LLMMetadataExtractor:
159165 def __init__ ( # pylint: disable=R0917
160166 self ,
161167 prompt : str ,
162- generator_api : Union [str , LLMProvider ],
168+ generator_api : Optional [ Union [str , LLMProvider ]] = None ,
163169 generator_api_params : Optional [Dict [str , Any ]] = None ,
170+ chat_generator : Optional [ChatGenerator ] = None ,
164171 expected_keys : Optional [List [str ]] = None ,
165172 page_range : Optional [List [Union [str , int ]]] = None ,
166173 raise_on_failure : bool = False ,
@@ -170,18 +177,20 @@ def __init__( # pylint: disable=R0917
170177 Initializes the LLMMetadataExtractor.
171178
172179 :param prompt: The prompt to be used for the LLM.
173- :param generator_api: The API provider for the LLM. Currently supported providers are:
174- "openai", "openai_azure", "aws_bedrock", "google_vertex"
175- :param generator_api_params: The parameters for the LLM generator.
180+ :param generator_api: The API provider for the LLM. Deprecated. Use chat_generator to configure the LLM.
181+ Currently supported providers are: "openai", "openai_azure", "aws_bedrock", "google_vertex".
182+ :param generator_api_params: The parameters for the LLM generator. Deprecated. Use chat_generator to configure
183+ the LLM.
184+ :param chat_generator: a ChatGenerator instance which represents the LLM. If provided, this will override
185+ settings in generator_api and generator_api_params.
176186 :param expected_keys: The keys expected in the JSON output from the LLM.
177187 :param page_range: A range of pages to extract metadata from. For example, page_range=['1', '3'] will extract
178- metadata from the first and third pages of each document. It also accepts printable range
179- strings, e.g.: ['1-3', '5', '8', '10-12'] will extract metadata from pages 1, 2, 3, 5, 8, 10,
180- 11, 12. If None, metadata will be extracted from the entire document for each document in the
181- documents list.
182- This parameter is optional and can be overridden in the `run` method.
188+ metadata from the first and third pages of each document. It also accepts printable range strings, e.g.:
189+ ['1-3', '5', '8', '10-12'] will extract metadata from pages 1, 2, 3, 5, 8, 10,11, 12.
190+ If None, metadata will be extracted from the entire document for each document in the documents list.
191+ This parameter is optional and can be overridden in the `run` method.
183192 :param raise_on_failure: Whether to raise an error on failure during the execution of the Generator or
184- validation of the JSON output.
193+ validation of the JSON output.
185194 :param max_workers: The maximum number of workers to use in the thread pool executor.
186195 """
187196 self .prompt = prompt
@@ -195,11 +204,32 @@ def __init__( # pylint: disable=R0917
195204 self .builder = PromptBuilder (prompt , required_variables = variables )
196205 self .raise_on_failure = raise_on_failure
197206 self .expected_keys = expected_keys or []
198- self .generator_api = (
199- generator_api if isinstance (generator_api , LLMProvider ) else LLMProvider .from_str (generator_api )
200- )
201- self .generator_api_params = generator_api_params or {}
202- self .llm_provider = self ._init_generator (self .generator_api , self .generator_api_params )
207+ generator_api_params = generator_api_params or {}
208+
209+ if generator_api is None and chat_generator is None :
210+ raise ValueError ("Either generator_api or chat_generator must be provided." )
211+
212+ if chat_generator is not None :
213+ self ._chat_generator = chat_generator
214+ if generator_api is not None :
215+ logger .warning (
216+ "Both chat_generator and generator_api are provided. "
217+ "chat_generator will be used. generator_api/generator_api_params/LLMProvider are deprecated and "
218+ "will be removed in Haystack 2.13.0."
219+ )
220+ else :
221+ warnings .warn (
222+ "generator_api, generator_api_params, and LLMProvider are deprecated and will be removed in Haystack "
223+ "2.13.0. Use chat_generator instead. For example, change `generator_api=LLMProvider.OPENAI` to "
224+ "`chat_generator=OpenAIChatGenerator()`." ,
225+ DeprecationWarning ,
226+ )
227+ assert generator_api is not None # verified by the checks above
228+ generator_api = (
229+ generator_api if isinstance (generator_api , LLMProvider ) else LLMProvider .from_str (generator_api )
230+ )
231+ self ._chat_generator = self ._init_generator (generator_api , generator_api_params )
232+
203233 self .splitter = DocumentSplitter (split_by = "page" , split_length = 1 )
204234 self .expanded_range = expand_page_range (page_range ) if page_range else None
205235 self .max_workers = max_workers
@@ -233,8 +263,8 @@ def warm_up(self):
233263 """
234264 Warm up the LLM provider component.
235265 """
236- if hasattr (self .llm_provider , "warm_up" ):
237- self .llm_provider .warm_up ()
266+ if hasattr (self ._chat_generator , "warm_up" ):
267+ self ._chat_generator .warm_up ()
238268
239269 def to_dict (self ) -> Dict [str , Any ]:
240270 """
@@ -244,13 +274,10 @@ def to_dict(self) -> Dict[str, Any]:
244274 Dictionary with serialized data.
245275 """
246276
247- llm_provider = self .llm_provider .to_dict ()
248-
249277 return default_to_dict (
250278 self ,
251279 prompt = self .prompt ,
252- generator_api = self .generator_api .value ,
253- generator_api_params = llm_provider ["init_parameters" ],
280+ chat_generator = self ._chat_generator .to_dict (),
254281 expected_keys = self .expected_keys ,
255282 page_range = self .expanded_range ,
256283 raise_on_failure = self .raise_on_failure ,
@@ -270,6 +297,15 @@ def from_dict(cls, data: Dict[str, Any]) -> "LLMMetadataExtractor":
270297
271298 init_parameters = data .get ("init_parameters" , {})
272299
300+ # new deserialization with chat_generator
301+ if init_parameters .get ("chat_generator" ) is not None :
302+ chat_generator_class = import_class_by_name (init_parameters ["chat_generator" ]["type" ])
303+ assert hasattr (chat_generator_class , "from_dict" ) # we know but mypy doesn't
304+ chat_generator_instance = chat_generator_class .from_dict (init_parameters ["chat_generator" ])
305+ data ["init_parameters" ]["chat_generator" ] = chat_generator_instance
306+ return default_from_dict (cls , data )
307+
308+ # legacy deserialization
273309 if "generator_api" in init_parameters :
274310 data ["init_parameters" ]["generator_api" ] = LLMProvider .from_str (data ["init_parameters" ]["generator_api" ])
275311
@@ -364,15 +400,15 @@ def _run_on_thread(self, prompt: Optional[ChatMessage]) -> Dict[str, Any]:
364400 return {"replies" : ["{}" ]}
365401
366402 try :
367- result = self .llm_provider .run (messages = [prompt ])
403+ result = self ._chat_generator .run (messages = [prompt ])
368404 except Exception as e :
405+ if self .raise_on_failure :
406+ raise e
369407 logger .error (
370408 "LLM {class_name} execution failed. Skipping metadata extraction. Failed with exception '{error}'." ,
371- class_name = self .llm_provider .__class__ .__name__ ,
409+ class_name = self ._chat_generator .__class__ .__name__ ,
372410 error = e ,
373411 )
374- if self .raise_on_failure :
375- raise e
376412 result = {"error" : "LLM failed with exception: " + str (e )}
377413 return result
378414
0 commit comments