Skip to content

Commit 0447e93

Browse files
GWealecopybara-github
authored andcommitted
fix: Fix SSRF and local-file access in load_web_page
Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 905226545
1 parent 1cdd1e7 commit 0447e93

2 files changed

Lines changed: 540 additions & 3 deletions

File tree

src/google/adk/tools/load_web_page.py

Lines changed: 259 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,260 @@
1616

1717
"""Tool for web browse."""
1818

19+
from dataclasses import dataclass
20+
import ipaddress
21+
import socket
22+
from typing import Any
23+
from urllib.parse import ParseResult
24+
from urllib.parse import urlparse
25+
1926
import requests
27+
from requests.adapters import HTTPAdapter
28+
from requests.utils import get_environ_proxies
29+
from requests.utils import select_proxy
30+
31+
_ALLOWED_URL_SCHEMES = frozenset({'http', 'https'})
32+
_DEFAULT_PORT_BY_SCHEME = {'http': 80, 'https': 443}
33+
_ResolvedAddress = ipaddress.IPv4Address | ipaddress.IPv6Address
34+
35+
36+
@dataclass(frozen=True)
37+
class _RequestTarget:
38+
parsed_url: ParseResult
39+
scheme: str
40+
hostname: str
41+
host_header: str
42+
43+
44+
class _PinnedAddressAdapter(HTTPAdapter):
45+
"""Routes a request to a vetted IP while preserving the original host."""
46+
47+
def __init__(
48+
self,
49+
*,
50+
rewritten_url: str,
51+
host_header: str,
52+
hostname: str,
53+
) -> None:
54+
super().__init__()
55+
self._rewritten_url = rewritten_url
56+
self._host_header = host_header
57+
self._hostname = hostname
58+
59+
def build_connection_pool_key_attributes(
60+
self,
61+
request: requests.PreparedRequest,
62+
verify: bool | str,
63+
cert: tuple[str, str] | str | None = None,
64+
) -> tuple[dict[str, Any], dict[str, Any]]:
65+
host_params, pool_kwargs = super().build_connection_pool_key_attributes(
66+
request, verify, cert
67+
)
68+
if host_params['scheme'] == 'https':
69+
pool_kwargs['assert_hostname'] = self._hostname
70+
pool_kwargs['server_hostname'] = self._hostname
71+
return host_params, pool_kwargs
72+
73+
def send(
74+
self,
75+
request: requests.PreparedRequest,
76+
stream: bool = False,
77+
timeout: Any = None,
78+
verify: bool | str = True,
79+
cert: tuple[str, str] | str | None = None,
80+
proxies: dict[str, str | None] | None = None,
81+
) -> requests.Response:
82+
prepared_request = request.copy()
83+
prepared_request.headers['Host'] = self._host_header
84+
prepared_request.url = self._rewritten_url
85+
return super().send(
86+
prepared_request,
87+
stream=stream,
88+
timeout=timeout,
89+
verify=verify,
90+
cert=cert,
91+
proxies=proxies,
92+
)
93+
94+
95+
def _failed_to_fetch_message(url: str) -> str:
96+
return f'Failed to fetch url: {url}'
97+
98+
99+
def _format_host(hostname: str) -> str:
100+
if ':' in hostname:
101+
return f'[{hostname}]'
102+
return hostname
103+
104+
105+
def _default_port_for_scheme(scheme: str) -> int:
106+
return _DEFAULT_PORT_BY_SCHEME[scheme]
107+
108+
109+
def _build_host_header(
110+
*, hostname: str, scheme: str, explicit_port: int | None
111+
) -> str:
112+
formatted_hostname = _format_host(hostname)
113+
if explicit_port is None or explicit_port == _default_port_for_scheme(scheme):
114+
return formatted_hostname
115+
return f'{formatted_hostname}:{explicit_port}'
116+
117+
118+
def _parse_request_target(url: str) -> _RequestTarget:
119+
parsed_url = urlparse(url)
120+
scheme = parsed_url.scheme.lower()
121+
if scheme not in _ALLOWED_URL_SCHEMES:
122+
raise ValueError(f'Unsupported url scheme: {url}')
123+
124+
hostname = parsed_url.hostname
125+
if not hostname:
126+
raise ValueError(f'URL is missing a hostname: {url}')
127+
128+
try:
129+
explicit_port = parsed_url.port
130+
except ValueError as exc:
131+
raise ValueError(f'Invalid url port: {url}') from exc
132+
133+
return _RequestTarget(
134+
parsed_url=parsed_url,
135+
scheme=scheme,
136+
hostname=hostname,
137+
host_header=_build_host_header(
138+
hostname=hostname,
139+
scheme=scheme,
140+
explicit_port=explicit_port,
141+
),
142+
)
143+
144+
145+
def _parse_ip_literal(hostname: str) -> _ResolvedAddress | None:
146+
try:
147+
return ipaddress.ip_address(hostname)
148+
except ValueError:
149+
return None
150+
151+
152+
def _is_blocked_hostname(hostname: str) -> bool:
153+
normalized_hostname = hostname.rstrip('.').lower()
154+
return normalized_hostname == 'localhost' or normalized_hostname.endswith(
155+
'.localhost'
156+
)
157+
158+
159+
def _is_blocked_address(address: _ResolvedAddress) -> bool:
160+
return not address.is_global
161+
162+
163+
def _resolve_host_addresses(hostname: str) -> tuple[_ResolvedAddress, ...]:
164+
resolved_address = _parse_ip_literal(hostname)
165+
166+
if resolved_address is not None:
167+
return (resolved_address,)
168+
169+
try:
170+
address_info = socket.getaddrinfo(
171+
hostname,
172+
None,
173+
type=socket.SOCK_STREAM,
174+
proto=socket.IPPROTO_TCP,
175+
)
176+
except (socket.gaierror, UnicodeError) as exc:
177+
raise ValueError(f'Unable to resolve host: {hostname}') from exc
178+
179+
resolved_addresses: list[_ResolvedAddress] = []
180+
for family, _, _, _, sockaddr in address_info:
181+
if family not in (socket.AF_INET, socket.AF_INET6):
182+
continue
183+
resolved_addresses.append(ipaddress.ip_address(sockaddr[0]))
184+
185+
if not resolved_addresses:
186+
raise ValueError(f'Unable to resolve host: {hostname}')
187+
188+
return tuple(resolved_addresses)
189+
190+
191+
def _get_proxy_url(url: str) -> str | None:
192+
proxies = get_environ_proxies(url)
193+
return select_proxy(url, proxies)
194+
195+
196+
def _resolve_direct_addresses(hostname: str) -> tuple[_ResolvedAddress, ...]:
197+
resolved_addresses = tuple(dict.fromkeys(_resolve_host_addresses(hostname)))
198+
if any(_is_blocked_address(address) for address in resolved_addresses):
199+
raise ValueError(f'Blocked host: {hostname}')
200+
return resolved_addresses
201+
202+
203+
def _rewrite_url_host(parsed_url: ParseResult, hostname: str) -> str:
204+
explicit_port = parsed_url.port
205+
formatted_hostname = _format_host(hostname)
206+
if explicit_port is None:
207+
rewritten_netloc = formatted_hostname
208+
else:
209+
rewritten_netloc = f'{formatted_hostname}:{explicit_port}'
210+
return parsed_url._replace(netloc=rewritten_netloc).geturl()
211+
212+
213+
def _fetch_direct_response(
214+
*,
215+
url: str,
216+
target: _RequestTarget,
217+
resolved_addresses: tuple[_ResolvedAddress, ...],
218+
) -> requests.Response:
219+
last_error: requests.RequestException | None = None
220+
for address in resolved_addresses:
221+
session = requests.Session()
222+
adapter = _PinnedAddressAdapter(
223+
rewritten_url=_rewrite_url_host(target.parsed_url, str(address)),
224+
host_header=target.host_header,
225+
hostname=target.hostname,
226+
)
227+
session.mount(f'{target.scheme}://', adapter)
228+
try:
229+
return session.get(
230+
url,
231+
allow_redirects=False,
232+
proxies={'http': None, 'https': None},
233+
)
234+
except requests.RequestException as exc:
235+
last_error = exc
236+
finally:
237+
session.close()
238+
239+
if last_error is not None:
240+
raise last_error
241+
raise requests.RequestException(f'Unable to fetch url: {url}')
242+
243+
244+
def _fetch_response(url: str) -> requests.Response:
245+
target = _parse_request_target(url)
246+
247+
if _is_blocked_hostname(target.hostname):
248+
raise ValueError(f'Blocked host: {target.hostname}')
249+
250+
parsed_ip_literal = _parse_ip_literal(target.hostname)
251+
if _get_proxy_url(url):
252+
# Proxies resolve the target hostname remotely, so only literal IPs and
253+
# localhost-style names can be rejected locally without breaking proxy use.
254+
if parsed_ip_literal is not None and _is_blocked_address(parsed_ip_literal):
255+
raise ValueError(f'Blocked host: {target.hostname}')
256+
return requests.get(url, allow_redirects=False)
257+
258+
if parsed_ip_literal is not None:
259+
if _is_blocked_address(parsed_ip_literal):
260+
raise ValueError(f'Blocked host: {target.hostname}')
261+
return _fetch_direct_response(
262+
url=url,
263+
target=target,
264+
resolved_addresses=(parsed_ip_literal,),
265+
)
266+
267+
resolved_addresses = _resolve_direct_addresses(target.hostname)
268+
return _fetch_direct_response(
269+
url=url,
270+
target=target,
271+
resolved_addresses=resolved_addresses,
272+
)
20273

21274

22275
def load_web_page(url: str) -> str:
@@ -30,14 +283,17 @@ def load_web_page(url: str) -> str:
30283
"""
31284
from bs4 import BeautifulSoup
32285

33-
# Set allow_redirects=False to prevent SSRF attacks via redirection.
34-
response = requests.get(url, allow_redirects=False)
286+
try:
287+
response = _fetch_response(url)
288+
except ValueError:
289+
return _failed_to_fetch_message(url)
35290

291+
# Set allow_redirects=False to prevent SSRF attacks via redirection.
36292
if response.status_code == 200:
37293
soup = BeautifulSoup(response.content, 'lxml')
38294
text = soup.get_text(separator='\n', strip=True)
39295
else:
40-
text = f'Failed to fetch url: {url}'
296+
text = _failed_to_fetch_message(url)
41297

42298
# Split the text into lines, filtering out very short lines
43299
# (e.g., single words or short subtitles)

0 commit comments

Comments
 (0)