|
1 | 1 | import json |
2 | 2 | import time |
3 | 3 | from concurrent import futures |
4 | | -from typing import Generator, Mapping |
| 4 | +from typing import AsyncGenerator, Generator, Mapping |
5 | 5 |
|
6 | 6 | import grpc |
7 | 7 | import pytest |
@@ -141,6 +141,16 @@ def weaviate_client( |
141 | 141 | client.close() |
142 | 142 |
|
143 | 143 |
|
| 144 | +@pytest.fixture(scope="function") |
| 145 | +async def weaviate_client_async( |
| 146 | + weaviate_mock: HTTPServer, start_grpc_server: grpc.Server |
| 147 | +) -> AsyncGenerator[weaviate.WeaviateAsyncClient, None]: |
| 148 | + client = weaviate.use_async_with_local(port=MOCK_PORT, host=MOCK_IP, grpc_port=MOCK_PORT_GRPC) |
| 149 | + await client.connect() |
| 150 | + yield client |
| 151 | + await client.close() |
| 152 | + |
| 153 | + |
144 | 154 | @pytest.fixture(scope="function") |
145 | 155 | def weaviate_timeouts_client( |
146 | 156 | weaviate_timeouts_mock: HTTPServer, start_grpc_server: grpc.Server |
@@ -368,3 +378,39 @@ def forbidden( |
368 | 378 | service = MockForbiddenWeaviateService() |
369 | 379 | weaviate_pb2_grpc.add_WeaviateServicer_to_server(service, start_grpc_server) |
370 | 380 | return weaviate_client.collections.use("ForbiddenCollection") |
| 381 | + |
| 382 | + |
| 383 | +class MockWeaviateService(weaviate_pb2_grpc.WeaviateServicer): |
| 384 | + def BatchStream( |
| 385 | + self, |
| 386 | + request_iterator: Generator[batch_pb2.BatchStreamRequest, None, None], |
| 387 | + context: grpc.ServicerContext, |
| 388 | + ) -> Generator[batch_pb2.BatchStreamReply, None, None]: |
| 389 | + while True: |
| 390 | + if context.is_active(): |
| 391 | + time.sleep(0.1) |
| 392 | + else: |
| 393 | + raise grpc.RpcError(grpc.StatusCode.DEADLINE_EXCEEDED, "Deadline exceeded") |
| 394 | + |
| 395 | + |
| 396 | +@pytest.fixture(scope="function") |
| 397 | +def stream_cancel( |
| 398 | + weaviate_client: weaviate.WeaviateClient, |
| 399 | + weaviate_mock: HTTPServer, |
| 400 | + start_grpc_server: grpc.Server, |
| 401 | +) -> Generator[weaviate.collections.Collection, None, None]: |
| 402 | + name = "StreamCancelCollection" |
| 403 | + weaviate_mock.expect_request(f"/v1/schema/{name}").respond_with_response( |
| 404 | + Response(status=404) |
| 405 | + ) # skips __create_batch_reset vectorizer logic |
| 406 | + weaviate_pb2_grpc.add_WeaviateServicer_to_server(MockWeaviateService(), start_grpc_server) |
| 407 | + client = weaviate.connect_to_local( |
| 408 | + port=MOCK_PORT, |
| 409 | + host=MOCK_IP, |
| 410 | + grpc_port=MOCK_PORT_GRPC, |
| 411 | + additional_config=weaviate.classes.init.AdditionalConfig( |
| 412 | + timeout=weaviate.classes.init.Timeout(insert=1) |
| 413 | + ), |
| 414 | + ) |
| 415 | + yield client.collections.use(name) |
| 416 | + client.close() |
0 commit comments