Skip to content

Commit 70e4cbc

Browse files
authored
switch openai to lazy import (#131)
1 parent 9bfa6e5 commit 70e4cbc

1 file changed

Lines changed: 5 additions & 4 deletions

File tree

angelslim/compressor/speculative/train/data/data_generation.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77

88
import tqdm
99
from datasets import load_dataset
10-
from openai import OpenAI
10+
11+
from angelslim.utils.lazy_imports import openai
1112

1213
from .data_utils import convert_sharegpt_data, convert_ultrachat_data
1314

@@ -36,7 +37,7 @@ def _initialize_clients(self, base_port: int, max_clients: int) -> None:
3637
"""Initialize available OpenAI clients."""
3738
for i in range(max_clients):
3839
base_url = f"http://localhost:{base_port + i}/v1"
39-
client = OpenAI(base_url=base_url, api_key="EMPTY")
40+
client = openai.OpenAI(base_url=base_url, api_key="EMPTY")
4041

4142
try:
4243
model_id = client.models.list().data[0].id
@@ -51,7 +52,7 @@ def _initialize_clients(self, base_port: int, max_clients: int) -> None:
5152

5253
logger.info(f"Initialized {len(self.clients)} clients")
5354

54-
def get_client(self, idx: int) -> OpenAI:
55+
def get_client(self, idx: int):
5556
"""Get a client using round-robin load balancing."""
5657
return self.clients[idx % len(self.clients)]
5758

@@ -101,7 +102,7 @@ def _convert_messages(
101102
return converted_messages, messages
102103

103104
def _generate_response(
104-
self, client: OpenAI, messages: List[Dict], **kwargs
105+
self, client, messages: List[Dict], **kwargs
105106
) -> Optional[str]:
106107
"""
107108
Generate a response using the OpenAI API.

0 commit comments

Comments
 (0)