Skip to content

Commit a130a0c

Browse files
MaxwellJryaoSuper User
andauthored
fix(charxiv): lazy-init OpenAI client and make model version configurable (#1252)
- Replace module-level OpenAI client with lazy _get_client() to avoid SSLContext pickling errors in datasets.map() multiprocessing - Remove num_proc=1 from dataset.map() (no longer needed) - Add model parameter to get_reasoning_result_gpt() so MODEL_VERSION env var is respected for both descriptive and reasoning grading Co-authored-by: Super User <root@TENCENT64.site>
1 parent 540724a commit a130a0c

2 files changed

Lines changed: 18 additions & 6 deletions

File tree

lmms_eval/tasks/charxiv/reasoning_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
)
1010

1111

12-
def get_reasoning_result_gpt(client, prompt, max_retries=10):
12+
def get_reasoning_result_gpt(client, prompt, model="gpt-4o-2024-05-13", max_retries=10):
1313
curr_retries = 0
1414
max_tokens = 256
1515
while curr_retries < max_retries:
@@ -22,7 +22,7 @@ def get_reasoning_result_gpt(client, prompt, max_retries=10):
2222
"content": prompt,
2323
}
2424
],
25-
model="gpt-4o-2024-05-13",
25+
model=model,
2626
response_format={"type": "json_object"},
2727
n=1,
2828
max_tokens=max_tokens,

lmms_eval/tasks/charxiv/utils.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,17 @@
2222
OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL", "YOUR_OPENAI_BASE_URL")
2323
MODEL_VERSION = os.getenv("MODEL_VERSION", "YOUR_MODEL_VERSION")
2424

25-
client = OpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_BASE_URL)
25+
# Lazy-initialize the OpenAI client to avoid creating an SSLContext at import time.
26+
# An SSLContext cannot be pickled, which breaks dataset.map() multiprocessing even
27+
# with num_proc=1 on newer versions of the `datasets` library.
28+
_client = None
29+
30+
31+
def _get_client():
32+
global _client
33+
if _client is None:
34+
_client = OpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_BASE_URL)
35+
return _client
2636

2737

2838
def charxiv_reasoning_doc_to_text_cot(doc, lmms_eval_specific_kwargs=None):
@@ -53,7 +63,9 @@ def _process_row(example, indice):
5363
example["descriptive_a"] = example[f"descriptive_a{q_number}"]
5464
return {"qid": qid, **example}
5565

56-
dataset = dataset.map(_process_row, with_indices=True, num_proc=1)
66+
# Use num_proc=None (single-process, no pickling) to avoid SSLContext
67+
# serialization errors from the lazy OpenAI client in module globals.
68+
dataset = dataset.map(_process_row, with_indices=True)
5769
return dataset
5870

5971

@@ -99,7 +111,7 @@ def charxiv_descriptive_aggregate_results(results):
99111
queries = build_descriptive_grading_queries(groups)
100112
combined_queries = []
101113
for query in tqdm(queries):
102-
result = get_descriptive_result_gpt(client, query["grading_query"], len(query["resp_keys"]), model=MODEL_VERSION)
114+
result = get_descriptive_result_gpt(_get_client(), query["grading_query"], len(query["resp_keys"]), model=MODEL_VERSION)
103115
# query contains resp_keys, grading_query, extract_answer and score
104116
combined_queries.append({**query, **result})
105117
queries = combined_queries
@@ -131,7 +143,7 @@ def charxiv_reasoning_aggregate_results(results):
131143
resps[result["resp_key"]] = result["resp_value"]
132144
queries = build_reasoning_grading_queries(data, resps)
133145
for figure_id, query in tqdm(queries.items()):
134-
ext, scr = get_reasoning_result_gpt(client, query["grading_query"])
146+
ext, scr = get_reasoning_result_gpt(_get_client(), query["grading_query"], model=MODEL_VERSION)
135147
queries[figure_id]["extracted_answer"] = ext
136148
queries[figure_id]["score"] = scr
137149
queries[figure_id].pop("grading_query")

0 commit comments

Comments
 (0)