Skip to content

Commit caac723

Browse files
claudevdmClaude
andauthored
Fix vllm logging, leaking connections and wait concurrently for futures. (#35053)
* Close openai client/connections. Fix logging. Wait async. * Close sync client. * Trigger postcommit. --------- Co-authored-by: Claude <cvandermerwe@google.com>
1 parent d166724 commit caac723

2 files changed

Lines changed: 45 additions & 53 deletions

File tree

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
22
"comment": "Modify this file in a trivial way to cause this test suite to run.",
3-
"modification": 11
3+
"modification": 12
44
}
55

sdks/python/apache_beam/ml/inference/vllm_inference.py

Lines changed: 44 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
from openai import OpenAI
4040

4141
try:
42+
# VLLM logging config breaks beam logging.
43+
os.environ["VLLM_CONFIGURE_LOGGING"] = "0"
4244
import vllm # pylint: disable=unused-import
4345
logging.info('vllm module successfully imported.')
4446
except ModuleNotFoundError:
@@ -127,7 +129,9 @@ def start_server(self, retries=3):
127129
]
128130
for k, v in self._vllm_server_kwargs.items():
129131
server_cmd.append(f'--{k}')
130-
server_cmd.append(v)
132+
# Only add values for commands with value part.
133+
if v is not None:
134+
server_cmd.append(v)
131135
self._server_process, self._server_port = start_process(server_cmd)
132136

133137
self.check_connectivity(retries)
@@ -138,27 +142,27 @@ def get_server_port(self) -> int:
138142
return self._server_port
139143

140144
def check_connectivity(self, retries=3):
141-
client = getVLLMClient(self._server_port)
142-
while self._server_process.poll() is None:
143-
try:
144-
models = client.models.list().data
145-
logging.info('models: %s' % models)
146-
if len(models) > 0:
147-
self._server_started = True
148-
return
149-
except: # pylint: disable=bare-except
150-
pass
151-
# Sleep while bringing up the process
152-
time.sleep(5)
153-
154-
if retries == 0:
155-
self._server_started = False
156-
raise Exception(
157-
"Failed to start vLLM server, polling process exited with code " +
158-
"%s. Next time a request is tried, the server will be restarted" %
159-
self._server_process.poll())
160-
else:
161-
self.start_server(retries - 1)
145+
with getVLLMClient(self._server_port) as client:
146+
while self._server_process.poll() is None:
147+
try:
148+
models = client.models.list().data
149+
logging.info('models: %s' % models)
150+
if len(models) > 0:
151+
self._server_started = True
152+
return
153+
except: # pylint: disable=bare-except
154+
pass
155+
# Sleep while bringing up the process
156+
time.sleep(5)
157+
158+
if retries == 0:
159+
self._server_started = False
160+
raise Exception(
161+
"Failed to start vLLM server, polling process exited with code " +
162+
"%s. Next time a request is tried, the server will be restarted" %
163+
self._server_process.poll())
164+
else:
165+
self.start_server(retries - 1)
162166

163167

164168
class VLLMCompletionsModelHandler(ModelHandler[str,
@@ -200,27 +204,21 @@ async def _async_run_inference(
200204
model: _VLLMModelServer,
201205
inference_args: Optional[dict[str, Any]] = None
202206
) -> Iterable[PredictionResult]:
203-
client = getAsyncVLLMClient(model.get_server_port())
204207
inference_args = inference_args or {}
205-
async_predictions = []
206-
for prompt in batch:
207-
try:
208-
completion = client.completions.create(
209-
model=self._model_name, prompt=prompt, **inference_args)
210-
async_predictions.append(completion)
211-
except Exception as e:
212-
model.check_connectivity()
213-
raise e
214208

215-
predictions = []
216-
for p in async_predictions:
209+
async with getAsyncVLLMClient(model.get_server_port()) as client:
217210
try:
218-
predictions.append(await p)
211+
async_predictions = [
212+
client.completions.create(
213+
model=self._model_name, prompt=prompt, **inference_args)
214+
for prompt in batch
215+
]
216+
responses = await asyncio.gather(*async_predictions)
219217
except Exception as e:
220218
model.check_connectivity()
221219
raise e
222220

223-
return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
221+
return [PredictionResult(x, y) for x, y in zip(batch, responses)]
224222

225223
def run_inference(
226224
self,
@@ -301,25 +299,19 @@ async def _async_run_inference(
301299
model: _VLLMModelServer,
302300
inference_args: Optional[dict[str, Any]] = None
303301
) -> Iterable[PredictionResult]:
304-
client = getAsyncVLLMClient(model.get_server_port())
305302
inference_args = inference_args or {}
306-
async_predictions = []
307-
for messages in batch:
308-
formatted = []
309-
for message in messages:
310-
formatted.append({"role": message.role, "content": message.content})
311-
try:
312-
completion = client.chat.completions.create(
313-
model=self._model_name, messages=formatted, **inference_args)
314-
async_predictions.append(completion)
315-
except Exception as e:
316-
model.check_connectivity()
317-
raise e
318303

319-
predictions = []
320-
for p in async_predictions:
304+
async with getAsyncVLLMClient(model.get_server_port()) as client:
321305
try:
322-
predictions.append(await p)
306+
async_predictions = [
307+
client.chat.completions.create(
308+
model=self._model_name,
309+
messages=[{
310+
"role": message.role, "content": message.content
311+
} for message in messages],
312+
**inference_args) for messages in batch
313+
]
314+
predictions = await asyncio.gather(*async_predictions)
323315
except Exception as e:
324316
model.check_connectivity()
325317
raise e

0 commit comments

Comments
 (0)