Skip to content

Commit 796ba7c

Browse files
authored
aiohttp finally works with both HTTP and HTTPS (#56)
* Draft for make tests fail with aiohttp 2.x version. * Update LICENSE * Mocket finally supports aiohttp >= 2 (but still trying to fix HTTPS). * Fix for HTTPS and asyncio/aiohttp.
1 parent 7408452 commit 796ba7c

5 files changed

Lines changed: 77 additions & 40 deletions

File tree

mocket/mocket.py

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
'true_getaddrinfo',
3535
'true_ssl_wrap_socket',
3636
'true_ssl_socket',
37+
'true_ssl_context',
38+
'true_inet_pton',
3739
'create_connection',
3840
'MocketSocket',
3941
'Mocket',
@@ -48,11 +50,8 @@
4850
true_getaddrinfo = socket.getaddrinfo
4951
true_ssl_wrap_socket = ssl.wrap_socket
5052
true_ssl_socket = ssl.SSLSocket
51-
try:
52-
true_ssl_context = ssl.SSLContext
53-
except AttributeError:
54-
# Python 2.6
55-
true_ssl_context = None
53+
true_ssl_context = ssl.SSLContext
54+
true_inet_pton = socket.inet_pton
5655

5756

5857
class SuperFakeSSLContext(object):
@@ -98,7 +97,6 @@ def wrap_bio(self, incoming, outcoming, *args, **kwargs):
9897
# FIXME: fake SSLObject implementation
9998
ssl_obj = MocketSocket()
10099
ssl_obj._host = kwargs['server_hostname']
101-
# ssl_obj.fd = outcoming
102100
return ssl_obj
103101

104102
def __getattr__(self, name):
@@ -119,8 +117,8 @@ class MocketSocket(object):
119117
family = None
120118
type = None
121119
proto = None
122-
_host = '127.0.0.1'
123-
_port = 80
120+
_host = None
121+
_port = None
124122
_address = None
125123
cipher = lambda s: ("ADH", "AES256", "SHA")
126124
compression = lambda s: ssl.OP_NO_COMPRESSION
@@ -177,8 +175,9 @@ def getsockname(self):
177175
return socket.gethostbyname(self._address[0]), self._address[1]
178176

179177
def getpeercert(self, *args, **kwargs):
180-
if not self._host:
181-
self._host, _ = self._address
178+
if not (self._host and self._port):
179+
self._address = self._host, self._port = Mocket._address
180+
182181
now = datetime.now()
183182
shift = now + timedelta(days=30 * 12)
184183
return {
@@ -205,21 +204,16 @@ def getpeercert(self, *args, **kwargs):
205204
def unwrap(self):
206205
return self
207206

208-
def write(self, c):
209-
return len(c)
207+
def write(self, data):
208+
return self.send(encode_to_bytes(data))
210209

211210
def fileno(self):
212-
if not self.fd.r_fd:
213-
self.fd.r_fd, self.fd.w_fd = os.pipe()
214-
return self.fd.r_fd
211+
Mocket.r_fd, Mocket.w_fd = os.pipe()
212+
return Mocket.r_fd
215213

216214
def connect(self, address):
217215
self._address = self._host, self._port = address
218-
219-
# def close(self):
220-
# if self.true_socket and self._connected:
221-
# self.true_socket.close()
222-
# self._closed = True
216+
Mocket._address = address
223217

224218
def makefile(self, mode='r', bufsize=-1):
225219
self._mode = mode
@@ -243,12 +237,17 @@ def sendall(self, data, *args, **kwargs):
243237
self.fd.truncate()
244238
self.fd.seek(0)
245239

240+
def read(self, buffersize):
241+
return self.fd.read(buffersize)
242+
246243
def recv(self, buffersize, flags=None):
244+
if Mocket.r_fd and Mocket.w_fd:
245+
return os.read(Mocket.r_fd, buffersize)
247246
return self.fd.read(buffersize)
248247

249248
def _connect(self): # pragma: no cover
250249
if not self._connected:
251-
self.true_socket.connect(self._address)
250+
self.true_socket.connect(Mocket._address)
252251
self._connected = True
253252

254253
def true_sendall(self, data, *args, **kwargs):
@@ -318,15 +317,17 @@ def true_sendall(self, data, *args, **kwargs):
318317

319318
def send(self, data, *args, **kwargs): # pragma: no cover
320319
entry = self.get_entry(data)
321-
if entry:
322-
if self._entry != entry:
323-
self.sendall(data, *args, **kwargs)
320+
if entry and self._entry != entry:
321+
self.sendall(data, *args, **kwargs)
324322
self._entry = entry
325323
return len(data)
326324

325+
# def __getattribute__(self, name):
326+
# return super(MocketSocket, self).__getattribute__(name)
327+
327328
def __getattr__(self, name):
328-
# useful when clients call methods on real
329-
# socket we do not provide on the fake one
329+
""" Useful when clients call methods on real
330+
socket we do not provide on the fake one. """
330331
return getattr(self.true_socket, name) # pragma: no cover
331332

332333

@@ -335,6 +336,8 @@ class Mocket(object):
335336
_requests = []
336337
_namespace = text_type(id(_entries))
337338
_truesocket_recording_dir = None
339+
r_fd = None
340+
w_fd = None
338341

339342
@classmethod
340343
def register(cls, *entries):
@@ -387,6 +390,10 @@ def enable(namespace=None, truesocket_recording_dir=None):
387390
lambda host, port, family=None, socktype=None, proto=None, flags=None: [(2, 1, 6, '', (host, port))]
388391
ssl.wrap_socket = ssl.__dict__['wrap_socket'] = FakeSSLContext.wrap_socket
389392
ssl.SSLContext = ssl.__dict__['SSLSocket'] = FakeSSLContext
393+
socket.inet_pton = socket.__dict__['inet_pton'] = lambda family, ip: byte_type(
394+
'\x7f\x00\x00\x01',
395+
'utf-8'
396+
)
390397

391398
@staticmethod
392399
def disable():
@@ -400,6 +407,7 @@ def disable():
400407
ssl.wrap_socket = ssl.__dict__['SSLSocket'] = true_ssl_wrap_socket
401408
ssl.SSLSocket = ssl.__dict__['wrap_socket'] = true_ssl_socket
402409
ssl.SSLContext = ssl.__dict__['SSLSocket'] = true_ssl_context
410+
socket.inet_pton = socket.__dict__['inet_pton'] = true_inet_pton
403411

404412
@classmethod
405413
def get_namespace(cls):

mocket/mockhttp.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,10 @@ def can_handle(self, data):
100100
requestline, _ = decode_from_bytes(data).split(CRLF, 1)
101101
method, path, version = self._parse_requestline(requestline)
102102
except ValueError:
103-
return self == Mocket._last_entry
103+
try:
104+
return self == Mocket._last_entry
105+
except AttributeError:
106+
return False
104107
uri = urlsplit(path)
105108
kw = dict(keep_blank_values=True)
106109
ch = uri.path == self.path and parse_qs(uri.query, **kw) == parse_qs(self.query, **kw) and method == self.method

mocket/utils.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33

44

55
class MocketSocketCore(io.BytesIO):
6-
r_fd = None
7-
w_fd = None
8-
96
def write(self, content):
107
super(MocketSocketCore, self).write(content)
118

12-
if self.r_fd and self.w_fd:
13-
os.write(self.w_fd, content)
9+
from mocket import Mocket
10+
11+
if Mocket.r_fd and Mocket.w_fd:
12+
os.write(Mocket.w_fd, content)

runtests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def runtests(args=None):
1717
if major == 3 and minor >= 5:
1818
python35 = True
1919

20-
pip.main(['install', 'aiohttp'])
20+
pip.main(['install', 'aiohttp', 'async_timeout'])
2121

2222
if not any(a for a in args[1:] if not a.startswith('-')):
2323
args.append('tests/main')

tests/tests35/test_http_aiohttp.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import aiohttp
22
import asyncio
3+
import async_timeout
34
from unittest import TestCase
45

56
from mocket.mocket import mocketize
@@ -8,21 +9,47 @@
89

910
class AioHttpEntryTestCase(TestCase):
1011
@mocketize
11-
def test_session(self):
12+
def test_http_session(self):
1213
url = 'http://httpbin.org/ip'
1314
body = "asd" * 100
1415
Entry.single_register(Entry.GET, url, body=body, status=404)
1516
Entry.single_register(Entry.POST, url, body=body*2, status=201)
1617

1718
async def main(l):
1819
async with aiohttp.ClientSession(loop=l) as session:
19-
async with session.get(url) as get_response:
20-
assert get_response.status == 404
21-
assert await get_response.text() == body
20+
with async_timeout.timeout(3):
21+
async with session.get(url) as get_response:
22+
assert get_response.status == 404
23+
assert await get_response.text() == body
2224

23-
async with session.post(url, data=body*6) as post_response:
24-
assert post_response.status == 201
25-
assert await post_response.text() == body*2
25+
with async_timeout.timeout(3):
26+
async with session.post(url, data=body * 6) as post_response:
27+
assert post_response.status == 201
28+
assert await post_response.text() == body * 2
2629

2730
loop = asyncio.get_event_loop()
31+
loop.set_debug(True)
32+
loop.run_until_complete(main(loop))
33+
34+
@mocketize
35+
def test_https_session(self):
36+
url = 'https://httpbin.org/ip'
37+
body = "asd" * 100
38+
Entry.single_register(Entry.GET, url, body=body, status=404)
39+
Entry.single_register(Entry.POST, url, body=body*2, status=201)
40+
41+
async def main(l):
42+
async with aiohttp.ClientSession(loop=l) as session:
43+
with async_timeout.timeout(3):
44+
async with session.get(url) as get_response:
45+
assert get_response.status == 404
46+
assert await get_response.text() == body
47+
48+
with async_timeout.timeout(3):
49+
async with session.post(url, data=body * 6) as post_response:
50+
assert post_response.status == 201
51+
assert await post_response.text() == body * 2
52+
53+
loop = asyncio.get_event_loop()
54+
loop.set_debug(True)
2855
loop.run_until_complete(main(loop))

0 commit comments

Comments
 (0)