3939from openai import OpenAI
4040
4141try :
42+ # VLLM logging config breaks beam logging.
43+ os .environ ["VLLM_CONFIGURE_LOGGING" ] = "0"
4244 import vllm # pylint: disable=unused-import
4345 logging .info ('vllm module successfully imported.' )
4446except ModuleNotFoundError :
@@ -127,7 +129,9 @@ def start_server(self, retries=3):
127129 ]
128130 for k , v in self ._vllm_server_kwargs .items ():
129131 server_cmd .append (f'--{ k } ' )
130- server_cmd .append (v )
132+ # Only add values for commands with value part.
133+ if v is not None :
134+ server_cmd .append (v )
131135 self ._server_process , self ._server_port = start_process (server_cmd )
132136
133137 self .check_connectivity (retries )
@@ -138,27 +142,27 @@ def get_server_port(self) -> int:
138142 return self ._server_port
139143
140144 def check_connectivity (self , retries = 3 ):
141- client = getVLLMClient (self ._server_port )
142- while self ._server_process .poll () is None :
143- try :
144- models = client .models .list ().data
145- logging .info ('models: %s' % models )
146- if len (models ) > 0 :
147- self ._server_started = True
148- return
149- except : # pylint: disable=bare-except
150- pass
151- # Sleep while bringing up the process
152- time .sleep (5 )
153-
154- if retries == 0 :
155- self ._server_started = False
156- raise Exception (
157- "Failed to start vLLM server, polling process exited with code " +
158- "%s. Next time a request is tried, the server will be restarted" %
159- self ._server_process .poll ())
160- else :
161- self .start_server (retries - 1 )
145+ with getVLLMClient (self ._server_port ) as client :
146+ while self ._server_process .poll () is None :
147+ try :
148+ models = client .models .list ().data
149+ logging .info ('models: %s' % models )
150+ if len (models ) > 0 :
151+ self ._server_started = True
152+ return
153+ except : # pylint: disable=bare-except
154+ pass
155+ # Sleep while bringing up the process
156+ time .sleep (5 )
157+
158+ if retries == 0 :
159+ self ._server_started = False
160+ raise Exception (
161+ "Failed to start vLLM server, polling process exited with code " +
162+ "%s. Next time a request is tried, the server will be restarted" %
163+ self ._server_process .poll ())
164+ else :
165+ self .start_server (retries - 1 )
162166
163167
164168class VLLMCompletionsModelHandler (ModelHandler [str ,
@@ -200,27 +204,21 @@ async def _async_run_inference(
200204 model : _VLLMModelServer ,
201205 inference_args : Optional [dict [str , Any ]] = None
202206 ) -> Iterable [PredictionResult ]:
203- client = getAsyncVLLMClient (model .get_server_port ())
204207 inference_args = inference_args or {}
205- async_predictions = []
206- for prompt in batch :
207- try :
208- completion = client .completions .create (
209- model = self ._model_name , prompt = prompt , ** inference_args )
210- async_predictions .append (completion )
211- except Exception as e :
212- model .check_connectivity ()
213- raise e
214208
215- predictions = []
216- for p in async_predictions :
209+ async with getAsyncVLLMClient (model .get_server_port ()) as client :
217210 try :
218- predictions .append (await p )
211+ async_predictions = [
212+ client .completions .create (
213+ model = self ._model_name , prompt = prompt , ** inference_args )
214+ for prompt in batch
215+ ]
216+ responses = await asyncio .gather (* async_predictions )
219217 except Exception as e :
220218 model .check_connectivity ()
221219 raise e
222220
223- return [PredictionResult (x , y ) for x , y in zip (batch , predictions )]
221+ return [PredictionResult (x , y ) for x , y in zip (batch , responses )]
224222
225223 def run_inference (
226224 self ,
@@ -301,25 +299,19 @@ async def _async_run_inference(
301299 model : _VLLMModelServer ,
302300 inference_args : Optional [dict [str , Any ]] = None
303301 ) -> Iterable [PredictionResult ]:
304- client = getAsyncVLLMClient (model .get_server_port ())
305302 inference_args = inference_args or {}
306- async_predictions = []
307- for messages in batch :
308- formatted = []
309- for message in messages :
310- formatted .append ({"role" : message .role , "content" : message .content })
311- try :
312- completion = client .chat .completions .create (
313- model = self ._model_name , messages = formatted , ** inference_args )
314- async_predictions .append (completion )
315- except Exception as e :
316- model .check_connectivity ()
317- raise e
318303
319- predictions = []
320- for p in async_predictions :
304+ async with getAsyncVLLMClient (model .get_server_port ()) as client :
321305 try :
322- predictions .append (await p )
306+ async_predictions = [
307+ client .chat .completions .create (
308+ model = self ._model_name ,
309+ messages = [{
310+ "role" : message .role , "content" : message .content
311+ } for message in messages ],
312+ ** inference_args ) for messages in batch
313+ ]
314+ predictions = await asyncio .gather (* async_predictions )
323315 except Exception as e :
324316 model .check_connectivity ()
325317 raise e
0 commit comments