Skip to content

Commit fdbfb0a

Browse files
authored
Add :: Lock based thread-safe adapter (#20)
* Add thread-safe (`Lock` based) HTTP adapters for IPv4 and IPv6 support * add Lock adapters * add Lock adapters * add comprehensive test suite for Lock-based adapter * update test command * `CHANGELOG.md` updated * `README.md` updated * add docstring Added parameter documentation to filtered_getaddrinfo function.
1 parent 33191ed commit fdbfb0a

6 files changed

Lines changed: 337 additions & 3 deletions

File tree

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ jobs:
4242
pip install --upgrade --upgrade-strategy=only-if-needed -r test-requirements.txt
4343
- name: Test with pytest
4444
run: |
45-
python -m pytest tests/test_adapters.py --cov=ipforce --cov-report=term
45+
python -m pytest tests/ --cov=ipforce --cov-report=term --ignore=tests/test_ipv4.py --ignore=tests/test_ipv6.py
4646
- name: Upload coverage to Codecov
4747
uses: codecov/codecov-action@v4
4848
with:

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
66

77
## [Unreleased]
88
### Added
9+
- `IPv6LockAdapter` class
10+
- `IPv4LockAdapter` class
911
- Logo
1012
### Changed
1113
- `README.md` updated

README.md

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,22 @@ response = session.get('https://ifconfig.co/json')
9595
```
9696

9797
> [!WARNING]
98-
> Current adapters are NOT thread-safe! They modify the global `socket.getaddrinfo` function, which can cause issues in multi-threaded applications.
98+
> `IPv4TransportAdapter` / `IPv6TransportAdapter` are NOT thread-safe. They modify the global `socket.getaddrinfo` function, which can cause race conditions in multi-threaded applications. Use the thread-safe adapters below for concurrent usage.
99+
100+
### Thread-Safe: Lock-Based Adapters
101+
102+
A process-wide lock serializes access to `socket.getaddrinfo`, guaranteeing correctness under concurrent access.
103+
104+
```python
105+
import requests
106+
from ipforce import IPv4LockAdapter, IPv6LockAdapter
107+
108+
session = requests.Session()
109+
session.mount('http://', IPv4LockAdapter()) # or IPv6LockAdapter()
110+
session.mount('https://', IPv4LockAdapter()) # or IPv6LockAdapter()
111+
112+
response = session.get('https://ifconfig.co/json')
113+
```
99114

100115
## Issues & Bug Reports
101116

ipforce/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
"""ipforce modules."""
33
from .params import IPFORCE_VERSION
44
from .adapters import IPv4TransportAdapter, IPv6TransportAdapter
5+
from .adapters import IPv4LockAdapter, IPv6LockAdapter
56

67
__version__ = IPFORCE_VERSION
78

8-
__all__ = ["IPv4TransportAdapter", "IPv6TransportAdapter"]
9+
__all__ = [
10+
"IPv4TransportAdapter", "IPv6TransportAdapter",
11+
"IPv4LockAdapter", "IPv6LockAdapter",
12+
]

ipforce/adapters.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33
import socket
44
from typing import Any, List, Tuple
55
from requests.adapters import HTTPAdapter
6+
from threading import Lock
7+
8+
# ============================================================================
9+
# Base adapter (not thread-safe)
10+
# ============================================================================
611

712

813
class IPv4TransportAdapter(HTTPAdapter):
@@ -63,3 +68,70 @@ def ipv6_only_getaddrinfo(*gargs: list, **gkwargs: dict) -> List[Tuple]:
6368
finally:
6469
socket.getaddrinfo = original_getaddrinfo
6570
return response
71+
72+
73+
# ============================================================================
74+
# Lock-based thread-safe adapters
75+
#
76+
# A process-wide lock serializes access to the global socket.getaddrinfo
77+
# patch. Correct under all conditions, but serializes DNS resolution
78+
# across threads.
79+
# ============================================================================
80+
81+
_adapter_lock = Lock()
82+
83+
84+
class _BaseLockAdapter(HTTPAdapter):
85+
"""Base class for lock-based thread-safe adapters."""
86+
87+
_family = socket.AF_UNSPEC
88+
89+
def send(self, *args: list, **kwargs: dict) -> Any:
90+
"""
91+
Thread-safe send that acquires a lock before patching getaddrinfo.
92+
93+
:param args: additional list arguments for the send method
94+
:param kwargs: additional keyword arguments for the send method
95+
"""
96+
with _adapter_lock:
97+
original_getaddrinfo = socket.getaddrinfo
98+
family = self._family
99+
100+
def filtered_getaddrinfo(*gargs: list, **gkwargs: dict) -> List[Tuple]:
101+
"""Filter getaddrinfo results to the target address family.
102+
103+
:param gargs: additional list arguments for the original_getaddrinfo function
104+
:param gkwargs: additional keyword arguments for the original_getaddrinfo function
105+
"""
106+
results = original_getaddrinfo(*gargs, **gkwargs)
107+
return [r for r in results if r[0] == family]
108+
109+
socket.getaddrinfo = filtered_getaddrinfo
110+
try:
111+
return super().send(*args, **kwargs)
112+
finally:
113+
socket.getaddrinfo = original_getaddrinfo
114+
115+
116+
class IPv4LockAdapter(_BaseLockAdapter):
117+
"""Thread-safe HTTPAdapter that enforces IPv4 using a global lock.
118+
119+
All requests across all threads are serialized through a single lock,
120+
ensuring no race conditions on socket.getaddrinfo. Best suited for
121+
low-concurrency use cases where simplicity is preferred.
122+
"""
123+
124+
_family = socket.AF_INET
125+
126+
127+
class IPv6LockAdapter(_BaseLockAdapter):
128+
"""Thread-safe HTTPAdapter that enforces IPv6 using a global lock.
129+
130+
All requests across all threads are serialized through a single lock,
131+
ensuring no race conditions on socket.getaddrinfo. Best suited for
132+
low-concurrency use cases where simplicity is preferred.
133+
"""
134+
135+
_family = socket.AF_INET6
136+
137+

tests/test_lock_adapters.py

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
"""Unit and concurrency tests for Lock-based thread-safe adapters."""
2+
import contextlib
3+
import socket
4+
import threading
5+
import unittest
6+
from concurrent.futures import ThreadPoolExecutor, as_completed
7+
from unittest.mock import patch, MagicMock
8+
9+
from requests.adapters import HTTPAdapter
10+
11+
from ipforce.adapters import IPv4LockAdapter, IPv6LockAdapter, _adapter_lock
12+
13+
MIXED_ADDR_RESULTS = [
14+
(socket.AF_INET, socket.SOCK_STREAM, 6, '', ('192.168.1.1', 80)),
15+
(socket.AF_INET6, socket.SOCK_STREAM, 6, '', ('::1', 80)),
16+
(socket.AF_INET, socket.SOCK_STREAM, 6, '', ('10.0.0.1', 80)),
17+
(socket.AF_INET6, socket.SOCK_STREAM, 6, '', ('2001:db8::1', 80)),
18+
]
19+
20+
NUM_THREADS = 8
21+
SENDS_PER_THREAD = 20
22+
23+
24+
# ============================================================================
25+
# Unit tests
26+
# ============================================================================
27+
28+
29+
class TestIPv4LockAdapter(unittest.TestCase):
30+
"""Test cases for IPv4LockAdapter."""
31+
32+
def setUp(self):
33+
"""Set up test fixtures."""
34+
self.adapter = IPv4LockAdapter()
35+
36+
def test_ipv4_filtering_during_send(self):
37+
"""Test that IPv4LockAdapter filters only IPv4 addresses during send."""
38+
captured = []
39+
40+
def mock_super_send(*args, **kwargs):
41+
captured.extend(socket.getaddrinfo('example.com', 80))
42+
return MagicMock()
43+
44+
with patch('socket.getaddrinfo', return_value=MIXED_ADDR_RESULTS):
45+
with patch.object(HTTPAdapter, 'send', mock_super_send):
46+
self.adapter.send(MagicMock())
47+
48+
self.assertEqual(len(captured), 2)
49+
for result in captured:
50+
self.assertEqual(result[0], socket.AF_INET)
51+
52+
def test_cleanup_after_send(self):
53+
"""Test that getaddrinfo is restored after send."""
54+
original = socket.getaddrinfo
55+
56+
with patch.object(HTTPAdapter, 'send', return_value=MagicMock()):
57+
self.adapter.send(MagicMock())
58+
59+
self.assertEqual(socket.getaddrinfo, original)
60+
61+
def test_cleanup_on_exception(self):
62+
"""Test that getaddrinfo is restored even when send raises."""
63+
original = socket.getaddrinfo
64+
65+
with patch.object(HTTPAdapter, 'send', side_effect=Exception("error")):
66+
with self.assertRaises(Exception):
67+
self.adapter.send(MagicMock())
68+
69+
self.assertEqual(socket.getaddrinfo, original)
70+
71+
def test_lock_is_acquired_during_send(self):
72+
"""Test that the adapter lock is held during the send call."""
73+
lock_was_held = []
74+
75+
def mock_super_send(*args, **kwargs):
76+
lock_was_held.append(_adapter_lock.locked())
77+
return MagicMock()
78+
79+
with patch.object(HTTPAdapter, 'send', mock_super_send):
80+
self.adapter.send(MagicMock())
81+
82+
self.assertTrue(lock_was_held[0])
83+
84+
85+
class TestIPv6LockAdapter(unittest.TestCase):
86+
"""Test cases for IPv6LockAdapter."""
87+
88+
def setUp(self):
89+
"""Set up test fixtures."""
90+
self.adapter = IPv6LockAdapter()
91+
92+
def test_ipv6_filtering_during_send(self):
93+
"""Test that IPv6LockAdapter filters only IPv6 addresses during send."""
94+
captured = []
95+
96+
def mock_super_send(*args, **kwargs):
97+
captured.extend(socket.getaddrinfo('example.com', 80))
98+
return MagicMock()
99+
100+
with patch('socket.getaddrinfo', return_value=MIXED_ADDR_RESULTS):
101+
with patch.object(HTTPAdapter, 'send', mock_super_send):
102+
self.adapter.send(MagicMock())
103+
104+
self.assertEqual(len(captured), 2)
105+
for result in captured:
106+
self.assertEqual(result[0], socket.AF_INET6)
107+
108+
def test_cleanup_after_send(self):
109+
"""Test that getaddrinfo is restored after send."""
110+
original = socket.getaddrinfo
111+
112+
with patch.object(HTTPAdapter, 'send', return_value=MagicMock()):
113+
self.adapter.send(MagicMock())
114+
115+
self.assertEqual(socket.getaddrinfo, original)
116+
117+
118+
# ============================================================================
119+
# Concurrency tests
120+
# ============================================================================
121+
122+
123+
def _run_concurrent_lock_test(adapter, expected_family):
124+
"""Run a barrier-synchronised concurrency test for a lock adapter."""
125+
barrier = threading.Barrier(NUM_THREADS)
126+
lock = threading.Lock()
127+
results = []
128+
errors = []
129+
130+
mock_gai = MagicMock(return_value=MIXED_ADDR_RESULTS)
131+
132+
def mock_super_send(*args, **kwargs):
133+
captured = list(socket.getaddrinfo('example.com', 80))
134+
for r in captured:
135+
if r[0] != expected_family:
136+
with lock:
137+
errors.append(
138+
"Expected family {exp}, got {got}".format(exp=expected_family, got=r[0]),
139+
)
140+
with lock:
141+
results.append(len(captured))
142+
return MagicMock()
143+
144+
def worker(_idx):
145+
barrier.wait()
146+
for _ in range(SENDS_PER_THREAD):
147+
adapter.send(MagicMock())
148+
149+
with patch('socket.getaddrinfo', mock_gai):
150+
with patch.object(HTTPAdapter, 'send', mock_super_send):
151+
with ThreadPoolExecutor(max_workers=NUM_THREADS) as pool:
152+
futures = [pool.submit(worker, i) for i in range(NUM_THREADS)]
153+
for f in as_completed(futures):
154+
f.result()
155+
156+
return results, errors
157+
158+
159+
class TestLockAdapterConcurrency(unittest.TestCase):
160+
"""Verify IPv4LockAdapter / IPv6LockAdapter under thread contention."""
161+
162+
def test_concurrent_ipv4_sends(self):
163+
"""Multiple threads using IPv4LockAdapter simultaneously."""
164+
results, errors = _run_concurrent_lock_test(
165+
IPv4LockAdapter(), socket.AF_INET,
166+
)
167+
self.assertEqual(errors, [])
168+
self.assertEqual(len(results), NUM_THREADS * SENDS_PER_THREAD)
169+
170+
def test_concurrent_ipv6_sends(self):
171+
"""Multiple threads using IPv6LockAdapter simultaneously."""
172+
results, errors = _run_concurrent_lock_test(
173+
IPv6LockAdapter(), socket.AF_INET6,
174+
)
175+
self.assertEqual(errors, [])
176+
177+
def test_getaddrinfo_restored_after_concurrent_sends(self):
178+
"""Verify socket.getaddrinfo is pristine after concurrent lock-adapter sends."""
179+
original = socket.getaddrinfo
180+
adapter = IPv4LockAdapter()
181+
barrier = threading.Barrier(NUM_THREADS)
182+
183+
def worker(_idx):
184+
barrier.wait()
185+
for _ in range(SENDS_PER_THREAD):
186+
adapter.send(MagicMock())
187+
188+
with patch.object(HTTPAdapter, 'send', return_value=MagicMock()):
189+
with ThreadPoolExecutor(max_workers=NUM_THREADS) as pool:
190+
futures = [pool.submit(worker, i) for i in range(NUM_THREADS)]
191+
for f in as_completed(futures):
192+
f.result()
193+
194+
self.assertIs(socket.getaddrinfo, original)
195+
196+
def test_mixed_ipv4_ipv6_lock_adapters(self):
197+
"""IPv4 and IPv6 lock adapters running concurrently filter correctly."""
198+
lock4 = IPv4LockAdapter()
199+
lock6 = IPv6LockAdapter()
200+
barrier = threading.Barrier(NUM_THREADS)
201+
data_lock = threading.Lock()
202+
errors = []
203+
completed = []
204+
205+
mock_gai = MagicMock(return_value=MIXED_ADDR_RESULTS)
206+
207+
def mock_super_send(*args, **kwargs):
208+
results = socket.getaddrinfo('example.com', 80)
209+
families = set(r[0] for r in results)
210+
if len(families) > 1:
211+
with data_lock:
212+
errors.append("Mixed families in single send: {f}".format(f=families))
213+
return MagicMock()
214+
215+
def v4_worker():
216+
barrier.wait()
217+
for _ in range(SENDS_PER_THREAD):
218+
lock4.send(MagicMock())
219+
with data_lock:
220+
completed.append('v4')
221+
222+
def v6_worker():
223+
barrier.wait()
224+
for _ in range(SENDS_PER_THREAD):
225+
lock6.send(MagicMock())
226+
with data_lock:
227+
completed.append('v6')
228+
229+
with patch('socket.getaddrinfo', mock_gai):
230+
with patch.object(HTTPAdapter, 'send', mock_super_send):
231+
with ThreadPoolExecutor(max_workers=NUM_THREADS) as pool:
232+
half = NUM_THREADS // 2
233+
futures = (
234+
[pool.submit(v4_worker) for _ in range(half)] +
235+
[pool.submit(v6_worker) for _ in range(NUM_THREADS - half)]
236+
)
237+
for f in as_completed(futures):
238+
f.result()
239+
240+
self.assertEqual(errors, [])
241+
self.assertEqual(len(completed), NUM_THREADS)

0 commit comments

Comments
 (0)