Skip to content

Commit 20cda67

Browse files
committed
context management for servers
1 parent dd7b2b2 commit 20cda67

2 files changed

Lines changed: 80 additions & 64 deletions

File tree

yaqd-core/yaqd_core/_protocol.py

Lines changed: 67 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,20 @@ def __init__(self, daemon, *args, **kwargs):
1717

1818
def connection_lost(self, exc):
1919
peername = self.transport.get_extra_info("peername")
20-
self.logger.info(f"Connection lost from {peername} to {self._daemon.name}")
20+
self.logger.info(f"Connection lost from {peername}")
2121
self.task.cancel()
2222
self._daemon._connection_lost(peername)
2323

2424
def connection_made(self, transport):
2525
"""Process an incomming connection."""
2626
peername = transport.get_extra_info("peername")
27-
self.logger.info(f"Connection made from {peername} to {self._daemon.name}")
27+
self.logger.info(f"Connection made from {peername}")
2828
self.transport = transport
2929
self.unpacker = avrorpc.Unpacker(self._avro_protocol)
3030
self._daemon._connection_made(peername)
31-
self.task = asyncio.get_running_loop().create_task(self.process_requests())
31+
self.task = self._daemon._loop.create_task(self.process_requests())
32+
self._daemon._tasks.append(self.task)
33+
self.task.add_done_callback(self._daemon._tasks.remove)
3234

3335
def data_received(self, data):
3436
"""Process an incomming request."""
@@ -38,61 +40,68 @@ def data_received(self, data):
3840
self.unpacker.feed(data)
3941

4042
async def process_requests(self):
41-
async for hs, meta, name, params in self.unpacker:
42-
if hs is not None:
43-
out = bytes(hs)
44-
out = struct.pack(">L", len(out)) + out
45-
self.transport.write(out)
46-
if hs.match == "NONE":
47-
name = ""
43+
try:
44+
async for hs, meta, name, params in self.unpacker:
45+
if hs is not None:
46+
out = bytes(hs)
47+
out = struct.pack(">L", len(out)) + out
48+
self.transport.write(out)
49+
if hs.match == "NONE":
50+
name = ""
4851

49-
out_meta = io.BytesIO()
50-
fastavro.schemaless_writer(
51-
out_meta, {"type": "map", "values": "bytes"}, meta
52-
)
53-
length = out_meta.tell()
54-
self.transport.write(struct.pack(">L", length) + out_meta.getvalue())
55-
self.logger.debug(f"Wrote meta, {meta}, {out_meta.getvalue()}")
56-
try:
57-
response_out = io.BytesIO()
58-
response = None
59-
response_schema = "null"
60-
if name:
61-
fun = getattr(self._daemon, name)
62-
if params is None:
63-
params = []
64-
response = fun(*params)
65-
response_schema = fastavro.parse_schema(
66-
self._avro_protocol["messages"][name].get("response", "null"),
67-
expand=True,
68-
named_schemas=self._named_types,
52+
out_meta = io.BytesIO()
53+
fastavro.schemaless_writer(
54+
out_meta, {"type": "map", "values": "bytes"}, meta
55+
)
56+
length = out_meta.tell()
57+
self.transport.write(struct.pack(">L", length) + out_meta.getvalue())
58+
self.logger.debug(f"Wrote meta, {meta}, {out_meta.getvalue()}")
59+
try:
60+
response_out = io.BytesIO()
61+
response = None
62+
response_schema = "null"
63+
if name:
64+
fun = getattr(self._daemon, name)
65+
if params is None:
66+
params = []
67+
response = fun(*params)
68+
response_schema = fastavro.parse_schema(
69+
self._avro_protocol["messages"][name].get("response", "null"),
70+
expand=True,
71+
named_schemas=self._named_types,
72+
)
73+
# Needed twice for nested types... Probably can be fixed upstream
74+
response_schema = fastavro.parse_schema(
75+
response_schema,
76+
expand=True,
77+
named_schemas=self._named_types,
78+
)
79+
fastavro.schemaless_writer(response_out, response_schema, response)
80+
except Exception as e:
81+
self.logger.error(f"Caught exception: {type(e)} in message {name}")
82+
self.logger.debug(traceback.format_exc())
83+
self.transport.write(struct.pack(">L", 1) + b"\1")
84+
error_out = io.BytesIO()
85+
fastavro.schemaless_writer(error_out, ["string"], repr(e))
86+
length = error_out.tell()
87+
self.transport.write(struct.pack(">L", length) + error_out.getvalue())
88+
else:
89+
self.transport.write(struct.pack(">L", 1) + b"\0")
90+
self.logger.debug(f"Wrote non-error flag")
91+
length = response_out.tell()
92+
self.transport.write(
93+
struct.pack(">L", length) + response_out.getvalue()
6994
)
70-
# Needed twice for nested types... Probably can be fixed upstream
71-
response_schema = fastavro.parse_schema(
72-
response_schema,
73-
expand=True,
74-
named_schemas=self._named_types,
95+
self.logger.debug(
96+
f"Wrote response {response}, {response_out.getvalue()}"
7597
)
76-
fastavro.schemaless_writer(response_out, response_schema, response)
77-
except Exception as e:
78-
self.logger.error(f"Caught exception: {type(e)} in message {name}")
79-
self.logger.debug(traceback.format_exc())
80-
self.transport.write(struct.pack(">L", 1) + b"\1")
81-
error_out = io.BytesIO()
82-
fastavro.schemaless_writer(error_out, ["string"], repr(e))
83-
length = error_out.tell()
84-
self.transport.write(struct.pack(">L", length) + error_out.getvalue())
85-
else:
86-
self.transport.write(struct.pack(">L", 1) + b"\0")
87-
self.logger.debug(f"Wrote non-error flag")
88-
length = response_out.tell()
89-
self.transport.write(
90-
struct.pack(">L", length) + response_out.getvalue()
91-
)
92-
self.logger.debug(
93-
f"Wrote response {response}, {response_out.getvalue()}"
94-
)
95-
self.transport.write(struct.pack(">L", 0))
96-
if name == "shutdown":
97-
self.logger.debug("Closing transport")
98-
self.transport.close()
98+
self.transport.write(struct.pack(">L", 0))
99+
if name == "shutdown":
100+
self.logger.debug("Closing transport")
101+
self.transport.close()
102+
except asyncio.CancelledError as e:
103+
self.logger.debug("task cancellation caught")
104+
await self.unpacker.__aexit__(None, None, None)
105+
self.transport.close()
106+
self.logger.debug(f"file closed? {self.unpacker._file.closed}")
107+
raise e

yaqd-core/yaqd_core/avrorpc/unpacker.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,20 @@ async def __anext__(self):
6767
except (ValueError, struct.error):
6868
await self.new_data.wait()
6969

70+
async def __aexit__(self, exc_type, exc_val, exc_tb):
71+
logger.info("closing")
72+
await asyncio.sleep(0)
73+
self._file.close()
74+
self.buf.close()
75+
7076
def feed(self, data: bytes):
71-
# Must support random access, if it does not, must be fed externally (e.g. TCP)
72-
pos = self._file.tell()
73-
self._file.seek(0, 2)
74-
self._file.write(data)
75-
self._file.seek(pos)
76-
self.new_data.set()
77+
if not self._file.closed:
78+
# Must support random access, if it does not, must be fed externally (e.g. TCP)
79+
pos = self._file.tell()
80+
self._file.seek(0, 2)
81+
self._file.write(data)
82+
self._file.seek(pos)
83+
self.new_data.set()
7784

7885
async def _read_object(self, schema):
7986
schema = fastavro.parse_schema(

0 commit comments

Comments
 (0)