1- import asyncio
21import ssl
32import types
43import typing
54
5+ import trio
66import certifi
77
88from ._streams import Stream
1313
1414class NetworkStream (Stream ):
1515 def __init__ (
16- self , reader : asyncio . StreamReader , writer : asyncio . StreamWriter , address : str = ''
16+ self , trio_stream : trio . abc . Stream , address : str = ''
1717 ) -> None :
18- self ._reader = reader
19- self ._writer = writer
18+ self ._trio_stream = trio_stream
2019 self ._address = address
21- self ._tls = False
2220 self ._closed = False
2321
2422 async def read (self , size : int = - 1 ) -> bytes :
2523 if size < 0 :
2624 size = 64 * 1024
27- return await self ._reader . read (size )
25+ return await self ._trio_stream . receive_some (size )
2826
2927 async def write (self , buffer : bytes ) -> None :
30- self ._writer .write (buffer )
31- await self ._writer .drain ()
28+ await self ._trio_stream .send_all (buffer )
3229
3330 async def close (self ) -> None :
34- if not self ._closed :
35- self ._writer .close ()
36- await self ._writer .wait_closed ()
31+ # Close the NetworkStream.
32+ # If the stream is already closed this is a checkpointed no-op.
33+ try :
34+ await self ._trio_stream .aclose ()
35+ finally :
3736 self ._closed = True
3837
3938 def __repr__ (self ):
4039 description = ""
41- description += " TLS" if self ._tls else ""
4240 description += " CLOSED" if self ._closed else ""
43- return f"<NetworkStream [{ self ._address !r } { description } ]>"
41+ return f"<NetworkStream [{ self ._address } { description } ]>"
4442
4543 def __del__ (self ):
4644 if not self ._closed :
4745 import warnings
48- warnings .warn ("NetworkStream was garbage collected without being closed." )
46+ warnings .warn (f" { self !r } was garbage collected without being closed." )
4947
5048 # Context managed usage...
5149 async def __aenter__ (self ) -> "NetworkStream" :
@@ -61,13 +59,17 @@ async def __aexit__(
6159
6260
6361class NetworkServer :
64- def __init__ (self , host : str , port : int , server : asyncio . Server ):
62+ def __init__ (self , host : str , port : int , handler , listeners : list [ trio . SocketListener ] ):
6563 self .host = host
6664 self .port = port
67- self ._server = server
65+ self ._handler = handler
66+ self ._listeners = listeners
6867
6968 # Context managed usage...
7069 async def __aenter__ (self ) -> "NetworkServer" :
70+ self ._nursery_manager = trio .open_nursery ()
71+ self ._nursery = await self ._nursery_manager .__aenter__ ()
72+ self ._nursery .start_soon (trio .serve_listeners , self ._handler , self ._listeners )
7173 return self
7274
7375 async def __aexit__ (
@@ -76,8 +78,8 @@ async def __aexit__(
7678 exc_value : BaseException | None = None ,
7779 traceback : types .TracebackType | None = None ,
7880 ):
79- self ._server . close ()
80- await self ._server . wait_closed ( )
81+ self ._nursery . cancel_scope . cancel ()
82+ await self ._nursery_manager . __aexit__ ( exc_type , exc_value , traceback )
8183
8284
8385class NetworkBackend :
@@ -92,29 +94,42 @@ async def connect(self, host: str, port: int) -> NetworkStream:
9294 """
9395 Connect to the given address, returning a Stream instance.
9496 """
97+ # Create the TCP stream
9598 address = f"{ host } :{ port } "
96- reader , writer = await asyncio . open_connection (host , port )
97- return NetworkStream (reader , writer , address = address )
99+ trio_stream = await trio . open_tcp_stream (host , port )
100+ return NetworkStream (trio_stream , address = address )
98101
99102 async def connect_tls (self , host : str , port : int , hostname : str = '' ) -> NetworkStream :
100103 """
101104 Connect to the given address, returning a Stream instance.
102105 """
106+ # Create the TCP stream
103107 address = f"{ host } :{ port } "
104- reader , writer = await asyncio .open_connection (host , port )
105- await writer .start_tls (self ._ssl_ctx , server_hostname = hostname )
106- return NetworkStream (reader , writer , address = address )
108+ trio_stream = await trio .open_tcp_stream (host , port )
109+
110+ # Establish SSL over TCP
111+ hostname = hostname or host
112+ ssl_stream = trio .SSLStream (trio_stream , ssl_context = self ._ssl_ctx , server_hostname = hostname )
113+ await ssl_stream .do_handshake ()
114+
115+ return NetworkStream (ssl_stream , address = address )
107116
108117 async def serve (self , host : str , port : int , handler : typing .Callable [[NetworkStream ], None ]) -> NetworkServer :
109- async def callback (reader , writer ):
110- stream = NetworkStream (reader , writer )
111- await handler (stream )
118+ async def callback (trio_stream ):
119+ stream = NetworkStream (trio_stream , address = f"{ host } :{ port } " )
120+ try :
121+ await handler (stream )
122+ finally :
123+ await stream .close ()
112124
113- server = await asyncio .start_server (callback , host , port )
114- return NetworkServer (host , port , server )
125+ listeners = await trio .open_tcp_listeners (port = port , host = host )
126+ return NetworkServer (host , port , callback , listeners )
127+
128+ def __repr__ (self ):
129+ return f"<NetworkBackend [trio]>"
115130
116131
117- Semaphore = asyncio .Semaphore
118- Lock = asyncio .Lock
119- timeout = asyncio . timeout
120- sleep = asyncio .sleep
132+ Semaphore = trio .Semaphore
133+ Lock = trio .Lock
134+ timeout = trio . move_on_after
135+ sleep = trio .sleep
0 commit comments