Skip to content

Commit c9043db

Browse files
committed
feat: add SafeHTTPAdapter for secure HTTP requests and enhance URL validation
1 parent adc87e3 commit c9043db

File tree

3 files changed

+170
-18
lines changed

3 files changed

+170
-18
lines changed

apps/knowledge/api/file.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,20 @@ def get_parameters():
5252
@staticmethod
5353
def get_response():
5454
return DefaultResultSerializer
55+
56+
class GetUrlContentAPI(APIMixin):
57+
@staticmethod
58+
def get_parameters():
59+
return [
60+
OpenApiParameter(
61+
name="url",
62+
description="文件url",
63+
type=OpenApiTypes.STR,
64+
location='query',
65+
required=True,
66+
),
67+
]
68+
69+
@staticmethod
70+
def get_response():
71+
return DefaultResultSerializer

apps/oss/serializers/file.py

Lines changed: 147 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import re
55
import socket
66
import urllib
7-
from urllib.parse import urlparse
7+
from urllib.parse import urlparse, urlunparse
88

99
import requests
1010
import uuid_utils.compat as uuid
@@ -166,6 +166,56 @@ def delete(self):
166166
return True
167167

168168

169+
from requests.adapters import HTTPAdapter
170+
171+
172+
class SafeHTTPAdapter(HTTPAdapter):
173+
"""
174+
安全的 HTTP 适配器,防止 DNS 重绑定攻击
175+
在建立连接前验证目标 IP 地址
176+
"""
177+
178+
def send(self, request, **kwargs):
179+
# 解析 URL 获取主机名
180+
parsed_url = urlparse(request.url)
181+
host = parsed_url.hostname
182+
183+
if host:
184+
# 验证目标 IP 是否安全
185+
self._validate_host_ip(host)
186+
187+
return super().send(request, **kwargs)
188+
189+
def _validate_host_ip(self, host: str):
190+
"""验证主机解析的 IP 地址是否安全"""
191+
try:
192+
# 获取所有 IP 地址(包括 IPv4 和 IPv6)
193+
addr_infos = socket.getaddrinfo(host, None, socket.AF_UNSPEC, socket.SOCK_STREAM)
194+
195+
for addr_info in addr_infos:
196+
ip = addr_info[4][0]
197+
if self._is_unsafe_ip(ip):
198+
raise AppApiException(500, _('Access to internal IP addresses is blocked'))
199+
except AppApiException:
200+
raise
201+
except Exception as e:
202+
raise AppApiException(500, _('Failed to resolve host: {error}').format(error=str(e)))
203+
204+
def _is_unsafe_ip(self, ip: str) -> bool:
205+
"""检查 IP 地址是否属于不安全的范围"""
206+
try:
207+
ip_addr = ipaddress.ip_address(ip)
208+
return (
209+
ip_addr.is_private or
210+
ip_addr.is_loopback or
211+
ip_addr.is_reserved or
212+
ip_addr.is_link_local or
213+
ip_addr.is_multicast
214+
)
215+
except Exception:
216+
return True
217+
218+
169219
def get_url_content(url, application_id: str):
170220
application = Application.objects.filter(id=application_id).first()
171221
if application is None:
@@ -177,11 +227,21 @@ def get_url_content(url, application_id: str):
177227
file_limit = application.file_upload_setting.get('fileLimit') * 1024 * 1024
178228
parsed = validate_url(url)
179229

180-
response = requests.get(
181-
url,
182-
timeout=3,
183-
allow_redirects=False
184-
)
230+
# 创建带有安全检查的 session
231+
session = requests.Session()
232+
safe_adapter = SafeHTTPAdapter()
233+
session.mount('http://', safe_adapter)
234+
session.mount('https://', safe_adapter)
235+
236+
try:
237+
response = session.get(
238+
url,
239+
timeout=3,
240+
allow_redirects=False
241+
)
242+
finally:
243+
session.close()
244+
185245
final_host = urlparse(response.url).hostname
186246
if is_private_ip(final_host):
187247
raise ValueError("Blocked unsafe redirect to internal host")
@@ -220,24 +280,94 @@ def is_private_ip(host: str) -> bool:
220280
return True
221281

222282

223-
def validate_url(url: str):
224-
"""验证 URL 是否安全"""
283+
def validate_and_normalize_url(url: str) -> str:
284+
"""
285+
严格验证并规范化 URL,防止 URL 解析绕过攻击
286+
287+
防御场景:
288+
- http://127.0.0.1:6666\@1.1.1.1/ (反斜杠绕过)
289+
- http://127.0.0.1:6666@1.1.1.1/ (认证信息混淆)
290+
- http://1.1.1.1#@127.0.0.1:6666/ (片段注入)
291+
"""
225292
if not url:
226293
raise ValueError("URL is required")
227294

295+
# 1. 拒绝包含危险字符的 URL
296+
dangerous_patterns = [
297+
r'\\', # 反斜杠
298+
r'\s', # 空白字符
299+
r'%00', # 空字节
300+
r'%0a', # 换行符
301+
r'%0d', # 回车符
302+
]
303+
304+
url_lower = url.lower()
305+
for pattern in dangerous_patterns:
306+
if re.search(pattern, url_lower):
307+
raise ValueError("URL contains dangerous characters")
308+
309+
# 2. 解析 URL
228310
parsed = urlparse(url)
229311

230-
# 仅允许 http / https
312+
# 3. 仅允许 http / https
231313
if parsed.scheme not in ("http", "https"):
232314
raise ValueError("Only http and https are allowed")
233315

234-
host = parsed.hostname
235-
# 域名不能为空
236-
if not host:
237-
raise ValueError("Invalid URL")
316+
# 4. 提取主机名(从 netloc 中)
317+
netloc = parsed.netloc
318+
319+
# 5. 如果 netloc 中包含 @,说明有认证信息,需要特别处理
320+
if '@' in netloc:
321+
# 分离认证信息和主机
322+
auth_part, host_part = netloc.rsplit('@', 1)
323+
324+
# 检查认证部分是否包含危险的 IP 或端口信息
325+
# 攻击者可能在认证部分放置内网地址
326+
if ':' in auth_part or '.' in auth_part:
327+
raise ValueError("Authentication part contains suspicious content")
328+
329+
# 使用真实的主机部分
330+
actual_host = host_part.split(':')[0] if ':' in host_part else host_part
331+
else:
332+
# 没有认证信息,直接提取主机
333+
actual_host = parsed.hostname
334+
335+
# 6. 验证主机名不为空
336+
if not actual_host:
337+
raise ValueError("Invalid URL: missing hostname")
338+
339+
# 7. 验证主机不是 IP 地址形式的内网地址
340+
# 这样可以防止直接在 URL 中使用内网 IP
341+
try:
342+
# 尝试解析为 IP 地址
343+
ip_addr = ipaddress.ip_address(actual_host)
344+
if is_private_ip(actual_host):
345+
raise ValueError("Access to internal IP addresses is blocked")
346+
except ValueError as e:
347+
# 如果不是 IP 地址(是域名),则继续检查
348+
if "internal IP" in str(e):
349+
raise
350+
# 对于域名,检查其解析结果
351+
if is_private_ip(actual_host):
352+
raise ValueError("Access to internal IP addresses is blocked")
238353

239-
# 禁止访问内部、保留、环回、云 metadata
240-
if is_private_ip(host):
241-
raise ValueError("Access to internal IP addresses is blocked")
354+
# 8. 重新构建干净的 URL,移除可能的认证信息
355+
clean_netloc = actual_host
356+
if parsed.port:
357+
clean_netloc = f"{actual_host}:{parsed.port}"
242358

243-
return parsed
359+
clean_url = urlunparse((
360+
parsed.scheme,
361+
clean_netloc,
362+
parsed.path,
363+
parsed.params,
364+
parsed.query,
365+
'' # 移除 fragment,防止片段注入
366+
))
367+
368+
return clean_url
369+
370+
371+
def validate_url(url: str):
372+
"""验证 URL 是否安全(保留向后兼容)"""
373+
return validate_and_normalize_url(url)

apps/oss/views/file.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
from rest_framework.views import APIView
66
from rest_framework.views import Request
77
from common.auth import TokenAuth, AllTokenAuth
8+
from common.constants.permission_constants import ChatAuth
89
from common.log.log import log
910
from common.result import result
10-
from knowledge.api.file import FileUploadAPI, FileGetAPI
11+
from knowledge.api.file import FileUploadAPI, FileGetAPI, GetUrlContentAPI
1112
from oss.serializers.file import FileSerializer, get_url_content
1213

1314

@@ -73,11 +74,15 @@ class GetUrlView(APIView):
7374
@extend_schema(
7475
methods=['GET'],
7576
summary=_('Get url'),
77+
parameters=GetUrlContentAPI.get_parameters(),
7678
description=_('Get url'),
7779
operation_id=_('Get url'), # type: ignore
7880
tags=[_('Chat')] # type: ignore
7981
)
8082
def get(self, request: Request, application_id: str):
83+
if isinstance(request.auth, ChatAuth) and request.auth.application_id and str(
84+
request.auth.application_id) != application_id:
85+
return result.error(_('No permission'))
8186
url = request.query_params.get('url')
8287
result_data = get_url_content(url, application_id)
8388
return result.success(result_data)

0 commit comments

Comments
 (0)