Skip to content

Commit 2403be8

Browse files
Integrate LiteLLM for multi-provider LLM support (#168)
* Integrate litellm for multi-provider LLM support * recover the default config yaml * Use litellm.acompletion for native async support * fix tob * Rename llm_complete/allm_complete to llm_completion/llm_acompletion, remove unused llm_complete_stream * Pin litellm to version 1.82.0 * resolve comments * args from cli is used to overrides config.yaml * Fix get_page_tokens hardcoded model default Pass opt.model to get_page_tokens so tokenization respects the configured model instead of always using gpt-4o-2024-11-20. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Remove explicit openai dependency from requirements.txt openai is no longer directly imported; it comes in as a transitive dependency of litellm. Pinning it explicitly risks version conflicts. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Restore openai==1.101.0 pin in requirements.txt litellm==1.82.0 and openai-agents have conflicting openai version requirements, but openai==1.101.0 works at runtime for both. The pin is necessary to prevent litellm from pulling in openai>=2.x which would break openai-agents. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Remove explicit openai dependency from requirements.txt openai is not directly used; it comes in as a transitive dependency of litellm. No openai-agents in this branch so no pin needed. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix an litellm error log * resolve comments --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 4b4b20f commit 2403be8

File tree

5 files changed

+78
-104
lines changed

5 files changed

+78
-104
lines changed

pageindex/config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
model: "gpt-4o-2024-11-20"
2+
# model: "anthropic/claude-sonnet-4-6"
23
toc_check_page_num: 20
34
max_page_num_each_node: 10
45
max_token_num_each_node: 20000

pageindex/page_index.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ async def check_title_appearance(item, page_list, start_index=1, model=None):
3636
}}
3737
Directly return the final JSON structure. Do not output anything else."""
3838

39-
response = await ChatGPT_API_async(model=model, prompt=prompt)
39+
response = await llm_acompletion(model=model, prompt=prompt)
4040
response = extract_json(response)
4141
if 'answer' in response:
4242
answer = response['answer']
@@ -64,7 +64,7 @@ async def check_title_appearance_in_start(title, page_text, model=None, logger=N
6464
}}
6565
Directly return the final JSON structure. Do not output anything else."""
6666

67-
response = await ChatGPT_API_async(model=model, prompt=prompt)
67+
response = await llm_acompletion(model=model, prompt=prompt)
6868
response = extract_json(response)
6969
if logger:
7070
logger.info(f"Response: {response}")
@@ -116,7 +116,7 @@ def toc_detector_single_page(content, model=None):
116116
Directly return the final JSON structure. Do not output anything else.
117117
Please note: abstract,summary, notation list, figure list, table list, etc. are not table of contents."""
118118

119-
response = ChatGPT_API(model=model, prompt=prompt)
119+
response = llm_completion(model=model, prompt=prompt)
120120
# print('response', response)
121121
json_content = extract_json(response)
122122
return json_content['toc_detected']
@@ -135,7 +135,7 @@ def check_if_toc_extraction_is_complete(content, toc, model=None):
135135
Directly return the final JSON structure. Do not output anything else."""
136136

137137
prompt = prompt + '\n Document:\n' + content + '\n Table of contents:\n' + toc
138-
response = ChatGPT_API(model=model, prompt=prompt)
138+
response = llm_completion(model=model, prompt=prompt)
139139
json_content = extract_json(response)
140140
return json_content['completed']
141141

@@ -153,7 +153,7 @@ def check_if_toc_transformation_is_complete(content, toc, model=None):
153153
Directly return the final JSON structure. Do not output anything else."""
154154

155155
prompt = prompt + '\n Raw Table of contents:\n' + content + '\n Cleaned Table of contents:\n' + toc
156-
response = ChatGPT_API(model=model, prompt=prompt)
156+
response = llm_completion(model=model, prompt=prompt)
157157
json_content = extract_json(response)
158158
return json_content['completed']
159159

@@ -165,7 +165,7 @@ def extract_toc_content(content, model=None):
165165
166166
Directly return the full table of contents content. Do not output anything else."""
167167

168-
response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt)
168+
response, finish_reason = llm_completion(model=model, prompt=prompt, return_finish_reason=True)
169169

170170
if_complete = check_if_toc_transformation_is_complete(content, response, model)
171171
if if_complete == "yes" and finish_reason == "finished":
@@ -176,7 +176,7 @@ def extract_toc_content(content, model=None):
176176
{"role": "assistant", "content": response},
177177
]
178178
prompt = f"""please continue the generation of table of contents , directly output the remaining part of the structure"""
179-
new_response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt, chat_history=chat_history)
179+
new_response, finish_reason = llm_completion(model=model, prompt=prompt, chat_history=chat_history, return_finish_reason=True)
180180
response = response + new_response
181181
if_complete = check_if_toc_transformation_is_complete(content, response, model)
182182

@@ -193,7 +193,7 @@ def extract_toc_content(content, model=None):
193193
{"role": "assistant", "content": response},
194194
]
195195
prompt = f"""please continue the generation of table of contents , directly output the remaining part of the structure"""
196-
new_response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt, chat_history=chat_history)
196+
new_response, finish_reason = llm_completion(model=model, prompt=prompt, chat_history=chat_history, return_finish_reason=True)
197197
response = response + new_response
198198
if_complete = check_if_toc_transformation_is_complete(content, response, model)
199199

@@ -215,7 +215,7 @@ def detect_page_index(toc_content, model=None):
215215
}}
216216
Directly return the final JSON structure. Do not output anything else."""
217217

218-
response = ChatGPT_API(model=model, prompt=prompt)
218+
response = llm_completion(model=model, prompt=prompt)
219219
json_content = extract_json(response)
220220
return json_content['page_index_given_in_toc']
221221

@@ -264,7 +264,7 @@ def toc_index_extractor(toc, content, model=None):
264264
Directly return the final JSON structure. Do not output anything else."""
265265

266266
prompt = toc_extractor_prompt + '\nTable of contents:\n' + str(toc) + '\nDocument pages:\n' + content
267-
response = ChatGPT_API(model=model, prompt=prompt)
267+
response = llm_completion(model=model, prompt=prompt)
268268
json_content = extract_json(response)
269269
return json_content
270270

@@ -292,15 +292,20 @@ def toc_transformer(toc_content, model=None):
292292
Directly return the final JSON structure, do not output anything else. """
293293

294294
prompt = init_prompt + '\n Given table of contents\n:' + toc_content
295-
last_complete, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt)
295+
last_complete, finish_reason = llm_completion(model=model, prompt=prompt, return_finish_reason=True)
296296
if_complete = check_if_toc_transformation_is_complete(toc_content, last_complete, model)
297297
if if_complete == "yes" and finish_reason == "finished":
298298
last_complete = extract_json(last_complete)
299299
cleaned_response=convert_page_to_int(last_complete['table_of_contents'])
300300
return cleaned_response
301301

302302
last_complete = get_json_content(last_complete)
303+
attempt = 0
304+
max_attempts = 5
303305
while not (if_complete == "yes" and finish_reason == "finished"):
306+
attempt += 1
307+
if attempt > max_attempts:
308+
raise Exception('Failed to complete toc transformation after maximum retries')
304309
position = last_complete.rfind('}')
305310
if position != -1:
306311
last_complete = last_complete[:position+2]
@@ -316,7 +321,7 @@ def toc_transformer(toc_content, model=None):
316321
317322
Please continue the json structure, directly output the remaining part of the json structure."""
318323

319-
new_complete, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt)
324+
new_complete, finish_reason = llm_completion(model=model, prompt=prompt, return_finish_reason=True)
320325

321326
if new_complete.startswith('```json'):
322327
new_complete = get_json_content(new_complete)
@@ -477,7 +482,7 @@ def add_page_number_to_toc(part, structure, model=None):
477482
Directly return the final JSON structure. Do not output anything else."""
478483

479484
prompt = fill_prompt_seq + f"\n\nCurrent Partial Document:\n{part}\n\nGiven Structure\n{json.dumps(structure, indent=2)}\n"
480-
current_json_raw = ChatGPT_API(model=model, prompt=prompt)
485+
current_json_raw = llm_completion(model=model, prompt=prompt)
481486
json_result = extract_json(current_json_raw)
482487

483488
for item in json_result:
@@ -499,7 +504,7 @@ def remove_first_physical_index_section(text):
499504
return text
500505

501506
### add verify completeness
502-
def generate_toc_continue(toc_content, part, model="gpt-4o-2024-11-20"):
507+
def generate_toc_continue(toc_content, part, model=None):
503508
print('start generate_toc_continue')
504509
prompt = """
505510
You are an expert in extracting hierarchical tree structure.
@@ -527,7 +532,7 @@ def generate_toc_continue(toc_content, part, model="gpt-4o-2024-11-20"):
527532
Directly return the additional part of the final JSON structure. Do not output anything else."""
528533

529534
prompt = prompt + '\nGiven text\n:' + part + '\nPrevious tree structure\n:' + json.dumps(toc_content, indent=2)
530-
response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt)
535+
response, finish_reason = llm_completion(model=model, prompt=prompt, return_finish_reason=True)
531536
if finish_reason == 'finished':
532537
return extract_json(response)
533538
else:
@@ -561,7 +566,7 @@ def generate_toc_init(part, model=None):
561566
Directly return the final JSON structure. Do not output anything else."""
562567

563568
prompt = prompt + '\nGiven text\n:' + part
564-
response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt)
569+
response, finish_reason = llm_completion(model=model, prompt=prompt, return_finish_reason=True)
565570

566571
if finish_reason == 'finished':
567572
return extract_json(response)
@@ -732,7 +737,7 @@ def check_toc(page_list, opt=None):
732737

733738

734739
################### fix incorrect toc #########################################################
735-
def single_toc_item_index_fixer(section_title, content, model="gpt-4o-2024-11-20"):
740+
async def single_toc_item_index_fixer(section_title, content, model=None):
736741
toc_extractor_prompt = """
737742
You are given a section title and several pages of a document, your job is to find the physical index of the start page of the section in the partial document.
738743
@@ -746,7 +751,7 @@ def single_toc_item_index_fixer(section_title, content, model="gpt-4o-2024-11-20
746751
Directly return the final JSON structure. Do not output anything else."""
747752

748753
prompt = toc_extractor_prompt + '\nSection Title:\n' + str(section_title) + '\nDocument pages:\n' + content
749-
response = ChatGPT_API(model=model, prompt=prompt)
754+
response = await llm_acompletion(model=model, prompt=prompt)
750755
json_content = extract_json(response)
751756
return convert_physical_index_to_int(json_content['physical_index'])
752757

@@ -815,7 +820,7 @@ async def process_and_check_item(incorrect_item):
815820
continue
816821
content_range = ''.join(page_contents)
817822

818-
physical_index_int = single_toc_item_index_fixer(incorrect_item['title'], content_range, model)
823+
physical_index_int = await single_toc_item_index_fixer(incorrect_item['title'], content_range, model)
819824

820825
# Check if the result is correct
821826
check_item = incorrect_item.copy()
@@ -1069,7 +1074,7 @@ def page_index_main(doc, opt=None):
10691074
raise ValueError("Unsupported input type. Expected a PDF file path or BytesIO object.")
10701075

10711076
print('Parsing PDF...')
1072-
page_list = get_page_tokens(doc)
1077+
page_list = get_page_tokens(doc, model=opt.model)
10731078

10741079
logger.info({'total_page_number': len(page_list)})
10751080
logger.info({'total_token': sum([page[1] for page in page_list])})

pageindex/utils.py

Lines changed: 31 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
import tiktoken
2-
import openai
1+
import litellm
32
import logging
43
import os
54
from datetime import datetime
@@ -17,95 +16,65 @@
1716
from pathlib import Path
1817
from types import SimpleNamespace as config
1918

20-
CHATGPT_API_KEY = os.getenv("CHATGPT_API_KEY")
19+
# Backward compatibility: support CHATGPT_API_KEY as alias for OPENAI_API_KEY
20+
if not os.getenv("OPENAI_API_KEY") and os.getenv("CHATGPT_API_KEY"):
21+
os.environ["OPENAI_API_KEY"] = os.getenv("CHATGPT_API_KEY")
22+
23+
litellm.drop_params = True
2124

2225
def count_tokens(text, model=None):
2326
if not text:
2427
return 0
25-
enc = tiktoken.encoding_for_model(model)
26-
tokens = enc.encode(text)
27-
return len(tokens)
28+
return litellm.token_counter(model=model, text=text)
29+
2830

29-
def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None):
31+
def llm_completion(model, prompt, chat_history=None, return_finish_reason=False):
3032
max_retries = 10
31-
client = openai.OpenAI(api_key=api_key)
33+
messages = list(chat_history) + [{"role": "user", "content": prompt}] if chat_history else [{"role": "user", "content": prompt}]
3234
for i in range(max_retries):
3335
try:
34-
if chat_history:
35-
messages = chat_history
36-
messages.append({"role": "user", "content": prompt})
37-
else:
38-
messages = [{"role": "user", "content": prompt}]
39-
40-
response = client.chat.completions.create(
36+
response = litellm.completion(
4137
model=model,
4238
messages=messages,
4339
temperature=0,
4440
)
45-
if response.choices[0].finish_reason == "length":
46-
return response.choices[0].message.content, "max_output_reached"
47-
else:
48-
return response.choices[0].message.content, "finished"
49-
41+
content = response.choices[0].message.content
42+
if return_finish_reason:
43+
finish_reason = "max_output_reached" if response.choices[0].finish_reason == "length" else "finished"
44+
return content, finish_reason
45+
return content
5046
except Exception as e:
5147
print('************* Retrying *************')
5248
logging.error(f"Error: {e}")
5349
if i < max_retries - 1:
54-
time.sleep(1) # Wait for 1秒 before retrying
50+
time.sleep(1)
5551
else:
5652
logging.error('Max retries reached for prompt: ' + prompt)
57-
return "", "error"
53+
if return_finish_reason:
54+
return "", "error"
55+
return ""
5856

5957

6058

61-
def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None):
59+
async def llm_acompletion(model, prompt):
6260
max_retries = 10
63-
client = openai.OpenAI(api_key=api_key)
61+
messages = [{"role": "user", "content": prompt}]
6462
for i in range(max_retries):
6563
try:
66-
if chat_history:
67-
messages = chat_history
68-
messages.append({"role": "user", "content": prompt})
69-
else:
70-
messages = [{"role": "user", "content": prompt}]
71-
72-
response = client.chat.completions.create(
64+
response = await litellm.acompletion(
7365
model=model,
7466
messages=messages,
7567
temperature=0,
7668
)
77-
7869
return response.choices[0].message.content
7970
except Exception as e:
8071
print('************* Retrying *************')
8172
logging.error(f"Error: {e}")
8273
if i < max_retries - 1:
83-
time.sleep(1) # Wait for 1秒 before retrying
84-
else:
85-
logging.error('Max retries reached for prompt: ' + prompt)
86-
return "Error"
87-
88-
89-
async def ChatGPT_API_async(model, prompt, api_key=CHATGPT_API_KEY):
90-
max_retries = 10
91-
messages = [{"role": "user", "content": prompt}]
92-
for i in range(max_retries):
93-
try:
94-
async with openai.AsyncOpenAI(api_key=api_key) as client:
95-
response = await client.chat.completions.create(
96-
model=model,
97-
messages=messages,
98-
temperature=0,
99-
)
100-
return response.choices[0].message.content
101-
except Exception as e:
102-
print('************* Retrying *************')
103-
logging.error(f"Error: {e}")
104-
if i < max_retries - 1:
105-
await asyncio.sleep(1) # Wait for 1s before retrying
74+
await asyncio.sleep(1)
10675
else:
10776
logging.error('Max retries reached for prompt: ' + prompt)
108-
return "Error"
77+
return ""
10978

11079

11180
def get_json_content(response):
@@ -410,15 +379,14 @@ def add_preface_if_needed(data):
410379

411380

412381

413-
def get_page_tokens(pdf_path, model="gpt-4o-2024-11-20", pdf_parser="PyPDF2"):
414-
enc = tiktoken.encoding_for_model(model)
382+
def get_page_tokens(pdf_path, model=None, pdf_parser="PyPDF2"):
415383
if pdf_parser == "PyPDF2":
416384
pdf_reader = PyPDF2.PdfReader(pdf_path)
417385
page_list = []
418386
for page_num in range(len(pdf_reader.pages)):
419387
page = pdf_reader.pages[page_num]
420388
page_text = page.extract_text()
421-
token_length = len(enc.encode(page_text))
389+
token_length = litellm.token_counter(model=model, text=page_text)
422390
page_list.append((page_text, token_length))
423391
return page_list
424392
elif pdf_parser == "PyMuPDF":
@@ -430,7 +398,7 @@ def get_page_tokens(pdf_path, model="gpt-4o-2024-11-20", pdf_parser="PyPDF2"):
430398
page_list = []
431399
for page in doc:
432400
page_text = page.get_text()
433-
token_length = len(enc.encode(page_text))
401+
token_length = litellm.token_counter(model=model, text=page_text)
434402
page_list.append((page_text, token_length))
435403
return page_list
436404
else:
@@ -533,7 +501,7 @@ def remove_structure_text(data):
533501
def check_token_limit(structure, limit=110000):
534502
list = structure_to_list(structure)
535503
for node in list:
536-
num_tokens = count_tokens(node['text'], model='gpt-4o')
504+
num_tokens = count_tokens(node['text'], model=None)
537505
if num_tokens > limit:
538506
print(f"Node ID: {node['node_id']} has {num_tokens} tokens")
539507
print("Start Index:", node['start_index'])
@@ -609,7 +577,7 @@ async def generate_node_summary(node, model=None):
609577
610578
Directly return the description, do not include any other text.
611579
"""
612-
response = await ChatGPT_API_async(model, prompt)
580+
response = await llm_acompletion(model, prompt)
613581
return response
614582

615583

@@ -654,7 +622,7 @@ def generate_doc_description(structure, model=None):
654622
655623
Directly return the description, do not include any other text.
656624
"""
657-
response = ChatGPT_API(model, prompt)
625+
response = llm_completion(model, prompt)
658626
return response
659627

660628

requirements.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
openai==1.101.0
1+
litellm==1.82.0
22
pymupdf==1.26.4
33
PyPDF2==3.0.1
44
python-dotenv==1.1.0
5-
tiktoken==0.11.0
65
pyyaml==6.0.2

0 commit comments

Comments
 (0)