Skip to content

Commit ad0031c

Browse files
committed
Merge branch '3.6.x'
2 parents cf9f74b + b3166e8 commit ad0031c

File tree

2 files changed

+270
-7
lines changed

2 files changed

+270
-7
lines changed

framework/fel/python/plugins/fel_langchain_tools/langchain_tools.py

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# -- encoding: utf-8 --
2-
# Copyright (c) 2024 Huawei Technologies Co., Ltd. All Rights Reserved.
2+
# Copyright (c) 2024-2026 Huawei Technologies Co., Ltd. All Rights Reserved.
33
# This file is a part of the ModelEngine Project.
44
# Licensed under the MIT License. See License.txt in the project root for license information.
55
# ======================================================================================================================
6+
import ipaddress
67
import json
7-
from urllib.parse import quote_plus
8+
from urllib.parse import quote_plus, urlparse
89

910
from langchain.agents import AgentExecutor
1011
from langchain_community.agent_toolkits import JsonToolkit, create_json_agent
@@ -31,6 +32,69 @@
3132
from .langchain_registers import register_function_tools, register_api_tools
3233

3334

35+
# SSRF 防护黑名单网络段
36+
_BLOCKED_NETWORKS = [
37+
ipaddress.ip_network("127.0.0.0/8"),
38+
ipaddress.ip_network("10.0.0.0/8"),
39+
ipaddress.ip_network("172.16.0.0/12"),
40+
ipaddress.ip_network("192.168.0.0/16"),
41+
ipaddress.ip_network("169.254.0.0/16"),
42+
ipaddress.ip_network("0.0.0.0/8"),
43+
ipaddress.ip_network("::1/128"),
44+
ipaddress.ip_network("fc00::/7"),
45+
ipaddress.ip_network("fe80::/10"),
46+
]
47+
_BLOCKED_HOSTNAMES = {"localhost", "0.0.0.0"}
48+
49+
50+
def _validate_url(url: str) -> None:
51+
"""校验 URL 是否安全,拦截 SSRF 攻击目标地址。"""
52+
parsed_url = urlparse(url)
53+
hostname = parsed_url.hostname
54+
if not hostname:
55+
raise ValueError(f"Invalid URL: {url}")
56+
if hostname.lower() in _BLOCKED_HOSTNAMES:
57+
raise ValueError(f"Request blocked: URL '{url}' targets a restricted host ({hostname})")
58+
59+
try:
60+
ip = ipaddress.ip_address(hostname)
61+
except ValueError:
62+
# 域名先放行,不在此处执行 DNS 解析。
63+
return
64+
65+
mapped_ipv4 = getattr(ip, "ipv4_mapped", None)
66+
if mapped_ipv4:
67+
ip = mapped_ipv4
68+
69+
for network in _BLOCKED_NETWORKS:
70+
if ip in network:
71+
raise ValueError(f"Request blocked: URL '{url}' targets a restricted network ({network})")
72+
73+
74+
class SafeRequestsWrapper(TextRequestsWrapper):
75+
"""带 SSRF 防护的 HTTP 请求包装器。"""
76+
77+
def get(self, url: str, **kwargs) -> str:
78+
_validate_url(url)
79+
return super().get(url, **kwargs)
80+
81+
def post(self, url: str, data: dict, **kwargs) -> str:
82+
_validate_url(url)
83+
return super().post(url, data, **kwargs)
84+
85+
def patch(self, url: str, data: dict, **kwargs) -> str:
86+
_validate_url(url)
87+
return super().patch(url, data, **kwargs)
88+
89+
def put(self, url: str, data: dict, **kwargs) -> str:
90+
_validate_url(url)
91+
return super().put(url, data, **kwargs)
92+
93+
def delete(self, url: str, **kwargs) -> str:
94+
_validate_url(url)
95+
return super().delete(url, **kwargs)
96+
97+
3498
# 从app_engine加密传输敏感信息
3599
def get_db(sql_url: str, sql_table: str, sql_name: str, sql_pwd: str) -> SQLDatabase:
36100
return SQLDatabase.from_uri(
@@ -113,35 +177,35 @@ def langchain_sql_agent(kwargs) -> AgentExecutor:
113177

114178
def langchain_request_get(kwargs) -> BaseTool:
115179
return RequestsGetTool(
116-
requests_wrapper=TextRequestsWrapper(headers={}),
180+
requests_wrapper=SafeRequestsWrapper(headers={}),
117181
allow_dangerous_requests=True,
118182
)
119183

120184

121185
def langchain_request_post(kwargs) -> BaseTool:
122186
return RequestsPostTool(
123-
requests_wrapper=TextRequestsWrapper(headers={}),
187+
requests_wrapper=SafeRequestsWrapper(headers={}),
124188
allow_dangerous_requests=True,
125189
)
126190

127191

128192
def langchain_request_patch(kwargs) -> BaseTool:
129193
return RequestsPatchTool(
130-
requests_wrapper=TextRequestsWrapper(headers={}),
194+
requests_wrapper=SafeRequestsWrapper(headers={}),
131195
allow_dangerous_requests=True,
132196
)
133197

134198

135199
def langchain_request_delete(kwargs) -> BaseTool:
136200
return RequestsDeleteTool(
137-
requests_wrapper=TextRequestsWrapper(headers={}),
201+
requests_wrapper=SafeRequestsWrapper(headers={}),
138202
allow_dangerous_requests=True,
139203
)
140204

141205

142206
def langchain_request_put(kwargs) -> BaseTool:
143207
return RequestsPutTool(
144-
requests_wrapper=TextRequestsWrapper(headers={}),
208+
requests_wrapper=SafeRequestsWrapper(headers={}),
145209
allow_dangerous_requests=True,
146210
)
147211

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
# -- encoding: utf-8 --
2+
# Copyright (c) 2024-2026 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the ModelEngine Project.
4+
# Licensed under the MIT License. See License.txt in the project root for license information.
5+
# ======================================================================================================================
6+
import importlib
7+
import os
8+
import sys
9+
import types
10+
import unittest
11+
from unittest.mock import patch
12+
13+
14+
class _DummyTextRequestsWrapper:
15+
def __init__(self, headers=None):
16+
self.headers = headers or {}
17+
18+
def get(self, url: str, **kwargs):
19+
return f"get:{url}"
20+
21+
def post(self, url: str, data: dict, **kwargs):
22+
return f"post:{url}:{data}"
23+
24+
def patch(self, url: str, data: dict, **kwargs):
25+
return f"patch:{url}:{data}"
26+
27+
def put(self, url: str, data: dict, **kwargs):
28+
return f"put:{url}:{data}"
29+
30+
def delete(self, url: str, **kwargs):
31+
return f"delete:{url}"
32+
33+
34+
class _DummyRequestsTool:
35+
def __init__(self, requests_wrapper=None, allow_dangerous_requests=False, **kwargs):
36+
self.requests_wrapper = requests_wrapper
37+
self.allow_dangerous_requests = allow_dangerous_requests
38+
39+
40+
class _DummySQLDatabase:
41+
@classmethod
42+
def from_uri(cls, _uri):
43+
return cls()
44+
45+
46+
def _build_stub_modules():
47+
langchain_pkg = types.ModuleType("langchain")
48+
langchain_pkg.__path__ = []
49+
langchain_agents = types.ModuleType("langchain.agents")
50+
langchain_agents.AgentExecutor = type("AgentExecutor", (), {})
51+
52+
community_pkg = types.ModuleType("langchain_community")
53+
community_pkg.__path__ = []
54+
community_agent_toolkits = types.ModuleType("langchain_community.agent_toolkits")
55+
community_agent_toolkits.JsonToolkit = type("JsonToolkit", (), {})
56+
community_agent_toolkits.create_json_agent = lambda **kwargs: kwargs
57+
community_agent_toolkits.create_sql_agent = lambda *args, **kwargs: (args, kwargs)
58+
59+
community_tools_pkg = types.ModuleType("langchain_community.tools")
60+
community_tools_pkg.__path__ = []
61+
community_tools_json_pkg = types.ModuleType("langchain_community.tools.json")
62+
community_tools_json_pkg.__path__ = []
63+
community_tools_json_tool = types.ModuleType("langchain_community.tools.json.tool")
64+
community_tools_json_tool.JsonSpec = type("JsonSpec", (), {})
65+
66+
community_tools_requests_pkg = types.ModuleType("langchain_community.tools.requests")
67+
community_tools_requests_pkg.__path__ = []
68+
community_tools_requests_tool = types.ModuleType("langchain_community.tools.requests.tool")
69+
community_tools_requests_tool.RequestsGetTool = _DummyRequestsTool
70+
community_tools_requests_tool.RequestsPostTool = _DummyRequestsTool
71+
community_tools_requests_tool.RequestsPatchTool = _DummyRequestsTool
72+
community_tools_requests_tool.RequestsPutTool = _DummyRequestsTool
73+
community_tools_requests_tool.RequestsDeleteTool = _DummyRequestsTool
74+
75+
community_tools_sql_pkg = types.ModuleType("langchain_community.tools.sql_database")
76+
community_tools_sql_pkg.__path__ = []
77+
community_tools_sql_tool = types.ModuleType("langchain_community.tools.sql_database.tool")
78+
community_tools_sql_tool.InfoSQLDatabaseTool = type("InfoSQLDatabaseTool", (), {})
79+
community_tools_sql_tool.ListSQLDatabaseTool = type("ListSQLDatabaseTool", (), {})
80+
community_tools_sql_tool.QuerySQLCheckerTool = type("QuerySQLCheckerTool", (), {})
81+
community_tools_sql_tool.QuerySQLDataBaseTool = type("QuerySQLDataBaseTool", (), {})
82+
83+
community_utilities_pkg = types.ModuleType("langchain_community.utilities")
84+
community_utilities_pkg.__path__ = []
85+
community_utilities_requests = types.ModuleType("langchain_community.utilities.requests")
86+
community_utilities_requests.TextRequestsWrapper = _DummyTextRequestsWrapper
87+
community_utilities_sql = types.ModuleType("langchain_community.utilities.sql_database")
88+
community_utilities_sql.SQLDatabase = _DummySQLDatabase
89+
90+
core_pkg = types.ModuleType("langchain_core")
91+
core_pkg.__path__ = []
92+
core_tools = types.ModuleType("langchain_core.tools")
93+
core_tools.BaseTool = object
94+
95+
langchain_openai = types.ModuleType("langchain_openai")
96+
langchain_openai.ChatOpenAI = type("ChatOpenAI", (), {})
97+
98+
registers_module = types.ModuleType("plugins.fel_langchain_tools.langchain_registers")
99+
registers_module.register_function_tools = lambda *args, **kwargs: None
100+
registers_module.register_api_tools = lambda *args, **kwargs: None
101+
102+
return {
103+
"langchain": langchain_pkg,
104+
"langchain.agents": langchain_agents,
105+
"langchain_community": community_pkg,
106+
"langchain_community.agent_toolkits": community_agent_toolkits,
107+
"langchain_community.tools": community_tools_pkg,
108+
"langchain_community.tools.json": community_tools_json_pkg,
109+
"langchain_community.tools.json.tool": community_tools_json_tool,
110+
"langchain_community.tools.requests": community_tools_requests_pkg,
111+
"langchain_community.tools.requests.tool": community_tools_requests_tool,
112+
"langchain_community.tools.sql_database": community_tools_sql_pkg,
113+
"langchain_community.tools.sql_database.tool": community_tools_sql_tool,
114+
"langchain_community.utilities": community_utilities_pkg,
115+
"langchain_community.utilities.requests": community_utilities_requests,
116+
"langchain_community.utilities.sql_database": community_utilities_sql,
117+
"langchain_core": core_pkg,
118+
"langchain_core.tools": core_tools,
119+
"langchain_openai": langchain_openai,
120+
"plugins.fel_langchain_tools.langchain_registers": registers_module,
121+
}
122+
123+
124+
def _load_module_with_stubs():
125+
fel_python_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
126+
if fel_python_path not in sys.path:
127+
sys.path.insert(0, fel_python_path)
128+
129+
sys.modules.pop("plugins.fel_langchain_tools.langchain_tools", None)
130+
with patch.dict(sys.modules, _build_stub_modules()):
131+
return importlib.import_module("plugins.fel_langchain_tools.langchain_tools")
132+
133+
134+
class TestLangchainToolsSSRF(unittest.TestCase):
135+
@classmethod
136+
def setUpClass(cls):
137+
cls.module = _load_module_with_stubs()
138+
139+
def test_validate_url_blocks_ssrf_targets(self):
140+
blocked_urls = [
141+
"http://169.254.169.254/latest/meta-data/",
142+
"http://127.0.0.1:8080/admin",
143+
"http://10.0.0.1/internal",
144+
"http://192.168.1.1/config",
145+
"http://[::ffff:169.254.169.254]/latest/meta-data/",
146+
"http://[::ffff:127.0.0.1]/",
147+
"http://[::ffff:10.0.0.1]/",
148+
"http://localhost/health",
149+
"http://0.0.0.0/status",
150+
"http://[::1]/",
151+
"http://[fc00::1]/",
152+
"http://[fe80::1]/",
153+
]
154+
155+
for url in blocked_urls:
156+
with self.subTest(url=url):
157+
with self.assertRaisesRegex(ValueError, "Request blocked"):
158+
self.module._validate_url(url)
159+
160+
def test_validate_url_allows_external_targets(self):
161+
allowed_urls = [
162+
"https://api.example.com/data",
163+
"http://8.8.8.8/health",
164+
"https://httpbin.org/get",
165+
]
166+
167+
for url in allowed_urls:
168+
with self.subTest(url=url):
169+
self.module._validate_url(url)
170+
171+
def test_validate_url_rejects_invalid_url(self):
172+
with self.assertRaisesRegex(ValueError, "Invalid URL"):
173+
self.module._validate_url("http:///path-only")
174+
175+
def test_http_tool_builders_use_safe_wrapper(self):
176+
builders = [
177+
self.module.langchain_request_get,
178+
self.module.langchain_request_post,
179+
self.module.langchain_request_patch,
180+
self.module.langchain_request_delete,
181+
self.module.langchain_request_put,
182+
]
183+
for builder in builders:
184+
with self.subTest(builder=builder.__name__):
185+
tool = builder({})
186+
self.assertIsInstance(tool.requests_wrapper, self.module.SafeRequestsWrapper)
187+
self.assertTrue(tool.allow_dangerous_requests)
188+
189+
def test_safe_wrapper_blocks_and_allows_requests(self):
190+
wrapper = self.module.SafeRequestsWrapper(headers={})
191+
with self.assertRaisesRegex(ValueError, "Request blocked"):
192+
wrapper.get("http://127.0.0.1/api")
193+
194+
result = wrapper.get("https://api.example.com/data")
195+
self.assertEqual(result, "get:https://api.example.com/data")
196+
197+
198+
if __name__ == "__main__":
199+
unittest.main()

0 commit comments

Comments
 (0)