Skip to content

Commit 59655cc

Browse files
authored
Refactor API calling implementation
1 parent d431933 commit 59655cc

1 file changed

Lines changed: 109 additions & 63 deletions

File tree

lexoid/core/parse_type/llm_parser.py

Lines changed: 109 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,24 @@
33
import mimetypes
44
import os
55
import time
6+
from functools import wraps
7+
from typing import Dict, List, Optional
8+
69
import pypdfium2 as pdfium
710
import requests
8-
from functools import wraps
11+
from huggingface_hub import InferenceClient
12+
from loguru import logger
13+
from openai import OpenAI
914
from requests.exceptions import HTTPError
10-
from typing import Dict, List
15+
from together import Together
1116

1217
from lexoid.core.prompt_templates import (
1318
INSTRUCTIONS_ADD_PG_BREAK,
19+
LLAMA_PARSER_PROMPT,
1420
OPENAI_USER_PROMPT,
1521
PARSER_PROMPT,
16-
LLAMA_PARSER_PROMPT,
1722
)
1823
from lexoid.core.utils import convert_image_to_pdf
19-
from loguru import logger
20-
from openai import OpenAI
21-
from together import Together
22-
from huggingface_hub import InferenceClient
2324

2425

2526
def retry_on_http_error(func):
@@ -172,18 +173,54 @@ def convert_pdf_page_to_base64(
172173
return base64.b64encode(img_byte_arr.getvalue()).decode("utf-8")
173174

174175

175-
def parse_with_api(path: str, api: str, **kwargs) -> List[Dict] | str:
176-
"""
177-
Parse documents (PDFs or images) using various vision model APIs.
176+
def get_messages(
177+
system_prompt: Optional[str], user_prompt: Optional[str], image_url: Optional[str]
178+
) -> List[Dict]:
179+
messages = []
180+
if system_prompt:
181+
messages.append(
182+
{
183+
"role": "system",
184+
"content": system_prompt,
185+
}
186+
)
187+
base_message = (
188+
[
189+
{"type": "text", "text": user_prompt},
190+
]
191+
if user_prompt
192+
else []
193+
)
194+
image_message = (
195+
[
196+
{
197+
"type": "image_url",
198+
"image_url": {"url": image_url},
199+
}
200+
]
201+
if image_url
202+
else []
203+
)
178204

179-
Args:
180-
path (str): Path to the document to parse
181-
api (str): Which API to use ("openai", "huggingface", or "together")
182-
**kwargs: Additional arguments including model, temperature, title, etc.
205+
messages.append(
206+
{
207+
"role": "user",
208+
"content": base_message + image_message,
209+
}
210+
)
183211

184-
Returns:
185-
Dict: Dictionary containing parsed document data
186-
"""
212+
return messages
213+
214+
215+
def create_response(
216+
api: str,
217+
model: str,
218+
system_prompt: Optional[str] = None,
219+
user_prompt: Optional[str] = None,
220+
image_url: Optional[str] = None,
221+
temperature: float = 0.2,
222+
max_tokens: int = 1024,
223+
) -> Dict:
187224
# Initialize appropriate client
188225
clients = {
189226
"openai": lambda: OpenAI(),
@@ -201,9 +238,46 @@ def parse_with_api(path: str, api: str, **kwargs) -> List[Dict] | str:
201238
),
202239
}
203240
assert api in clients, f"Unsupported API: {api}"
204-
logger.debug(f"Parsing with {api} API and model {kwargs['model']}")
205241
client = clients[api]()
206242

243+
# Prepare messages for the API call
244+
messages = get_messages(system_prompt, user_prompt, image_url)
245+
246+
# Common completion parameters
247+
completion_params = {
248+
"model": model,
249+
"messages": messages,
250+
"max_tokens": max_tokens,
251+
"temperature": temperature,
252+
}
253+
254+
# Get completion from selected API
255+
response = client.chat.completions.create(**completion_params)
256+
token_usage = response.usage
257+
258+
# Extract the response text
259+
page_text = response.choices[0].message.content
260+
261+
return {
262+
"response": page_text,
263+
"usage": token_usage,
264+
}
265+
266+
267+
def parse_with_api(path: str, api: str, **kwargs) -> List[Dict] | str:
268+
"""
269+
Parse documents (PDFs or images) using various vision model APIs.
270+
271+
Args:
272+
path (str): Path to the document to parse
273+
api (str): Which API to use ("openai", "huggingface", or "together")
274+
**kwargs: Additional arguments including model, temperature, title, etc.
275+
276+
Returns:
277+
Dict: Dictionary containing parsed document data
278+
"""
279+
logger.debug(f"Parsing with {api} API and model {kwargs['model']}")
280+
207281
# Handle different input types
208282
mime_type, _ = mimetypes.guess_type(path)
209283
if mime_type and mime_type.startswith("image"):
@@ -222,60 +296,32 @@ def parse_with_api(path: str, api: str, **kwargs) -> List[Dict] | str:
222296
for page_num in range(len(pdf_document))
223297
]
224298

225-
# API-specific message formatting
226-
def get_messages(page_num: int, image_url: str) -> List[Dict]:
227-
image_message = {
228-
"type": "image_url",
229-
"image_url": {"url": image_url},
230-
}
231-
299+
# Process each page/image
300+
all_results = []
301+
for page_num, image_url in images:
232302
if api == "openai":
233303
system_prompt = kwargs.get(
234304
"system_prompt", PARSER_PROMPT.format(custom_instructions="")
235305
)
236306
user_prompt = kwargs.get("user_prompt", OPENAI_USER_PROMPT)
237-
return [
238-
{
239-
"role": "system",
240-
"content": system_prompt,
241-
},
242-
{
243-
"role": "user",
244-
"content": [
245-
{"type": "text", "text": user_prompt},
246-
image_message,
247-
],
248-
},
249-
]
250307
else:
251-
prompt = kwargs.get("system_prompt", LLAMA_PARSER_PROMPT)
252-
base_message = {"type": "text", "text": prompt}
253-
return [
254-
{
255-
"role": "user",
256-
"content": [base_message, image_message],
257-
}
258-
]
259-
260-
# Process each page/image
261-
all_results = []
262-
for page_num, image_url in images:
263-
messages = get_messages(page_num, image_url)
264-
265-
# Common completion parameters
266-
completion_params = {
267-
"model": kwargs["model"],
268-
"messages": messages,
269-
"max_tokens": kwargs.get("max_tokens", 1024),
270-
"temperature": kwargs.get("temperature", 0.2),
271-
}
308+
system_prompt = kwargs.get("system_prompt", None)
309+
user_prompt = kwargs.get("user_prompt", LLAMA_PARSER_PROMPT)
310+
311+
response = create_response(
312+
api=api,
313+
model=kwargs["model"],
314+
system_prompt=system_prompt,
315+
user_prompt=user_prompt,
316+
image_url=image_url,
317+
temperature=kwargs.get("temperature", 0.2),
318+
max_tokens=kwargs.get("max_tokens", 1024),
319+
)
272320

273321
# Get completion from selected API
274-
response = client.chat.completions.create(**completion_params)
275-
token_usage = response.usage
322+
page_text = response["response"]
323+
token_usage = response["usage"]
276324

277-
# Extract the response text
278-
page_text = response.choices[0].message.content
279325
if kwargs.get("verbose", None):
280326
logger.debug(f"Page {page_num + 1} response: {page_text}")
281327

0 commit comments

Comments
 (0)