-
Notifications
You must be signed in to change notification settings - Fork 165
Expand file tree
/
Copy pathsync.py
More file actions
249 lines (218 loc) · 8.35 KB
/
sync.py
File metadata and controls
249 lines (218 loc) · 8.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
from __future__ import annotations
import functools
import socket
import ssl
import sys
import typing
from .._exceptions import (
ConnectError,
ConnectTimeout,
ExceptionMapping,
ReadError,
ReadTimeout,
WriteError,
WriteTimeout,
map_exceptions,
)
from .._utils import is_socket_readable
from .base import SOCKET_OPTION, NetworkBackend, NetworkStream
class TLSinTLSStream(NetworkStream): # pragma: no cover
"""
Because the standard `SSLContext.wrap_socket` method does
not work for `SSLSocket` objects, we need this class
to implement TLS stream using an underlying `SSLObject`
instance in order to support TLS on top of TLS.
"""
# Defined in RFC 8449
TLS_RECORD_SIZE = 16384
def __init__(
self,
sock: socket.socket,
ssl_context: ssl.SSLContext,
server_hostname: str | None = None,
timeout: float | None = None,
):
self._sock = sock
self._incoming = ssl.MemoryBIO()
self._outgoing = ssl.MemoryBIO()
self.ssl_obj = ssl_context.wrap_bio(
incoming=self._incoming,
outgoing=self._outgoing,
server_hostname=server_hostname,
)
self._sock.settimeout(timeout)
self._perform_io(self.ssl_obj.do_handshake)
def _perform_io(
self,
func: typing.Callable[..., typing.Any],
) -> typing.Any:
ret = None
while True:
errno = None
try:
ret = func()
except (ssl.SSLWantReadError, ssl.SSLWantWriteError) as e:
errno = e.errno
self._sock.sendall(self._outgoing.read())
if errno == ssl.SSL_ERROR_WANT_READ:
buf = self._sock.recv(self.TLS_RECORD_SIZE)
if buf:
self._incoming.write(buf)
else:
self._incoming.write_eof()
if errno is None:
return ret
def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError}
with map_exceptions(exc_map):
self._sock.settimeout(timeout)
return typing.cast(
bytes, self._perform_io(functools.partial(self.ssl_obj.read, max_bytes))
)
def write(self, buffer: bytes, timeout: float | None = None) -> None:
exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError}
with map_exceptions(exc_map):
self._sock.settimeout(timeout)
# Use a memoryview to avoid O(n²) copying when buffer is sliced on each
# iteration. Plain bytes slicing (buffer = buffer[n:]) creates a full copy;
# memoryview slicing is zero-copy, critical for large upload payloads.
view = memoryview(buffer)
while view:
nsent = self._perform_io(functools.partial(self.ssl_obj.write, view))
view = view[nsent:]
def close(self) -> None:
self._sock.close()
def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: str | None = None,
timeout: float | None = None,
) -> NetworkStream:
raise NotImplementedError()
def get_extra_info(self, info: str) -> typing.Any:
if info == "ssl_object":
return self.ssl_obj
if info == "client_addr":
return self._sock.getsockname()
if info == "server_addr":
return self._sock.getpeername()
if info == "socket":
return self._sock
if info == "is_readable":
return is_socket_readable(self._sock)
return None
class SyncStream(NetworkStream):
def __init__(self, sock: socket.socket) -> None:
self._sock = sock
def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError}
with map_exceptions(exc_map):
self._sock.settimeout(timeout)
return self._sock.recv(max_bytes)
def write(self, buffer: bytes, timeout: float | None = None) -> None:
if not buffer:
return
exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError}
with map_exceptions(exc_map):
# Use a memoryview to avoid O(n²) copying when buffer is sliced on each
# iteration. Plain bytes slicing (buffer = buffer[n:]) creates a full copy;
# memoryview slicing is zero-copy, critical for large upload payloads.
view = memoryview(buffer)
while view:
self._sock.settimeout(timeout)
n = self._sock.send(view)
view = view[n:]
def close(self) -> None:
self._sock.close()
def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: str | None = None,
timeout: float | None = None,
) -> NetworkStream:
exc_map: ExceptionMapping = {
socket.timeout: ConnectTimeout,
OSError: ConnectError,
}
with map_exceptions(exc_map):
try:
if isinstance(self._sock, ssl.SSLSocket): # pragma: no cover
# If the underlying socket has already been upgraded
# to the TLS layer (i.e. is an instance of SSLSocket),
# we need some additional smarts to support TLS-in-TLS.
return TLSinTLSStream(
self._sock, ssl_context, server_hostname, timeout
)
else:
self._sock.settimeout(timeout)
sock = ssl_context.wrap_socket(
self._sock, server_hostname=server_hostname
)
except Exception as exc: # pragma: nocover
self.close()
raise exc
return SyncStream(sock)
def get_extra_info(self, info: str) -> typing.Any:
if info == "ssl_object" and isinstance(self._sock, ssl.SSLSocket):
return self._sock._sslobj # type: ignore
if info == "client_addr":
return self._sock.getsockname()
if info == "server_addr":
return self._sock.getpeername()
if info == "socket":
return self._sock
if info == "is_readable":
return is_socket_readable(self._sock)
return None
class SyncBackend(NetworkBackend):
def connect_tcp(
self,
host: str,
port: int,
timeout: float | None = None,
local_address: str | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> NetworkStream:
# Note that we automatically include `TCP_NODELAY`
# in addition to any other custom socket options.
if socket_options is None:
socket_options = [] # pragma: no cover
address = (host, port)
source_address = None if local_address is None else (local_address, 0)
exc_map: ExceptionMapping = {
socket.timeout: ConnectTimeout,
OSError: ConnectError,
}
with map_exceptions(exc_map):
sock = socket.create_connection(
address,
timeout,
source_address=source_address,
)
for option in socket_options:
sock.setsockopt(*option) # pragma: no cover
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
return SyncStream(sock)
def connect_unix_socket(
self,
path: str,
timeout: float | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> NetworkStream: # pragma: nocover
if sys.platform == "win32":
raise RuntimeError(
"Attempted to connect to a UNIX socket on a Windows system."
)
if socket_options is None:
socket_options = []
exc_map: ExceptionMapping = {
socket.timeout: ConnectTimeout,
OSError: ConnectError,
}
with map_exceptions(exc_map):
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
for option in socket_options:
sock.setsockopt(*option)
sock.settimeout(timeout)
sock.connect(path)
return SyncStream(sock)