Skip to content

Commit b0f1ea5

Browse files
fix: fix list models (#46)
1 parent ac199b2 commit b0f1ea5

2 files changed

Lines changed: 17 additions & 4 deletions

File tree

serving_utils/client.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,11 @@ def parse_predict_response(response):
180180
return results
181181

182182
def list_models(self):
183-
stub = list_models_pb2_grpc.ListModelsStub(self._channel)
183+
try:
184+
_, conn = next(iter(self._pool))
185+
except StopIteration:
186+
raise EmptyPool("no connections")
187+
stub = list_models_pb2_grpc.ListModelsStub(conn.sync_channel)
184188
response = stub.ListModels(list_models_pb2.ListModelsRequest())
185189
return response.models
186190

tests/test_integration.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import time
44
import pytest
55

6+
import grpc
67
import numpy as np
78
from serving_utils import Client, PredictInput
89

@@ -32,21 +33,29 @@ async def test_client():
3233
continue
3334
else:
3435
break
35-
# test client predict correct result
36-
# fake model is generated from `train_for_test.py`
36+
3737
clients = []
3838
for serving_port in test_serving_ports:
3939
clients.append(Client(
4040
host="localhost",
4141
port=serving_port,
4242
))
4343

44+
# fake model is generated from `train_for_test.py`
45+
model_name = 'test_model'
46+
47+
# test client list_models
48+
for client in clients:
49+
with pytest.raises(grpc.RpcError) as e:
50+
client.list_models()
51+
assert e.code() == grpc.StatusCode.UNIMPLEMENTED
52+
53+
# test client predict correct result
4454
req_data = [
4555
PredictInput(name='a', value=np.int16(2)),
4656
PredictInput(name='b', value=np.int16(3)),
4757
]
4858
output_names = ['c']
49-
model_name = 'test_model'
5059
expected_output = {'c': 8} # c = a + 2 * b
5160
for client in clients:
5261
actual_output = client.predict(

0 commit comments

Comments
 (0)