Skip to content

Commit 4bd3e4f

Browse files
authored
Merge pull request #224 from mindflayer/chore/review-pr
Add allowed locations (whitelist) for STRICT mode
2 parents 6d861c2 + 324845f commit 4bd3e4f

8 files changed

Lines changed: 139 additions & 15 deletions

File tree

.github/workflows/main.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ jobs:
4040
make services-up
4141
- name: Setup hostname
4242
run: |
43-
export CONTAINER_ID=$(docker-compose ps -q proxy)
43+
export CONTAINER_ID=$(docker compose ps -q proxy)
4444
export CONTAINER_IP=$(docker inspect -f '{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}' $CONTAINER_ID)
4545
echo "$CONTAINER_IP httpbin.local" | sudo tee -a /etc/hosts
4646
- name: Test

README.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,20 @@ NEW!!! Sometimes you just want your tests to fail when they attempt to use the n
169169
with pytest.raises(StrictMocketException):
170170
requests.get("https://duckduckgo.com/")
171171
172+
You can specify exceptions as a list of hosts or host-port pairs.
173+
174+
.. code-block:: python
175+
176+
with Mocketizer(strict_mode=True, strict_mode_allowed=["localhost", ("intake.ourmetrics.net", 443)]):
177+
...
178+
179+
# OR
180+
181+
@mocketize(strict_mode=True, strict_mode_allowed=["localhost", ("intake.ourmetrics.net", 443)])
182+
def test_get():
183+
...
184+
185+
172186
How to be sure that all the Entry instances have been served?
173187
=============================================================
174188
Add this instruction at the end of the test execution:

mocket/async_mocket.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,16 @@
33

44

55
async def wrapper(
6-
test, truesocket_recording_dir=None, strict_mode=False, *args, **kwargs
6+
test,
7+
truesocket_recording_dir=None,
8+
strict_mode=False,
9+
strict_mode_allowed=None,
10+
*args,
11+
**kwargs,
712
):
8-
async with Mocketizer.factory(test, truesocket_recording_dir, strict_mode, args):
13+
async with Mocketizer.factory(
14+
test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args
15+
):
916
return await test(*args, **kwargs)
1017

1118

mocket/mocket.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
urllib3_wrap_socket = None
2323

2424
from .compat import basestring, byte_type, decode_from_bytes, encode_to_bytes, text_type
25-
from .exceptions import StrictMocketException
2625
from .utils import (
2726
SSL_PROTOCOL,
2827
MocketMode,
@@ -333,8 +332,8 @@ def recv(self, buffersize, flags=None):
333332
raise exc
334333

335334
def true_sendall(self, data, *args, **kwargs):
336-
if MocketMode().STRICT:
337-
raise StrictMocketException("Mocket tried to use the real `socket` module.")
335+
if not MocketMode().is_allowed((self._host, self._port)):
336+
MocketMode.raise_not_allowed()
338337

339338
req = decode_from_bytes(data)
340339
# make request unique again
@@ -642,6 +641,9 @@ def __init__(self, location, responses):
642641
r = self.response_cls(r)
643642
self.responses.append(r)
644643

644+
def __repr__(self):
645+
return "{}(location={})".format(self.__class__.__name__, self.location)
646+
645647
@staticmethod
646648
def can_handle(data):
647649
return True
@@ -670,11 +672,18 @@ def __init__(
670672
namespace=None,
671673
truesocket_recording_dir=None,
672674
strict_mode=False,
675+
strict_mode_allowed=None,
673676
):
674677
self.instance = instance
675678
self.truesocket_recording_dir = truesocket_recording_dir
676679
self.namespace = namespace or text_type(id(self))
677680
MocketMode().STRICT = strict_mode
681+
if strict_mode:
682+
MocketMode().STRICT_ALLOWED = strict_mode_allowed or []
683+
elif strict_mode_allowed:
684+
raise ValueError(
685+
"Allowed locations are only accepted when STRICT mode is active."
686+
)
678687

679688
def enter(self):
680689
Mocket.enable(
@@ -709,7 +718,7 @@ def check_and_call(self, method_name):
709718
method()
710719

711720
@staticmethod
712-
def factory(test, truesocket_recording_dir, strict_mode, args):
721+
def factory(test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args):
713722
instance = args[0] if args else None
714723
namespace = None
715724
if truesocket_recording_dir:
@@ -726,11 +735,21 @@ def factory(test, truesocket_recording_dir, strict_mode, args):
726735
namespace=namespace,
727736
truesocket_recording_dir=truesocket_recording_dir,
728737
strict_mode=strict_mode,
738+
strict_mode_allowed=strict_mode_allowed,
729739
)
730740

731741

732-
def wrapper(test, truesocket_recording_dir=None, strict_mode=False, *args, **kwargs):
733-
with Mocketizer.factory(test, truesocket_recording_dir, strict_mode, args):
742+
def wrapper(
743+
test,
744+
truesocket_recording_dir=None,
745+
strict_mode=False,
746+
strict_mode_allowed=None,
747+
*args,
748+
**kwargs,
749+
):
750+
with Mocketizer.factory(
751+
test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args
752+
):
734753
return test(*args, **kwargs)
735754

736755

mocket/mockhttp.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,11 @@ def headers(self):
6565
@property
6666
def querystring(self):
6767
parts = self._protocol.url.split("?", 1)
68-
if len(parts) == 2:
69-
return parse_qs(unquote(parts[1]), keep_blank_values=True)
70-
return {}
68+
return (
69+
parse_qs(unquote(parts[1]), keep_blank_values=True)
70+
if len(parts) == 2
71+
else {}
72+
)
7173

7274
@property
7375
def body(self):
@@ -175,6 +177,18 @@ def __init__(self, uri, method, responses, match_querystring=True):
175177
self._sent_data = b""
176178
self._match_querystring = match_querystring
177179

180+
def __repr__(self):
181+
return (
182+
"{}(method={!r}, schema={!r}, location={!r}, path={!r}, query={!r})".format(
183+
self.__class__.__name__,
184+
self.method,
185+
self.schema,
186+
self.location,
187+
self.path,
188+
self.query,
189+
)
190+
)
191+
178192
def collect(self, data):
179193
consume_response = True
180194

mocket/plugins/httpretty/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,8 @@ def register_uri(
7070
responses=None,
7171
match_querystring=False,
7272
priority=0,
73-
**headers
73+
**headers,
7474
):
75-
7675
headers = httprettifier_headers(headers)
7776

7877
if adding_headers is not None:
@@ -101,7 +100,6 @@ def force_headers(self):
101100

102101

103102
class MocketHTTPretty:
104-
105103
Response = Response
106104

107105
def __getattr__(self, name):

mocket/utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
import io
33
import os
44
import ssl
5+
from typing import Tuple, Union
56

67
from .compat import decode_from_bytes, encode_to_bytes
8+
from .exceptions import StrictMocketException
79

810
SSL_PROTOCOL = ssl.PROTOCOL_TLSv1_2
911

@@ -47,6 +49,33 @@ def get_mocketize(wrapper_):
4749
class MocketMode:
4850
__shared_state = {}
4951
STRICT = None
52+
STRICT_ALLOWED = None
5053

5154
def __init__(self):
5255
self.__dict__ = self.__shared_state
56+
57+
def is_allowed(self, location: Union[str, Tuple[str, int]]) -> bool:
58+
"""
59+
Checks if (`host`, `port`) or at least `host`
60+
are allowed locationsto perform real `socket` calls
61+
"""
62+
if not self.STRICT:
63+
return True
64+
host, _ = location
65+
return location in self.STRICT_ALLOWED or host in self.STRICT_ALLOWED
66+
67+
@staticmethod
68+
def raise_not_allowed():
69+
from .mocket import Mocket
70+
71+
current_entries = [
72+
(location, "\n ".join(map(str, entries)))
73+
for location, entries in Mocket._entries.items()
74+
]
75+
formatted_entries = "\n".join(
76+
[f" {location}:\n {entries}" for location, entries in current_entries]
77+
)
78+
raise StrictMocketException(
79+
"Mocket tried to use the real `socket` module while STRICT mode was active.\n"
80+
f"Registered entries:\n{formatted_entries}"
81+
)

tests/main/test_mode.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
from mocket import Mocketizer, mocketize
55
from mocket.exceptions import StrictMocketException
6+
from mocket.mockhttp import Entry, Response
7+
from mocket.utils import MocketMode
68

79

810
@mocketize(strict_mode=True)
@@ -26,3 +28,44 @@ def test_intermittent_strict_mode():
2628

2729
with Mocketizer(strict_mode=False):
2830
requests.get(url)
31+
32+
33+
@pytest.mark.skipif('os.getenv("SKIP_TRUE_HTTP", False)')
34+
def test_strict_mode_exceptions():
35+
url = "http://httpbin.local/ip"
36+
37+
with Mocketizer(strict_mode=True, strict_mode_allowed=["httpbin.local"]):
38+
requests.get(url)
39+
40+
with Mocketizer(strict_mode=True, strict_mode_allowed=[("httpbin.local", 80)]):
41+
requests.get(url)
42+
43+
44+
def test_strict_mode_error_message():
45+
url = "http://httpbin.local/ip"
46+
47+
Entry.register(Entry.GET, "http://httpbin.local/user.agent", Response(status=404))
48+
49+
with Mocketizer(strict_mode=True):
50+
with pytest.raises(StrictMocketException) as exc_info:
51+
requests.get(url)
52+
assert (
53+
str(exc_info.value)
54+
== """
55+
Mocket tried to use the real `socket` module while STRICT mode was active.
56+
Registered entries:
57+
('httpbin.local', 80):
58+
Entry(method='GET', schema='http', location=('httpbin.local', 80), path='/user.agent', query='')
59+
""".strip()
60+
)
61+
62+
63+
def test_strict_mode_false_with_allowed_hosts():
64+
with pytest.raises(ValueError):
65+
Mocketizer(strict_mode=False, strict_mode_allowed=["foobar.local"])
66+
67+
68+
def test_strict_mode_false_always_allowed():
69+
with Mocketizer(strict_mode=False):
70+
assert MocketMode().is_allowed("foobar.com")
71+
assert MocketMode().is_allowed(("foobar.com", 443))

0 commit comments

Comments
 (0)