Skip to content

Commit a130e0c

Browse files
authored
[FE] Add gemini support to parse pdf with schema
1 parent 9768786 commit a130e0c

1 file changed

Lines changed: 34 additions & 12 deletions

File tree

lexoid/core/parse_type/llm_parser.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,6 @@ def parse_llm_doc(path: str, **kwargs) -> List[Dict] | str:
7272

7373

7474
def parse_with_gemini(path: str, **kwargs) -> List[Dict] | str:
75-
logger.debug(f"Parsing with Gemini API and model {kwargs['model']}")
76-
api_key = os.environ.get("GOOGLE_API_KEY")
77-
if not api_key:
78-
raise ValueError("GOOGLE_API_KEY environment variable is not set")
79-
80-
url = f"https://generativelanguage.googleapis.com/v1beta/models/{kwargs['model']}:generateContent?key={api_key}"
81-
8275
# Check if the file is an image and convert to PDF if necessary
8376
mime_type, _ = mimetypes.guess_type(path)
8477
if mime_type and mime_type.startswith("image"):
@@ -90,6 +83,20 @@ def parse_with_gemini(path: str, **kwargs) -> List[Dict] | str:
9083
file_content = file.read()
9184
base64_file = base64.b64encode(file_content).decode("utf-8")
9285

86+
return parse_image_with_gemini(
87+
base64_file=base64_file, mime_type=mime_type, **kwargs
88+
)
89+
90+
91+
def parse_image_with_gemini(
92+
base64_file: str, mime_type: str = "image/png", **kwargs
93+
) -> List[Dict] | str:
94+
api_key = os.environ.get("GOOGLE_API_KEY")
95+
if not api_key:
96+
raise ValueError("GOOGLE_API_KEY environment variable is not set")
97+
98+
url = f"https://generativelanguage.googleapis.com/v1beta/models/{kwargs['model']}:generateContent?key={api_key}"
99+
93100
if "system_prompt" in kwargs:
94101
prompt = kwargs["system_prompt"]
95102
else:
@@ -129,24 +136,23 @@ def parse_with_gemini(path: str, **kwargs) -> List[Dict] | str:
129136
if "text" in part
130137
)
131138

132-
combined_text = ""
139+
combined_text = raw_text
133140
if "<output>" in raw_text:
134141
combined_text = raw_text.split("<output>")[-1].strip()
135-
if "</output>" in result:
136-
combined_text = result.split("</output>")[0].strip()
142+
if "</output>" in combined_text:
143+
combined_text = combined_text.split("</output>")[0].strip()
137144

138145
token_usage = result["usageMetadata"]
139146
input_tokens = token_usage.get("promptTokenCount", 0)
140147
output_tokens = token_usage.get("candidatesTokenCount", 0)
141148
total_tokens = input_tokens + output_tokens
142-
143149
return {
144150
"raw": combined_text.replace("<page-break>", "\n\n"),
145151
"segments": [
146152
{"metadata": {"page": kwargs.get("start", 0) + page_no}, "content": page}
147153
for page_no, page in enumerate(combined_text.split("<page-break>"), start=1)
148154
],
149-
"title": kwargs["title"],
155+
"title": kwargs.get("title", ""),
150156
"url": kwargs.get("url", ""),
151157
"parent_title": kwargs.get("parent_title", ""),
152158
"recursive_docs": [],
@@ -236,8 +242,24 @@ def create_response(
236242
base_url="https://api.fireworks.ai/inference/v1",
237243
api_key=os.environ["FIREWORKS_API_KEY"],
238244
),
245+
"gemini": lambda: None, # Gemini is handled separately
239246
}
240247
assert api in clients, f"Unsupported API: {api}"
248+
249+
if api == "gemini":
250+
image_url = image_url.split("data:image/png;base64,")[1]
251+
response = parse_image_with_gemini(
252+
base64_file=image_url,
253+
model=model,
254+
temperature=temperature,
255+
max_tokens=max_tokens,
256+
system_prompt=system_prompt,
257+
)
258+
return {
259+
"response": response["raw"],
260+
"usage": response["token_usage"],
261+
}
262+
241263
client = clients[api]()
242264

243265
# Prepare messages for the API call

0 commit comments

Comments
 (0)