@@ -72,13 +72,6 @@ def parse_llm_doc(path: str, **kwargs) -> List[Dict] | str:
7272
7373
7474def parse_with_gemini (path : str , ** kwargs ) -> List [Dict ] | str :
75- logger .debug (f"Parsing with Gemini API and model { kwargs ['model' ]} " )
76- api_key = os .environ .get ("GOOGLE_API_KEY" )
77- if not api_key :
78- raise ValueError ("GOOGLE_API_KEY environment variable is not set" )
79-
80- url = f"https://generativelanguage.googleapis.com/v1beta/models/{ kwargs ['model' ]} :generateContent?key={ api_key } "
81-
8275 # Check if the file is an image and convert to PDF if necessary
8376 mime_type , _ = mimetypes .guess_type (path )
8477 if mime_type and mime_type .startswith ("image" ):
@@ -90,6 +83,20 @@ def parse_with_gemini(path: str, **kwargs) -> List[Dict] | str:
9083 file_content = file .read ()
9184 base64_file = base64 .b64encode (file_content ).decode ("utf-8" )
9285
86+ return parse_image_with_gemini (
87+ base64_file = base64_file , mime_type = mime_type , ** kwargs
88+ )
89+
90+
91+ def parse_image_with_gemini (
92+ base64_file : str , mime_type : str = "image/png" , ** kwargs
93+ ) -> List [Dict ] | str :
94+ api_key = os .environ .get ("GOOGLE_API_KEY" )
95+ if not api_key :
96+ raise ValueError ("GOOGLE_API_KEY environment variable is not set" )
97+
98+ url = f"https://generativelanguage.googleapis.com/v1beta/models/{ kwargs ['model' ]} :generateContent?key={ api_key } "
99+
93100 if "system_prompt" in kwargs :
94101 prompt = kwargs ["system_prompt" ]
95102 else :
@@ -129,24 +136,23 @@ def parse_with_gemini(path: str, **kwargs) -> List[Dict] | str:
129136 if "text" in part
130137 )
131138
132- combined_text = ""
139+ combined_text = raw_text
133140 if "<output>" in raw_text :
134141 combined_text = raw_text .split ("<output>" )[- 1 ].strip ()
135- if "</output>" in result :
136- combined_text = result .split ("</output>" )[0 ].strip ()
142+ if "</output>" in combined_text :
143+ combined_text = combined_text .split ("</output>" )[0 ].strip ()
137144
138145 token_usage = result ["usageMetadata" ]
139146 input_tokens = token_usage .get ("promptTokenCount" , 0 )
140147 output_tokens = token_usage .get ("candidatesTokenCount" , 0 )
141148 total_tokens = input_tokens + output_tokens
142-
143149 return {
144150 "raw" : combined_text .replace ("<page-break>" , "\n \n " ),
145151 "segments" : [
146152 {"metadata" : {"page" : kwargs .get ("start" , 0 ) + page_no }, "content" : page }
147153 for page_no , page in enumerate (combined_text .split ("<page-break>" ), start = 1 )
148154 ],
149- "title" : kwargs [ "title" ] ,
155+ "title" : kwargs . get ( "title" , "" ) ,
150156 "url" : kwargs .get ("url" , "" ),
151157 "parent_title" : kwargs .get ("parent_title" , "" ),
152158 "recursive_docs" : [],
@@ -236,8 +242,24 @@ def create_response(
236242 base_url = "https://api.fireworks.ai/inference/v1" ,
237243 api_key = os .environ ["FIREWORKS_API_KEY" ],
238244 ),
245+ "gemini" : lambda : None , # Gemini is handled separately
239246 }
240247 assert api in clients , f"Unsupported API: { api } "
248+
249+ if api == "gemini" :
250+ image_url = image_url .split ("data:image/png;base64," )[1 ]
251+ response = parse_image_with_gemini (
252+ base64_file = image_url ,
253+ model = model ,
254+ temperature = temperature ,
255+ max_tokens = max_tokens ,
256+ system_prompt = system_prompt ,
257+ )
258+ return {
259+ "response" : response ["raw" ],
260+ "usage" : response ["token_usage" ],
261+ }
262+
241263 client = clients [api ]()
242264
243265 # Prepare messages for the API call
0 commit comments