|
22 | 22 | OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL", "YOUR_OPENAI_BASE_URL") |
23 | 23 | MODEL_VERSION = os.getenv("MODEL_VERSION", "YOUR_MODEL_VERSION") |
24 | 24 |
|
25 | | -client = OpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_BASE_URL) |
| 25 | +# Lazy-initialize the OpenAI client to avoid creating an SSLContext at import time. |
| 26 | +# An SSLContext cannot be pickled, which breaks dataset.map() multiprocessing even |
| 27 | +# with num_proc=1 on newer versions of the `datasets` library. |
| 28 | +_client = None |
| 29 | + |
| 30 | + |
| 31 | +def _get_client(): |
| 32 | + global _client |
| 33 | + if _client is None: |
| 34 | + _client = OpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_BASE_URL) |
| 35 | + return _client |
26 | 36 |
|
27 | 37 |
|
28 | 38 | def charxiv_reasoning_doc_to_text_cot(doc, lmms_eval_specific_kwargs=None): |
@@ -53,7 +63,9 @@ def _process_row(example, indice): |
53 | 63 | example["descriptive_a"] = example[f"descriptive_a{q_number}"] |
54 | 64 | return {"qid": qid, **example} |
55 | 65 |
|
56 | | - dataset = dataset.map(_process_row, with_indices=True, num_proc=1) |
| 66 | + # Use num_proc=None (single-process, no pickling) to avoid SSLContext |
| 67 | + # serialization errors from the lazy OpenAI client in module globals. |
| 68 | + dataset = dataset.map(_process_row, with_indices=True) |
57 | 69 | return dataset |
58 | 70 |
|
59 | 71 |
|
@@ -99,7 +111,7 @@ def charxiv_descriptive_aggregate_results(results): |
99 | 111 | queries = build_descriptive_grading_queries(groups) |
100 | 112 | combined_queries = [] |
101 | 113 | for query in tqdm(queries): |
102 | | - result = get_descriptive_result_gpt(client, query["grading_query"], len(query["resp_keys"]), model=MODEL_VERSION) |
| 114 | + result = get_descriptive_result_gpt(_get_client(), query["grading_query"], len(query["resp_keys"]), model=MODEL_VERSION) |
103 | 115 | # query contains resp_keys, grading_query, extract_answer and score |
104 | 116 | combined_queries.append({**query, **result}) |
105 | 117 | queries = combined_queries |
@@ -131,7 +143,7 @@ def charxiv_reasoning_aggregate_results(results): |
131 | 143 | resps[result["resp_key"]] = result["resp_value"] |
132 | 144 | queries = build_reasoning_grading_queries(data, resps) |
133 | 145 | for figure_id, query in tqdm(queries.items()): |
134 | | - ext, scr = get_reasoning_result_gpt(client, query["grading_query"]) |
| 146 | + ext, scr = get_reasoning_result_gpt(_get_client(), query["grading_query"], model=MODEL_VERSION) |
135 | 147 | queries[figure_id]["extracted_answer"] = ext |
136 | 148 | queries[figure_id]["score"] = scr |
137 | 149 | queries[figure_id].pop("grading_query") |
|
0 commit comments