Skip to content

Commit 04d4896

Browse files
committed
Add first version of request headers
1 parent 0e6d13e commit 04d4896

7 files changed

Lines changed: 241 additions & 54 deletions

File tree

src/purerpc/client.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,17 @@ def __init__(self, service_name: str, channel: Channel):
2525
self.service_name = service_name
2626
self.channel = channel
2727

28-
async def rpc(self, method_name: str, request_type, response_type):
28+
async def rpc(self, method_name: str, request_type, response_type, metadata=None):
2929
if self.channel.grpc_socket is None:
3030
await self.channel.connect()
3131
message_type = request_type.DESCRIPTOR.full_name
32+
if metadata is None:
33+
metadata = ()
3234
stream = await self.channel.grpc_socket.start_request("http", self.service_name,
3335
method_name, message_type,
3436
"{}:{}".format(self.channel.host,
35-
self.channel.port))
37+
self.channel.port),
38+
custom_metadata=metadata)
3639
stream.expect_message_type(response_type)
3740
return stream
3841

src/purerpc/grpclib/connection.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import h2.exceptions
99
from h2.settings import SettingCodes
1010

11+
from .headers import HeaderDict, sanitize_headers
1112
from .status import Status
1213
from .config import GRPCConfiguration
1314
from .events import MessageReceived, RequestReceived, RequestEnded, ResponseReceived, ResponseEnded
@@ -75,13 +76,13 @@ def _request_received(self, event: h2.events.RequestReceived):
7576
if event.stream_ended:
7677
raise ProtocolError("Stream ended before data was sent")
7778
request = RequestReceived.parse_from_stream_id_and_headers_destructive(
78-
event.stream_id, dict(event.headers))
79+
event.stream_id, HeaderDict(event.headers))
7980
self.message_read_buffers[event.stream_id] = MessageReadBuffer(request.message_encoding,
8081
self.config.max_message_length)
8182
return [request]
8283

8384
def _response_received(self, event: h2.events.ResponseReceived):
84-
headers = dict(event.headers)
85+
headers = HeaderDict(event.headers)
8586
response_received = ResponseReceived.parse_from_stream_id_and_headers_destructive(
8687
event.stream_id, headers)
8788
if event.stream_ended:
@@ -97,7 +98,7 @@ def _response_received(self, event: h2.events.ResponseReceived):
9798

9899
def _trailers_received(self, event: h2.events.TrailersReceived):
99100
response_ended = ResponseEnded.parse_from_stream_id_and_headers_destructive(
100-
event.stream_id, dict(event.headers))
101+
event.stream_id, HeaderDict(event.headers))
101102
return [response_ended]
102103

103104
def _informational_response_received(self, event: h2.events.InformationalResponseReceived):
@@ -228,7 +229,6 @@ def start_request(self, stream_id: int, scheme: str, service_name: str, method_n
228229
(":path", "/{}/{}".format(service_name, method_name)),
229230
("te", "trailers"),
230231
("content-type", "application/grpc" + content_type_suffix),
231-
*custom_metadata
232232
]
233233
if authority is not None:
234234
headers.insert(3, (":authority", authority))
@@ -250,6 +250,7 @@ def start_request(self, stream_id: int, scheme: str, service_name: str, method_n
250250
headers.append(("grpc-accept-encoding", self.config._message_accept_encoding))
251251
if self.config._user_agent is not None:
252252
headers.append(("user-agent", self.config._user_agent))
253+
headers.extend(sanitize_headers(custom_metadata))
253254
self.h2_connection.send_headers(stream_id, headers, end_stream=False)
254255

255256
def end_request(self, stream_id: int):
@@ -259,12 +260,12 @@ def start_response(self, stream_id: int, content_type_suffix="", custom_metadata
259260
headers = [
260261
(":status", "200"),
261262
("content-type", "application/grpc" + content_type_suffix),
262-
*custom_metadata,
263263
]
264264
if self.config._message_encoding is not None:
265265
headers.append(("grpc-encoding", self.config._message_encoding))
266266
if self.config._message_accept_encoding is not None:
267267
headers.append(("grpc-accept-encoding", self.config._message_accept_encoding))
268+
headers.extend(sanitize_headers(custom_metadata))
268269
self.h2_connection.send_headers(stream_id, headers, end_stream=False)
269270

270271
def respond_status(self, stream_id: int, status: Status, content_type_suffix="",
@@ -273,19 +274,19 @@ def respond_status(self, stream_id: int, status: Status, content_type_suffix="",
273274
(":status", "200"),
274275
("content-type", "application/grpc" + content_type_suffix),
275276
("grpc-status", str(status.int_value)),
276-
*custom_metadata,
277277
]
278278
if status.status_message:
279279
# TODO: should be percent encoded
280280
trailers.append(("grpc-message", status.status_message))
281+
trailers.extend(sanitize_headers(custom_metadata))
281282
self.h2_connection.send_headers(stream_id, trailers, end_stream=True)
282283

283284
def end_response(self, stream_id: int, status: Status, custom_metadata=()):
284285
trailers = [
285286
("grpc-status", str(status.int_value)),
286-
*custom_metadata,
287287
]
288288
if status.status_message:
289289
# TODO: should be percent encoded
290290
trailers.append(("grpc-message", status.status_message))
291+
trailers.extend(sanitize_headers(custom_metadata))
291292
self.h2_connection.send_headers(stream_id, trailers, end_stream=True)

src/purerpc/grpclib/events.py

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
import base64
21
import datetime
2+
3+
from .headers import HeaderDict
34
from .exceptions import ProtocolError
45
from .status import Status
56

@@ -22,10 +23,10 @@ def __init__(self, stream_id: int, scheme: str, service_name: str, method_name:
2223
self.message_encoding = None
2324
self.message_accept_encoding = None
2425
self.user_agent = None
25-
self.custom_metadata = {}
26+
self.custom_metadata = ()
2627

2728
@staticmethod
28-
def parse_from_stream_id_and_headers_destructive(stream_id: int, headers: dict):
29+
def parse_from_stream_id_and_headers_destructive(stream_id: int, headers: HeaderDict):
2930
if headers.pop(":method") != "POST":
3031
raise ProtocolError("Unsupported method {}".format(headers[":method"]))
3132

@@ -83,11 +84,8 @@ def parse_from_stream_id_and_headers_destructive(stream_id: int, headers: dict):
8384
if "grpc-message-type" in headers:
8485
event.message_type = headers.pop("grpc-message-type")
8586

86-
for header_name in list(headers.keys()):
87-
if header_name.endswith("-bin"):
88-
event.custom_metadata[header_name] = base64.b64decode(headers.pop(header_name))
89-
else:
90-
event.custom_metadata[header_name] = headers.pop(header_name)
87+
event.custom_metadata = tuple(header for header_name in list(headers.keys())
88+
for header in headers.extract_headers(header_name))
9189
return event
9290

9391

@@ -109,10 +107,10 @@ def __init__(self, stream_id: int, content_type: str):
109107
self.content_type = content_type
110108
self.message_encoding = None
111109
self.message_accept_encoding = None
112-
self.custom_metadata = {}
110+
self.custom_metadata = ()
113111

114112
@staticmethod
115-
def parse_from_stream_id_and_headers_destructive(stream_id: int, headers: dict):
113+
def parse_from_stream_id_and_headers_destructive(stream_id: int, headers: HeaderDict):
116114
if int(headers.pop(":status")) != 200:
117115
raise ProtocolError("http status is not 200")
118116

@@ -128,25 +126,19 @@ def parse_from_stream_id_and_headers_destructive(stream_id: int, headers: dict):
128126
if "grpc-accept-encoding" in headers:
129127
event.message_accept_encoding = headers.pop("grpc-accept-encoding").split(",")
130128

131-
for header_name in list(headers.keys()):
132-
if header_name in ["grpc-status", "grpc-message"]:
133-
# is not metadata, will be parsed in ResponseEnded
134-
continue
135-
if header_name.endswith("-bin"):
136-
event.custom_metadata[header_name] = base64.b64decode(headers.pop(header_name))
137-
else:
138-
event.custom_metadata[header_name] = headers.pop(header_name)
129+
event.custom_metadata = tuple(header for header_name in list(headers.keys())
130+
for header in headers.extract_headers(header_name))
139131
return event
140132

141133

142134
class ResponseEnded(Event):
143135
def __init__(self, stream_id: int, status: Status):
144136
self.stream_id = stream_id
145137
self.status = status
146-
self.custom_metadata = {}
138+
self.custom_metadata = ()
147139

148140
@staticmethod
149-
def parse_from_stream_id_and_headers_destructive(stream_id: int, headers: dict):
141+
def parse_from_stream_id_and_headers_destructive(stream_id: int, headers: HeaderDict):
150142
if "grpc-status" not in headers:
151143
raise ProtocolError("Expected grpc-status in trailers")
152144

@@ -159,9 +151,6 @@ def parse_from_stream_id_and_headers_destructive(stream_id: int, headers: dict):
159151

160152
event = ResponseEnded(stream_id, Status(status_code, status_message))
161153

162-
for header_name in list(headers.keys()):
163-
if header_name.endswith("-bin"):
164-
event.custom_metadata[header_name] = base64.b64decode(headers.pop(header_name))
165-
else:
166-
event.custom_metadata[header_name] = headers.pop(header_name)
154+
event.custom_metadata = tuple(header for header_name in list(headers.keys())
155+
for header in headers.extract_headers(header_name))
167156
return event

src/purerpc/grpclib/headers.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import base64
2+
3+
4+
class HeaderDict(dict):
5+
def __init__(self, values):
6+
super().__init__()
7+
for key, value in values:
8+
if key not in self:
9+
self[key] = [value]
10+
else:
11+
self[key].append(value)
12+
for key in self:
13+
if len(self[key]) == 1:
14+
self[key] = self[key][0]
15+
16+
def extract_headers(self, header_name: str):
17+
"""Returns all headers with name == header_name as list of tuples (name, value)"""
18+
if header_name.startswith("grpc-"):
19+
return ()
20+
else:
21+
value = self.pop(header_name)
22+
is_binary = header_name.endswith("-bin")
23+
24+
if not isinstance(value, list):
25+
value_list = [value]
26+
else:
27+
value_list = value
28+
29+
if is_binary:
30+
return ((header_name, b64decode(value)) for value_sublist in value_list for value
31+
in value_sublist.split(","))
32+
else:
33+
return ((header_name, value) for value in value_list)
34+
35+
36+
def sanitize_headers(headers):
37+
for name, value in headers:
38+
if isinstance(value, bytes) and not name.endswith("-bin"):
39+
raise ValueError(f"Got binary value for header name '{name}', but name does not end "
40+
f"with '-bin' suffix")
41+
if name.startswith("grpc-"):
42+
raise ValueError(f"Got header with name '{name}', but custom metadata headers should "
43+
f"not start with 'grpc-' prefix")
44+
if name.endswith("-bin"):
45+
yield name, b64encode(value)
46+
else:
47+
yield name, value
48+
49+
50+
def b64decode(data: str) -> bytes:
51+
# Apply missing padding
52+
missing_padding = len(data) % 4
53+
if missing_padding:
54+
data += "=" * (4 - missing_padding)
55+
return base64.b64decode(data)
56+
57+
58+
def b64encode(data: bytes) -> str:
59+
return base64.b64encode(data).rstrip(b"=").decode("utf-8")

src/purerpc/server.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
import inspect
33
import warnings
44
import collections
5+
import functools
56
from multiprocessing import Process
67

78
import curio
89
import curio.meta
910
import typing
1011
import logging
1112

13+
from .grpclib.events import RequestReceived
1214
from .grpclib.status import Status, StatusCode
1315
from .grpclib.exceptions import RpcFailedError
1416
from purerpc.grpc_proto import GRPCProtoStream, GRPCProtoSocket
@@ -29,24 +31,37 @@ def __init__(self, name):
2931
self.name = name
3032
self.methods = {}
3133

32-
def add_method(self, method_name: str, method_fn, rpc_signature: RPCSignature):
33-
self.methods[method_name] = BoundRPCMethod(method_fn, rpc_signature)
34+
def add_method(self, method_name: str, method_fn, rpc_signature: RPCSignature,
35+
method_signature: inspect.Signature = None):
36+
if method_signature is None:
37+
method_signature = inspect.signature(method_fn)
38+
if len(method_signature.parameters) == 1:
39+
def method_fn_with_headers(arg, request):
40+
return method_fn(arg)
41+
elif len(method_signature.parameters) == 2:
42+
if list(method_signature.parameters.values())[1].name == "request":
43+
method_fn_with_headers = method_fn
44+
else:
45+
raise ValueError("Expected second parameter 'request'")
46+
else:
47+
raise ValueError("Expected method_fn to have exactly one or two parameters")
48+
self.methods[method_name] = BoundRPCMethod(method_fn_with_headers, rpc_signature)
3449

3550
def rpc(self, method_name):
3651
def decorator(func):
3752
signature = inspect.signature(func)
3853
if signature.return_annotation == signature.empty:
3954
raise ValueError("Only annotated methods can be used with Service.rpc() decorator")
40-
if len(signature.parameters) != 1:
41-
raise ValueError("Only functions with one parameter can be used with Service.rpc("
42-
") decorator")
55+
if len(signature.parameters) not in (1, 2):
56+
raise ValueError("Only functions with one or two parameters can be used with "
57+
"Service.rpc() decorator")
4358
parameter = next(iter(signature.parameters.values()))
4459
if parameter.annotation == parameter.empty:
4560
raise ValueError("Only annotated methods can be used with Service.rpc() decorator")
4661

4762
rpc_signature = RPCSignature.from_annotations(parameter.annotation,
4863
signature.return_annotation)
49-
self.add_method(method_name, func, rpc_signature)
64+
self.add_method(method_name, func, rpc_signature, method_signature=signature)
5065
return func
5166
return decorator
5267

@@ -104,6 +119,10 @@ async def request_received(self, stream: GRPCProtoStream):
104119
await stream.start_response()
105120
event = await stream.receive_event()
106121

122+
if not isinstance(event, RequestReceived):
123+
await stream.close(Status(StatusCode.INTERNAL, status_message="Expected headers"))
124+
return
125+
107126
try:
108127
service = self.server.services[event.service_name]
109128
except KeyError:
@@ -125,7 +144,7 @@ async def request_received(self, stream: GRPCProtoStream):
125144

126145
# TODO: Should at least pass through GeneratorExit
127146
try:
128-
method_fn = bound_rpc_method.method_fn
147+
method_fn = functools.partial(bound_rpc_method.method_fn, request=event)
129148
cardinality = bound_rpc_method.signature.cardinality
130149
stream.expect_message_type(bound_rpc_method.signature.request_type)
131150
if cardinality == Cardinality.STREAM_STREAM:

src/purerpc/wrappers.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -80,40 +80,40 @@ def __init__(self, stream_fn):
8080

8181

8282
class ClientStubUnaryUnary(ClientStub):
83-
async def __call__(self, message):
84-
stream = await self._stream_fn()
83+
async def __call__(self, message, *, metadata=None):
84+
stream = await self._stream_fn(metadata=metadata)
8585
await send_single_message_client(stream, message)
8686
return await extract_message_from_singleton_stream(stream)
8787

8888

8989
class ClientStubUnaryStream(ClientStub):
90-
async def __call__(self, message):
91-
stream = await self._stream_fn()
90+
async def __call__(self, message, *, metadata=None):
91+
stream = await self._stream_fn(metadata=metadata)
9292
await send_single_message_client(stream, message)
9393
async for message in stream_to_async_iterator(stream):
9494
yield message
9595

9696

9797
class ClientStubStreamUnary(ClientStub):
98-
async def __call__(self, message_aiter):
99-
stream = await self._stream_fn()
98+
async def __call__(self, message_aiter, *, metadata=None):
99+
stream = await self._stream_fn(metadata=metadata)
100100
await curio.spawn(send_multiple_messages_client, stream, message_aiter, daemon=True)
101101
return await extract_message_from_singleton_stream(stream)
102102

103103

104104
class ClientStubStreamStream(ClientStub):
105-
async def call_aiter(self, message_aiter):
106-
stream = await self._stream_fn()
105+
async def call_aiter(self, message_aiter, metadata):
106+
stream = await self._stream_fn(metadata=metadata)
107107
if message_aiter is not None:
108108
await curio.spawn(send_multiple_messages_client, stream, message_aiter, daemon=True)
109109
async for message in stream_to_async_iterator(stream):
110110
yield message
111111

112-
async def call_stream(self):
113-
return await self._stream_fn()
112+
async def call_stream(self, metadata):
113+
return await self._stream_fn(metadata=metadata)
114114

115-
def __call__(self, message_aiter=None):
115+
def __call__(self, message_aiter=None, *, metadata=None):
116116
if message_aiter is None:
117-
return self.call_stream()
117+
return self.call_stream(metadata)
118118
else:
119-
return self.call_aiter(message_aiter)
119+
return self.call_aiter(message_aiter, metadata)

0 commit comments

Comments
 (0)