Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/google/adk/tools/load_web_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

_ALLOWED_URL_SCHEMES = frozenset({'http', 'https'})
_DEFAULT_PORT_BY_SCHEME = {'http': 80, 'https': 443}
# Default timeout in seconds for HTTP requests.
_DEFAULT_TIMEOUT_SECONDS = 10
_ResolvedAddress = ipaddress.IPv4Address | ipaddress.IPv6Address


Expand Down Expand Up @@ -230,6 +232,7 @@ def _fetch_direct_response(
url,
allow_redirects=False,
proxies={'http': None, 'https': None},
timeout=_DEFAULT_TIMEOUT_SECONDS,
)
except requests.RequestException as exc:
last_error = exc
Expand All @@ -253,7 +256,9 @@ def _fetch_response(url: str) -> requests.Response:
# localhost-style names can be rejected locally without breaking proxy use.
if parsed_ip_literal is not None and _is_blocked_address(parsed_ip_literal):
raise ValueError(f'Blocked host: {target.hostname}')
return requests.get(url, allow_redirects=False)
return requests.get(
url, allow_redirects=False, timeout=_DEFAULT_TIMEOUT_SECONDS
)

if parsed_ip_literal is not None:
if _is_blocked_address(parsed_ip_literal):
Expand Down Expand Up @@ -285,7 +290,7 @@ def load_web_page(url: str) -> str:

try:
response = _fetch_response(url)
except ValueError:
except (ValueError, requests.RequestException):
return _failed_to_fetch_message(url)

# Set allow_redirects=False to prevent SSRF attacks via redirection.
Expand Down
122 changes: 121 additions & 1 deletion tests/unittests/tools/test_load_web_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,9 @@ def test_load_web_page_uses_proxy_for_unresolved_public_hostnames(monkeypatch):

assert result == 'This page has enough words to keep.'
mock_get.assert_called_once_with(
'https://does-not-resolve.invalid', allow_redirects=False
'https://does-not-resolve.invalid',
allow_redirects=False,
timeout=load_web_page_module._DEFAULT_TIMEOUT_SECONDS,
)
mock_send.assert_not_called()

Expand Down Expand Up @@ -279,3 +281,121 @@ def _send(
'https://93.184.216.35',
]
mock_get.assert_not_called()


def test_load_web_page_passes_timeout_to_pinned_session(monkeypatch):
_clear_proxy_env(monkeypatch)
monkeypatch.setattr(
load_web_page_module.socket,
'getaddrinfo',
mock.Mock(
return_value=[(
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
'',
('93.184.216.34', 0),
)]
),
)
monkeypatch.setattr(
'bs4.BeautifulSoup',
mock.Mock(
return_value=mock.Mock(
get_text=mock.Mock(
return_value='This page has enough words to keep.'
)
)
),
)
captured_timeouts: list[object] = []

def _send(
self,
request,
stream=False,
timeout=None,
verify=True,
cert=None,
proxies=None,
):
del self, request, stream, verify, cert, proxies
captured_timeouts.append(timeout)
return _create_response(
'<html><body><p>This page has enough words to keep.</p></body></html>'
)

monkeypatch.setattr(load_web_page_module.HTTPAdapter, 'send', _send)

load_web_page('https://example.com')

assert captured_timeouts == [load_web_page_module._DEFAULT_TIMEOUT_SECONDS]


def test_load_web_page_passes_timeout_to_proxied_get(monkeypatch):
monkeypatch.setenv('HTTPS_PROXY', 'http://proxy.example.test:8080')
monkeypatch.setenv('NO_PROXY', '')
monkeypatch.setattr(
load_web_page_module.socket,
'getaddrinfo',
mock.Mock(side_effect=AssertionError('unexpected local DNS lookup')),
)
monkeypatch.setattr(
'bs4.BeautifulSoup',
mock.Mock(
return_value=mock.Mock(
get_text=mock.Mock(
return_value='This page has enough words to keep.'
)
)
),
)
mock_get = mock.Mock(
return_value=_create_response(
'<html><body><p>This page has enough words to keep.</p></body></html>'
)
)
monkeypatch.setattr(load_web_page_module.requests, 'get', mock_get)

load_web_page('https://does-not-resolve.invalid')

mock_get.assert_called_once_with(
'https://does-not-resolve.invalid',
allow_redirects=False,
timeout=load_web_page_module._DEFAULT_TIMEOUT_SECONDS,
)


def test_load_web_page_returns_failure_on_timeout(monkeypatch):
_clear_proxy_env(monkeypatch)
monkeypatch.setattr(
load_web_page_module.socket,
'getaddrinfo',
mock.Mock(
return_value=[(
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
'',
('93.184.216.34', 0),
)]
),
)

def _send(
self,
request,
stream=False,
timeout=None,
verify=True,
cert=None,
proxies=None,
):
del self, request, stream, timeout, verify, cert, proxies
raise requests.exceptions.Timeout('boom')

monkeypatch.setattr(load_web_page_module.HTTPAdapter, 'send', _send)

result = load_web_page('https://example.com')

assert result == 'Failed to fetch url: https://example.com'