Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 20 additions & 67 deletions Lib/asyncio/base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,65 +96,6 @@ def _set_reuseport(sock):
'SO_REUSEPORT defined but not implemented.')


def _ipaddr_info(host, port, family, type, proto, flowinfo=0, scopeid=0):
# Try to skip getaddrinfo if "host" is already an IP. Users might have
# handled name resolution in their own code and pass in resolved IPs.
if not hasattr(socket, 'inet_pton'):
return

if proto not in {0, socket.IPPROTO_TCP, socket.IPPROTO_UDP} or \
host is None:
return None

if type == socket.SOCK_STREAM:
proto = socket.IPPROTO_TCP
elif type == socket.SOCK_DGRAM:
proto = socket.IPPROTO_UDP
else:
return None

if port is None:
port = 0
elif isinstance(port, bytes) and port == b'':
port = 0
elif isinstance(port, str) and port == '':
port = 0
else:
# If port's a service name like "http", don't skip getaddrinfo.
try:
port = int(port)
except (TypeError, ValueError):
return None

if family == socket.AF_UNSPEC:
afs = [socket.AF_INET]
if _HAS_IPv6:
afs.append(socket.AF_INET6)
else:
afs = [family]

if isinstance(host, bytes):
host = host.decode('idna')
if '%' in host:
# Linux's inet_pton doesn't accept an IPv6 zone index after host,
# like '::1%lo0'.
return None

for af in afs:
try:
socket.inet_pton(af, host)
# The host has already been resolved.
if _HAS_IPv6 and af == socket.AF_INET6:
return af, type, proto, '', (host, port, flowinfo, scopeid)
else:
return af, type, proto, '', (host, port)
except OSError:
pass

# "host" is not an IP address.
return None


def _interleave_addrinfos(addrinfos, first_address_family_count=1):
"""Interleave list of addrinfo tuples by family."""
# Group addresses by family
Expand Down Expand Up @@ -861,6 +802,15 @@ async def getaddrinfo(self, host, port, *,
else:
getaddr_func = socket.getaddrinfo

try:
return getaddr_func(
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

host, port, family, type, proto,
flags | socket.AI_NUMERICHOST | socket.AI_NUMERICSERV,
)
except socket.gaierror as e:
if e.errno != socket.EAI_NONAME:
raise

return await self.run_in_executor(
None, getaddr_func, host, port, family, type, proto, flags)

Expand Down Expand Up @@ -1396,14 +1346,17 @@ async def create_datagram_endpoint(self, protocol_factory,
async def _ensure_resolved(self, address, *,
family=0, type=socket.SOCK_STREAM,
proto=0, flags=0, loop):
host, port = address[:2]
info = _ipaddr_info(host, port, family, type, proto, *address[2:])
if info is not None:
# "host" is already a resolved IP.
return [info]
else:
return await loop.getaddrinfo(host, port, family=family, type=type,
proto=proto, flags=flags)
host, port, *rest = address
if not rest:
return await loop.getaddrinfo(
host, port, family=family, type=type, proto=proto, flags=flags,
)
return [
(*first, (host, port, *rest)) for *first, (host, port, *_) in
await loop.getaddrinfo(
host, port, family=family, type=type, proto=proto, flags=flags,
)
]

async def _create_server_getaddrinfo(self, host, port, family, flags):
infos = await self._ensure_resolved((host, port), family=family,
Expand Down
126 changes: 2 additions & 124 deletions Lib/test/test_asyncio/test_base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def mock_socket_module():
m_socket = mock.MagicMock(spec=socket)
for name in (
'AF_INET', 'AF_INET6', 'AF_UNSPEC', 'IPPROTO_TCP', 'IPPROTO_UDP',
'SOCK_STREAM', 'SOCK_DGRAM', 'SOL_SOCKET', 'SO_REUSEADDR', 'inet_pton'
'SOCK_STREAM', 'SOCK_DGRAM', 'SOL_SOCKET', 'SO_REUSEADDR', 'inet_pton',
'gaierror', 'AI_NUMERICHOST', 'AI_NUMERICSERV', 'EAI_NONAME',
):
if hasattr(socket, name):
setattr(m_socket, name, getattr(socket, name))
Expand All @@ -50,103 +51,6 @@ def patch_socket(f):
new_callable=mock_socket_module)(f)


class BaseEventTests(test_utils.TestCase):

def test_ipaddr_info(self):
UNSPEC = socket.AF_UNSPEC
INET = socket.AF_INET
INET6 = socket.AF_INET6
STREAM = socket.SOCK_STREAM
DGRAM = socket.SOCK_DGRAM
TCP = socket.IPPROTO_TCP
UDP = socket.IPPROTO_UDP

self.assertEqual(
(INET, STREAM, TCP, '', ('1.2.3.4', 1)),
base_events._ipaddr_info('1.2.3.4', 1, INET, STREAM, TCP))

self.assertEqual(
(INET, STREAM, TCP, '', ('1.2.3.4', 1)),
base_events._ipaddr_info(b'1.2.3.4', 1, INET, STREAM, TCP))

self.assertEqual(
(INET, STREAM, TCP, '', ('1.2.3.4', 1)),
base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, STREAM, TCP))

self.assertEqual(
(INET, DGRAM, UDP, '', ('1.2.3.4', 1)),
base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, DGRAM, UDP))

# Socket type STREAM implies TCP protocol.
self.assertEqual(
(INET, STREAM, TCP, '', ('1.2.3.4', 1)),
base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, STREAM, 0))

# Socket type DGRAM implies UDP protocol.
self.assertEqual(
(INET, DGRAM, UDP, '', ('1.2.3.4', 1)),
base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, DGRAM, 0))

# No socket type.
self.assertIsNone(
base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, 0, 0))

if socket_helper.IPV6_ENABLED:
# IPv4 address with family IPv6.
self.assertIsNone(
base_events._ipaddr_info('1.2.3.4', 1, INET6, STREAM, TCP))

self.assertEqual(
(INET6, STREAM, TCP, '', ('::3', 1, 0, 0)),
base_events._ipaddr_info('::3', 1, INET6, STREAM, TCP))

self.assertEqual(
(INET6, STREAM, TCP, '', ('::3', 1, 0, 0)),
base_events._ipaddr_info('::3', 1, UNSPEC, STREAM, TCP))

# IPv6 address with family IPv4.
self.assertIsNone(
base_events._ipaddr_info('::3', 1, INET, STREAM, TCP))

# IPv6 address with zone index.
self.assertIsNone(
base_events._ipaddr_info('::3%lo0', 1, INET6, STREAM, TCP))

def test_port_parameter_types(self):
# Test obscure kinds of arguments for "port".
INET = socket.AF_INET
STREAM = socket.SOCK_STREAM
TCP = socket.IPPROTO_TCP

self.assertEqual(
(INET, STREAM, TCP, '', ('1.2.3.4', 0)),
base_events._ipaddr_info('1.2.3.4', None, INET, STREAM, TCP))

self.assertEqual(
(INET, STREAM, TCP, '', ('1.2.3.4', 0)),
base_events._ipaddr_info('1.2.3.4', b'', INET, STREAM, TCP))

self.assertEqual(
(INET, STREAM, TCP, '', ('1.2.3.4', 0)),
base_events._ipaddr_info('1.2.3.4', '', INET, STREAM, TCP))

self.assertEqual(
(INET, STREAM, TCP, '', ('1.2.3.4', 1)),
base_events._ipaddr_info('1.2.3.4', '1', INET, STREAM, TCP))

self.assertEqual(
(INET, STREAM, TCP, '', ('1.2.3.4', 1)),
base_events._ipaddr_info('1.2.3.4', b'1', INET, STREAM, TCP))

@patch_socket
def test_ipaddr_info_no_inet_pton(self, m_socket):
del m_socket.inet_pton
self.assertIsNone(base_events._ipaddr_info('1.2.3.4', 1,
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP))


class BaseEventLoopTests(test_utils.TestCase):

def setUp(self):
Expand Down Expand Up @@ -1827,32 +1731,6 @@ def test_create_datagram_endpoint_nosoreuseport(self, m_socket):

self.assertRaises(ValueError, self.loop.run_until_complete, coro)

@patch_socket
def test_create_datagram_endpoint_ip_addr(self, m_socket):
def getaddrinfo(*args, **kw):
self.fail('should not have called getaddrinfo')

m_socket.getaddrinfo = getaddrinfo
m_socket.socket.return_value.bind = bind = mock.Mock()
self.loop._add_reader = mock.Mock()
self.loop._add_reader._is_coroutine = False

reuseport_supported = hasattr(socket, 'SO_REUSEPORT')
coro = self.loop.create_datagram_endpoint(
lambda: MyDatagramProto(loop=self.loop),
local_addr=('1.2.3.4', 0),
reuse_port=reuseport_supported)

t, p = self.loop.run_until_complete(coro)
try:
bind.assert_called_with(('1.2.3.4', 0))
m_socket.socket.assert_called_with(family=m_socket.AF_INET,
proto=m_socket.IPPROTO_UDP,
type=m_socket.SOCK_DGRAM)
finally:
t.close()
test_utils.run_briefly(self.loop) # allow transport to close

def test_accept_connection_retry(self):
sock = mock.Mock()
sock.accept.side_effect = BlockingIOError()
Expand Down
4 changes: 3 additions & 1 deletion Lib/test/test_asyncio/test_selector_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,9 @@ def test_sock_connect_resolve_using_socket_params(self, m_gai):
con = self.loop.create_task(self.loop.sock_connect(sock, addr))
self.loop.run_until_complete(con)
m_gai.assert_called_with(
addr[0], addr[1], sock.family, sock.type, sock.proto, 0)
addr[0], addr[1], sock.family, sock.type, sock.proto,
mock.ANY,
)

self.loop.run_until_complete(con)
sock.connect.assert_called_with(('127.0.0.1', 0))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
skip getaddrinfo thread if host is already resolved, using socket.AI_NUMERIC