|
| 1 | +"""The grpc-web channel and multicallables. |
| 2 | +
|
| 3 | +:class:`GrpcWebChannel` implements the small slice of the ``grpc.aio`` channel interface |
| 4 | +that ``weaviate``'s generated stub and ``ConnectionV4`` actually use — ``unary_unary``, |
| 5 | +``stream_stream`` and ``close`` — by framing requests as grpc-web and POSTing them via a |
| 6 | +pluggable async sender. It subclasses the shim's ``grpc.aio.Channel`` (:class:`AioChannel`) |
| 7 | +so the ``isinstance(..., grpc.aio.Channel)`` assertions in ``connect/v4.py`` hold. |
| 8 | +
|
| 9 | +Only unary RPCs are supported (Search, Aggregate, TenantsGet, BatchObjects, |
| 10 | +BatchReferences, BatchDelete, and the unary health check). ``stream_stream`` (the bidi |
| 11 | +``BatchStream`` used by opt-in server-side batching) cannot work over grpc-web/fetch and |
| 12 | +raises a clear error. |
| 13 | +""" |
| 14 | + |
| 15 | +import base64 |
| 16 | +import urllib.parse |
| 17 | +from typing import Any, Callable, Dict, Optional |
| 18 | + |
| 19 | +from ._framing import encode_message, split_response |
| 20 | +from ._sender import Sender, pyfetch_sender |
| 21 | +from ._shim import AioChannel, AioRpcError, StatusCode, status_from_int |
| 22 | + |
| 23 | +# Module-level default sender; overridable for tests / non-browser runtimes. |
| 24 | +_default_sender: Sender = pyfetch_sender |
| 25 | + |
| 26 | + |
| 27 | +def set_sender(sender: Sender) -> None: |
| 28 | + """Override the default async sender used by new channels (tests/integration).""" |
| 29 | + global _default_sender |
| 30 | + _default_sender = sender |
| 31 | + |
| 32 | + |
| 33 | +def get_sender() -> Sender: |
| 34 | + return _default_sender |
| 35 | + |
| 36 | + |
| 37 | +def _encode_timeout(seconds: float) -> str: |
| 38 | + """Encode a timeout as a grpc-timeout header value (``<positive int><unit>``).""" |
| 39 | + millis = max(1, int(seconds * 1000)) |
| 40 | + if millis < 100_000_000: |
| 41 | + return f"{millis}m" |
| 42 | + return f"{max(1, int(seconds))}S" |
| 43 | + |
| 44 | + |
| 45 | +def _fold_metadata(headers: Dict[str, str], metadata: Any) -> None: |
| 46 | + """Fold gRPC call metadata (``[(key, value), ...]``) into fetch headers. |
| 47 | +
|
| 48 | + Binary ``-bin`` keys are base64-encoded as grpc-web requires. |
| 49 | + """ |
| 50 | + if not metadata: |
| 51 | + return |
| 52 | + for key, value in metadata: |
| 53 | + name = key.lower() |
| 54 | + if name.endswith("-bin"): |
| 55 | + raw = value if isinstance(value, (bytes, bytearray)) else str(value).encode() |
| 56 | + headers[name] = base64.b64encode(raw).decode("ascii") |
| 57 | + else: |
| 58 | + headers[name] = value if isinstance(value, str) else str(value) |
| 59 | + |
| 60 | + |
| 61 | +def _header_lookup(headers: Dict[str, str], name: str) -> Optional[str]: |
| 62 | + target = name.lower() |
| 63 | + for key, value in headers.items(): |
| 64 | + if key.lower() == target: |
| 65 | + return value |
| 66 | + return None |
| 67 | + |
| 68 | + |
| 69 | +class _UnaryUnaryMultiCallable: |
| 70 | + """Awaitable multicallable bound by ``WeaviateStub.__init__``. |
| 71 | +
|
| 72 | + Called as ``await mc(request, metadata=..., timeout=...)`` (and, for the health |
| 73 | + check, as ``mc(request, timeout=...)`` with no metadata). |
| 74 | + """ |
| 75 | + |
| 76 | + def __init__( |
| 77 | + self, |
| 78 | + channel: "GrpcWebChannel", |
| 79 | + path: str, |
| 80 | + request_serializer: Callable[[Any], bytes], |
| 81 | + response_deserializer: Callable[[bytes], Any], |
| 82 | + ) -> None: |
| 83 | + self._channel = channel |
| 84 | + self._path = path |
| 85 | + self._serialize = request_serializer |
| 86 | + self._deserialize = response_deserializer |
| 87 | + |
| 88 | + async def __call__( |
| 89 | + self, |
| 90 | + request: Any, |
| 91 | + *, |
| 92 | + metadata: Any = None, |
| 93 | + timeout: Optional[float] = None, |
| 94 | + credentials: Any = None, |
| 95 | + wait_for_ready: Any = None, |
| 96 | + compression: Any = None, |
| 97 | + ) -> Any: |
| 98 | + payload = self._serialize(request) |
| 99 | + return await self._channel._unary(self._path, payload, self._deserialize, metadata, timeout) |
| 100 | + |
| 101 | + |
| 102 | +class _UnsupportedStreamMultiCallable: |
| 103 | + """Placeholder for ``stream_stream`` (bidirectional streaming). |
| 104 | +
|
| 105 | + Calling it raises immediately, before the ``async for`` in ``connect/v4.py:1243`` |
| 106 | + begins iterating. |
| 107 | + """ |
| 108 | + |
| 109 | + def __init__(self, path: str) -> None: |
| 110 | + self._path = path |
| 111 | + |
| 112 | + def __call__(self, *args: Any, **kwargs: Any) -> Any: |
| 113 | + raise RuntimeError( |
| 114 | + f"Bidirectional streaming RPC {self._path!r} (server-side batching / " |
| 115 | + "BatchStream) is not supported over grpc-web/fetch. Use insert_many(), or " |
| 116 | + "batch.dynamic() / fixed_size() / rate_limit(), instead of batch.stream()." |
| 117 | + ) |
| 118 | + |
| 119 | + |
| 120 | +class GrpcWebChannel(AioChannel): |
| 121 | + """grpc-web/fetch implementation of the async grpc channel slice the client uses.""" |
| 122 | + |
| 123 | + def __init__( |
| 124 | + self, |
| 125 | + target: Optional[str], |
| 126 | + secure: bool, |
| 127 | + options: Any = None, |
| 128 | + sender: Optional[Sender] = None, |
| 129 | + ) -> None: |
| 130 | + if not target: |
| 131 | + raise ValueError("GrpcWebChannel requires a target (host:port)") |
| 132 | + scheme = "https" if secure else "http" |
| 133 | + self._base_url = f"{scheme}://{target}" |
| 134 | + self._sender: Sender = sender or get_sender() |
| 135 | + |
| 136 | + def unary_unary( |
| 137 | + self, |
| 138 | + method: str, |
| 139 | + request_serializer: Callable[[Any], bytes], |
| 140 | + response_deserializer: Callable[[bytes], Any], |
| 141 | + _registered_method: bool = False, |
| 142 | + ) -> _UnaryUnaryMultiCallable: |
| 143 | + return _UnaryUnaryMultiCallable(self, method, request_serializer, response_deserializer) |
| 144 | + |
| 145 | + def stream_stream( |
| 146 | + self, |
| 147 | + method: str, |
| 148 | + request_serializer: Callable[[Any], bytes], |
| 149 | + response_deserializer: Callable[[bytes], Any], |
| 150 | + _registered_method: bool = False, |
| 151 | + ) -> _UnsupportedStreamMultiCallable: |
| 152 | + return _UnsupportedStreamMultiCallable(method) |
| 153 | + |
| 154 | + async def close(self, grace: Optional[float] = None) -> None: |
| 155 | + # Nothing to tear down: each call is an independent fetch. |
| 156 | + return None |
| 157 | + |
| 158 | + async def _unary( |
| 159 | + self, |
| 160 | + path: str, |
| 161 | + payload: bytes, |
| 162 | + deserialize: Callable[[bytes], Any], |
| 163 | + metadata: Any, |
| 164 | + timeout: Optional[float], |
| 165 | + ) -> Any: |
| 166 | + headers: Dict[str, str] = { |
| 167 | + "content-type": "application/grpc-web+proto", |
| 168 | + "accept": "application/grpc-web+proto", |
| 169 | + "x-grpc-web": "1", |
| 170 | + "x-user-agent": "weaviate-python-grpc-web", |
| 171 | + } |
| 172 | + _fold_metadata(headers, metadata) |
| 173 | + if timeout is not None: |
| 174 | + headers["grpc-timeout"] = _encode_timeout(timeout) |
| 175 | + |
| 176 | + url = self._base_url + path |
| 177 | + status, resp_headers, body = await self._sender( |
| 178 | + url, headers, encode_message(payload), timeout |
| 179 | + ) |
| 180 | + return self._handle_response(status, resp_headers, body, deserialize) |
| 181 | + |
| 182 | + @staticmethod |
| 183 | + def _handle_response( |
| 184 | + http_status: int, |
| 185 | + resp_headers: Dict[str, str], |
| 186 | + body: bytes, |
| 187 | + deserialize: Callable[[bytes], Any], |
| 188 | + ) -> Any: |
| 189 | + messages, trailers = split_response(body) if body else ([], {}) |
| 190 | + |
| 191 | + raw_status = trailers.get("grpc-status") |
| 192 | + if raw_status is None: |
| 193 | + raw_status = _header_lookup(resp_headers, "grpc-status") |
| 194 | + raw_message = ( |
| 195 | + trailers.get("grpc-message") or _header_lookup(resp_headers, "grpc-message") or "" |
| 196 | + ) |
| 197 | + message = urllib.parse.unquote(raw_message) |
| 198 | + |
| 199 | + if raw_status is None: |
| 200 | + if http_status != 200: |
| 201 | + raise AioRpcError( |
| 202 | + code=_status_from_http(http_status), |
| 203 | + details=f"HTTP {http_status} from grpc-web endpoint", |
| 204 | + ) |
| 205 | + code = StatusCode.OK |
| 206 | + else: |
| 207 | + code = status_from_int(int(raw_status)) |
| 208 | + |
| 209 | + if code is not StatusCode.OK: |
| 210 | + raise AioRpcError(code=code, details=message) |
| 211 | + if not messages: |
| 212 | + raise AioRpcError( |
| 213 | + code=StatusCode.INTERNAL, |
| 214 | + details="grpc-web response contained no message frame", |
| 215 | + ) |
| 216 | + return deserialize(messages[0]) |
| 217 | + |
| 218 | + |
| 219 | +def _status_from_http(http_status: int) -> StatusCode: |
| 220 | + """Map an HTTP status to a gRPC status when no grpc-status is present. |
| 221 | +
|
| 222 | + Mirrors the grpc-web spec's HTTP-to-gRPC code mapping. |
| 223 | + """ |
| 224 | + return { |
| 225 | + 400: StatusCode.INTERNAL, |
| 226 | + 401: StatusCode.UNAUTHENTICATED, |
| 227 | + 403: StatusCode.PERMISSION_DENIED, |
| 228 | + 404: StatusCode.UNIMPLEMENTED, |
| 229 | + 429: StatusCode.UNAVAILABLE, |
| 230 | + 502: StatusCode.UNAVAILABLE, |
| 231 | + 503: StatusCode.UNAVAILABLE, |
| 232 | + 504: StatusCode.UNAVAILABLE, |
| 233 | + }.get(http_status, StatusCode.UNKNOWN) |
0 commit comments