Skip to content

Commit f1ae3a4

Browse files
committed
refactor(auth): simplify URL resolution logic and enhance test coverage
The _get_openapi_base_url method was refactored to remove complex FC region logic and simplify URL selection. The implementation now prioritizes intranet URLs when available, falling back to internet URLs. This change improves code readability and maintainability while maintaining the same functional behavior. Additionally, comprehensive test suites were added for RAM signature helper functions, ControlAPI client methods, and exception handling to ensure robust authentication and error management. 测试套件已增加以验证 RAM 签名辅助函数、ControlAPI 客户端方法和异常处理, 确保身份验证和错误管理的健壮性。 Change-Id: Ia9d8ff6d2bfd37ec858413f13686fcee8fd6d912 Signed-off-by: OhYee <oyohyee@oyohyee.com>
1 parent 7ac57e5 commit f1ae3a4

File tree

5 files changed

+338
-27
lines changed

5 files changed

+338
-27
lines changed

agentrun/toolset/__toolset_async_template.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -94,20 +94,13 @@ def _get_openapi_auth_defaults(
9494
return headers, query
9595

9696
def _get_openapi_base_url(self) -> Optional[str]:
97-
import os
98-
99-
fc_region = os.getenv("FC_REGION")
100-
arn = pydash.get(self, "status.outputs.function_arn", "")
101-
102-
if fc_region and arn and pydash.get(arn.split(":"), "[2]"):
103-
# 在同一个 region,则使用内网地址
104-
return pydash.get(
105-
self,
106-
"status.outputs.urls.intranet_url",
107-
None,
108-
)
97+
intranet_url: Optional[str] = pydash.get(
98+
self, "status.outputs.urls.intranet_url", None
99+
)
100+
if intranet_url:
101+
return intranet_url
109102

110-
return None
103+
return pydash.get(self, "status.outputs.urls.internet_url", None)
111104

112105
async def get_async(self, config: Optional[Config] = None):
113106
if self.name is None:

agentrun/toolset/toolset.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -109,20 +109,13 @@ def _get_openapi_auth_defaults(
109109
return headers, query
110110

111111
def _get_openapi_base_url(self) -> Optional[str]:
112-
import os
113-
114-
fc_region = os.getenv("FC_REGION")
115-
arn = pydash.get(self, "status.outputs.function_arn", "")
116-
117-
if fc_region and arn and pydash.get(arn.split(":"), "[2]"):
118-
# 在同一个 region,则使用内网地址
119-
return pydash.get(
120-
self,
121-
"status.outputs.urls.intranet_url",
122-
None,
123-
)
112+
intranet_url: Optional[str] = pydash.get(
113+
self, "status.outputs.urls.intranet_url", None
114+
)
115+
if intranet_url:
116+
return intranet_url
124117

125-
return None
118+
return pydash.get(self, "status.outputs.urls.internet_url", None)
126119

127120
async def get_async(self, config: Optional[Config] = None):
128121
if self.name is None:

tests/unittests/ram_signature/test_signer.py

Lines changed: 163 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,13 @@
55

66
import pytest
77

8-
from agentrun.ram_signature.python.signer import get_agentrun_signed_headers
8+
from agentrun.utils.ram_signature.signer import (
9+
_canonical_headers,
10+
_canonical_uri,
11+
_percent_encode,
12+
get_agentrun_signed_headers,
13+
get_agentrun_signed_headers_with_debug,
14+
)
915

1016

1117
class TestRamSignatureStandalone:
@@ -290,3 +296,159 @@ def test_scenario_3_post_query_empty_body_content_type_json(self):
290296
)
291297
if ref is not None:
292298
assert sig == ref, "SDK 与官方包(ref) 签名应一致"
299+
300+
301+
class TestSignerHelperFunctions:
302+
"""测试签名辅助函数的边界情况"""
303+
304+
def test_percent_encode_none(self):
305+
"""测试 _percent_encode(None) 返回空字符串"""
306+
assert _percent_encode(None) == ""
307+
308+
def test_percent_encode_tilde(self):
309+
"""测试 _percent_encode 正确处理 ~ 字符"""
310+
assert "~" in _percent_encode("a~b")
311+
312+
def test_canonical_uri_empty(self):
313+
"""测试 _canonical_uri 空字符串返回 /"""
314+
assert _canonical_uri("") == "/"
315+
316+
def test_canonical_uri_none(self):
317+
"""测试 _canonical_uri None 返回 /"""
318+
assert _canonical_uri(None) == "/"
319+
320+
def test_canonical_uri_normal(self):
321+
"""测试 _canonical_uri 正常路径"""
322+
assert _canonical_uri("/path/to/resource") == "/path/to/resource"
323+
324+
def test_canonical_headers_skips_none_values(self):
325+
"""测试 _canonical_headers 跳过 value 为 None 的 header"""
326+
headers = {
327+
"host": "example.com",
328+
"x-acs-date": "2026-01-01T00:00:00Z",
329+
"x-acs-skip": None,
330+
}
331+
canon, signed = _canonical_headers(headers)
332+
assert "x-acs-skip" not in signed
333+
assert "host" in signed
334+
335+
336+
class TestSignerNaiveDatetime:
337+
"""测试 naive datetime(无时区信息)的处理"""
338+
339+
def test_naive_datetime_gets_utc(self):
340+
"""测试 naive datetime 被自动设置为 UTC"""
341+
naive_time = datetime(2026, 1, 1, 12, 0, 0)
342+
headers = get_agentrun_signed_headers(
343+
url="https://x.agentrun-data.cn-hangzhou.aliyuncs.com/path",
344+
access_key_id="ak",
345+
access_key_secret="sk",
346+
sign_time=naive_time,
347+
)
348+
assert headers["x-acs-date"] == "2026-01-01T12:00:00Z"
349+
350+
351+
class TestSignerWithDebug:
352+
"""测试 get_agentrun_signed_headers_with_debug 函数"""
353+
354+
def test_returns_headers_and_debug(self):
355+
"""测试返回 headers 和 debug 信息"""
356+
t = datetime(2026, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
357+
headers, debug = get_agentrun_signed_headers_with_debug(
358+
url="https://x.agentrun-data.cn-hangzhou.aliyuncs.com/path",
359+
access_key_id="ak",
360+
access_key_secret="sk",
361+
sign_time=t,
362+
)
363+
assert "Agentrun-Authorization" in headers
364+
assert "x-acs-date" in headers
365+
assert "canonical_request" in debug
366+
assert "string_to_sign" in debug
367+
assert "signing_key_hex" in debug
368+
assert "signature" in debug
369+
370+
def test_debug_signature_matches_headers(self):
371+
"""测试 debug 中的 signature 与 headers 中的一致"""
372+
t = datetime(2026, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
373+
headers, debug = get_agentrun_signed_headers_with_debug(
374+
url="https://x.agentrun-data.cn-hangzhou.aliyuncs.com/path",
375+
access_key_id="ak",
376+
access_key_secret="sk",
377+
sign_time=t,
378+
)
379+
auth = headers["Agentrun-Authorization"]
380+
sig_in_auth = auth.split("Signature=")[-1]
381+
assert sig_in_auth == debug["signature"]
382+
383+
def test_debug_matches_non_debug_version(self):
384+
"""测试 debug 版本与非 debug 版本签名一致"""
385+
t = datetime(2026, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
386+
opts = dict(
387+
url="https://x.agentrun-data.cn-hangzhou.aliyuncs.com/path?a=1",
388+
method="POST",
389+
access_key_id="ak",
390+
access_key_secret="sk",
391+
sign_time=t,
392+
)
393+
headers_normal = get_agentrun_signed_headers(**opts)
394+
headers_debug, _ = get_agentrun_signed_headers_with_debug(**opts)
395+
assert (
396+
headers_normal["Agentrun-Authorization"]
397+
== headers_debug["Agentrun-Authorization"]
398+
)
399+
400+
def test_debug_requires_ak_sk(self):
401+
"""测试 debug 版本也要求 ak/sk"""
402+
with pytest.raises(ValueError, match="Access Key ID and Secret"):
403+
get_agentrun_signed_headers_with_debug(
404+
url="https://x.agentrun-data.cn-hangzhou.aliyuncs.com/",
405+
access_key_id="",
406+
access_key_secret="sk",
407+
)
408+
409+
def test_debug_with_security_token(self):
410+
"""测试 debug 版本带 security_token"""
411+
t = datetime(2026, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
412+
headers, debug = get_agentrun_signed_headers_with_debug(
413+
url="https://x.agentrun-data.cn-hangzhou.aliyuncs.com/path",
414+
access_key_id="ak",
415+
access_key_secret="sk",
416+
security_token="sts-token",
417+
sign_time=t,
418+
)
419+
assert "x-acs-security-token" in headers
420+
assert "x-acs-security-token" in headers["Agentrun-Authorization"]
421+
422+
def test_debug_with_content_type(self):
423+
"""测试 debug 版本带 content_type"""
424+
t = datetime(2026, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
425+
headers, debug = get_agentrun_signed_headers_with_debug(
426+
url="https://x.agentrun-data.cn-hangzhou.aliyuncs.com/path",
427+
access_key_id="ak",
428+
access_key_secret="sk",
429+
content_type="application/json",
430+
sign_time=t,
431+
)
432+
assert "content-type" in headers["Agentrun-Authorization"]
433+
434+
def test_debug_with_query_params(self):
435+
"""测试 debug 版本带 query 参数"""
436+
t = datetime(2026, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
437+
headers, debug = get_agentrun_signed_headers_with_debug(
438+
url="https://x.agentrun-data.cn-hangzhou.aliyuncs.com/path?foo=bar&zoo=",
439+
access_key_id="ak",
440+
access_key_secret="sk",
441+
sign_time=t,
442+
)
443+
assert "foo=bar" in debug["canonical_request"]
444+
445+
def test_debug_naive_datetime(self):
446+
"""测试 debug 版本处理 naive datetime"""
447+
naive_time = datetime(2026, 1, 1, 12, 0, 0)
448+
headers, debug = get_agentrun_signed_headers_with_debug(
449+
url="https://x.agentrun-data.cn-hangzhou.aliyuncs.com/path",
450+
access_key_id="ak",
451+
access_key_secret="sk",
452+
sign_time=naive_time,
453+
)
454+
assert headers["x-acs-date"] == "2026-01-01T12:00:00Z"

tests/unittests/utils/test_control_api.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,3 +277,139 @@ def test_get_devs_client_with_read_timeout(self, mock_client_class):
277277
config_arg = call_args[0][0]
278278
assert config_arg.connect_timeout == 300
279279
assert config_arg.read_timeout == 60000
280+
281+
282+
class TestControlAPIGetBailianClient:
283+
"""测试 ControlAPI._get_bailian_client"""
284+
285+
@patch("agentrun.utils.control_api.BailianClient")
286+
def test_get_bailian_client_basic(self, mock_client_class):
287+
"""测试获取基本百炼客户端"""
288+
config = Config(
289+
access_key_id="ak",
290+
access_key_secret="sk",
291+
region_id="cn-hangzhou",
292+
)
293+
api = ControlAPI(config=config)
294+
295+
mock_client = MagicMock()
296+
mock_client_class.return_value = mock_client
297+
298+
result = api._get_bailian_client()
299+
300+
assert mock_client_class.called
301+
call_args = mock_client_class.call_args
302+
config_arg = call_args[0][0]
303+
assert config_arg.access_key_id == "ak"
304+
assert config_arg.access_key_secret == "sk"
305+
assert config_arg.region_id == "cn-hangzhou"
306+
307+
@patch("agentrun.utils.control_api.BailianClient")
308+
def test_get_bailian_client_strips_https_prefix(self, mock_client_class):
309+
"""测试获取百炼客户端时去除 https:// 前缀"""
310+
config = Config(
311+
access_key_id="ak",
312+
access_key_secret="sk",
313+
bailian_endpoint="https://bailian.cn-hangzhou.aliyuncs.com",
314+
)
315+
api = ControlAPI(config=config)
316+
317+
mock_client = MagicMock()
318+
mock_client_class.return_value = mock_client
319+
320+
api._get_bailian_client()
321+
322+
call_args = mock_client_class.call_args
323+
config_arg = call_args[0][0]
324+
assert config_arg.endpoint == "bailian.cn-hangzhou.aliyuncs.com"
325+
326+
@patch("agentrun.utils.control_api.BailianClient")
327+
def test_get_bailian_client_strips_http_prefix(self, mock_client_class):
328+
"""测试获取百炼客户端时去除 http:// 前缀"""
329+
config = Config(
330+
access_key_id="ak",
331+
access_key_secret="sk",
332+
bailian_endpoint="http://bailian.custom.com",
333+
)
334+
api = ControlAPI(config=config)
335+
336+
mock_client = MagicMock()
337+
mock_client_class.return_value = mock_client
338+
339+
api._get_bailian_client()
340+
341+
call_args = mock_client_class.call_args
342+
config_arg = call_args[0][0]
343+
assert config_arg.endpoint == "bailian.custom.com"
344+
345+
346+
class TestControlAPIGetGPDBClient:
347+
"""测试 ControlAPI._get_gpdb_client"""
348+
349+
@patch("agentrun.utils.control_api.GPDBClient")
350+
def test_get_gpdb_client_known_region(self, mock_client_class):
351+
"""测试已知 region 使用通用 endpoint"""
352+
config = Config(
353+
access_key_id="ak",
354+
access_key_secret="sk",
355+
region_id="cn-hangzhou",
356+
)
357+
api = ControlAPI(config=config)
358+
359+
mock_client = MagicMock()
360+
mock_client_class.return_value = mock_client
361+
362+
api._get_gpdb_client()
363+
364+
call_args = mock_client_class.call_args
365+
config_arg = call_args[0][0]
366+
assert config_arg.endpoint == "gpdb.aliyuncs.com"
367+
368+
@patch("agentrun.utils.control_api.GPDBClient")
369+
def test_get_gpdb_client_unknown_region(self, mock_client_class):
370+
"""测试未知 region 使用区域级别 endpoint"""
371+
config = Config(
372+
access_key_id="ak",
373+
access_key_secret="sk",
374+
region_id="us-west-1",
375+
)
376+
api = ControlAPI(config=config)
377+
378+
mock_client = MagicMock()
379+
mock_client_class.return_value = mock_client
380+
381+
api._get_gpdb_client()
382+
383+
call_args = mock_client_class.call_args
384+
config_arg = call_args[0][0]
385+
assert config_arg.endpoint == "gpdb.us-west-1.aliyuncs.com"
386+
387+
@patch("agentrun.utils.control_api.GPDBClient")
388+
def test_get_gpdb_client_all_known_regions(self, mock_client_class):
389+
"""测试所有已知 region 使用通用 endpoint"""
390+
known_regions = [
391+
"cn-beijing",
392+
"cn-hangzhou",
393+
"cn-shanghai",
394+
"cn-shenzhen",
395+
"cn-hongkong",
396+
"ap-southeast-1",
397+
]
398+
for region in known_regions:
399+
config = Config(
400+
access_key_id="ak",
401+
access_key_secret="sk",
402+
region_id=region,
403+
)
404+
api = ControlAPI(config=config)
405+
406+
mock_client = MagicMock()
407+
mock_client_class.return_value = mock_client
408+
409+
api._get_gpdb_client()
410+
411+
call_args = mock_client_class.call_args
412+
config_arg = call_args[0][0]
413+
assert (
414+
config_arg.endpoint == "gpdb.aliyuncs.com"
415+
), f"Region {region} should use gpdb.aliyuncs.com"

0 commit comments

Comments
 (0)