Skip to content

Commit f7ffb37

Browse files
author
黑曜
committed
feat(openapi): 支持固定请求头参数
用户要求:从 OpenAPI protocolSpec 中解析字符串类型的 header 固定值,让 remote OpenAPI 工具调用时自动携带这些 header。 实现思路: - 仅将 in: header 且 schema.const 为字符串的参数解析为 fixed_headers - 仅将 in: header 且 schema.enum 只有一个字符串值的参数解析为 fixed_headers - 调用工具时合并 fixed_headers,并过滤同名调用参数,避免误传到 query 或 body - 明确不支持 parameter $ref 和非字符串固定值,并补充对应测试 Signed-off-by: 黑曜 <haotian.qht@alibaba-inc.com>
1 parent f35f01f commit f7ffb37

2 files changed

Lines changed: 442 additions & 9 deletions

File tree

agentrun/tool/api/openapi.py

Lines changed: 117 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
extracts operations as ToolInfo list, and makes HTTP calls via Server URL.
77
"""
88

9+
from copy import deepcopy
910
import json
10-
from typing import Any, Dict, Generator, List, Optional
11+
from typing import Any, Dict, Generator, List, Optional, Tuple
1112
from urllib.parse import urlparse, urlunparse
1213

1314
import httpx
@@ -204,6 +205,85 @@ def _resolve_schema(
204205

205206
return result
206207

208+
@staticmethod
209+
def _collect_parameters(
210+
path_item: Dict[str, Any], operation: Dict[str, Any]
211+
) -> List[Dict[str, Any]]:
212+
"""合并 path 和 operation 参数 / Merge path and operation parameters."""
213+
merged: Dict[Tuple[str, str], Dict[str, Any]] = {}
214+
for source in (
215+
path_item.get("parameters", []),
216+
operation.get("parameters", []),
217+
):
218+
if not isinstance(source, list):
219+
continue
220+
for param in source:
221+
if not isinstance(param, dict):
222+
continue
223+
key = (str(param.get("name", "")), str(param.get("in", "")))
224+
merged[key] = param
225+
return list(merged.values())
226+
227+
@staticmethod
228+
def _fixed_header_value_from_schema(
229+
schema: Optional[Dict[str, Any]],
230+
) -> Optional[str]:
231+
"""从 schema.const 或单值 enum 提取固定 header 值。"""
232+
if not schema or not isinstance(schema, dict):
233+
return None
234+
235+
if "const" in schema:
236+
value = schema.get("const")
237+
return value if isinstance(value, str) else None
238+
239+
enum_values = schema.get("enum")
240+
if isinstance(enum_values, list) and len(enum_values) == 1:
241+
value = enum_values[0]
242+
return value if isinstance(value, str) else None
243+
244+
return None
245+
246+
@staticmethod
247+
def _merge_headers(
248+
base_headers: Dict[str, str], fixed_headers: Dict[str, str]
249+
) -> Dict[str, str]:
250+
"""合并请求头,固定 header 按大小写不敏感规则覆盖已有值。"""
251+
merged = dict(base_headers)
252+
for fixed_key, fixed_value in fixed_headers.items():
253+
fixed_key_lower = fixed_key.lower()
254+
for existing_key in list(merged.keys()):
255+
if existing_key.lower() == fixed_key_lower:
256+
del merged[existing_key]
257+
merged[fixed_key] = fixed_value
258+
return merged
259+
260+
@staticmethod
261+
def _remove_fixed_header_arguments(
262+
arguments: Optional[Dict[str, Any]], fixed_headers: Dict[str, str]
263+
) -> Dict[str, Any]:
264+
"""从调用参数中移除固定 header,避免误传到 query 或 body。"""
265+
if not arguments:
266+
return {}
267+
268+
cleaned = dict(arguments)
269+
fixed_names = {name.lower() for name in fixed_headers}
270+
for name in list(cleaned.keys()):
271+
if name.lower() in fixed_names:
272+
del cleaned[name]
273+
return cleaned
274+
275+
def _prepare_request_inputs(
276+
self,
277+
arguments: Optional[Dict[str, Any]],
278+
fixed_headers: Dict[str, str],
279+
) -> Tuple[Dict[str, str], Dict[str, Any]]:
280+
"""准备请求头和参数 / Prepare request headers and arguments."""
281+
request_headers = self._merge_headers(self.headers, fixed_headers)
282+
request_arguments = self._remove_fixed_header_arguments(
283+
arguments, fixed_headers
284+
)
285+
return request_headers, request_arguments
286+
207287
def _parse_operations(self) -> List[Dict[str, Any]]:
208288
"""解析 OpenAPI Schema 中的所有 operations / Parse all operations from OpenAPI Schema"""
209289
if self._operations is not None:
@@ -235,15 +315,34 @@ def _parse_operations(self) -> List[Dict[str, Any]]:
235315
request_body_schema = self._resolve_schema(raw_schema)
236316

237317
parameters_schema = None
238-
parameters = operation.get("parameters", [])
318+
parameters = self._collect_parameters(path_item, operation)
319+
fixed_headers: Dict[str, str] = {}
239320
if parameters and isinstance(parameters, list):
240321
props = {}
241322
required_params = []
242323
for param in parameters:
243324
if not isinstance(param, dict):
244325
continue
245326
param_name = param.get("name", "")
246-
param_schema = param.get("schema", {"type": "string"})
327+
if not param_name:
328+
continue
329+
330+
raw_param_schema = param.get(
331+
"schema", {"type": "string"}
332+
)
333+
param_schema = self._resolve_schema(raw_param_schema)
334+
if not isinstance(param_schema, dict):
335+
param_schema = {"type": "string"}
336+
337+
if param.get("in") == "header":
338+
fixed_value = self._fixed_header_value_from_schema(
339+
param_schema
340+
)
341+
if fixed_value is not None:
342+
fixed_headers[str(param_name)] = fixed_value
343+
continue
344+
345+
param_schema = deepcopy(param_schema)
247346
param_schema["description"] = param.get(
248347
"description", ""
249348
)
@@ -267,6 +366,7 @@ def _parse_operations(self) -> List[Dict[str, Any]]:
267366
"method": method.upper(),
268367
"path": path,
269368
"input_schema": input_schema,
369+
"fixed_headers": fixed_headers,
270370
})
271371

272372
return self._operations
@@ -337,6 +437,10 @@ def call_tool(
337437

338438
url = f"{base_url.rstrip('/')}{target_operation['path']}"
339439
method = target_operation["method"]
440+
fixed_headers = target_operation.get("fixed_headers", {})
441+
request_headers, request_arguments = self._prepare_request_inputs(
442+
arguments, fixed_headers
443+
)
340444

341445
# 应用 RAM 签名
342446
url, auth = self._build_ram_auth(url)
@@ -346,12 +450,12 @@ def call_tool(
346450
)
347451

348452
with httpx.Client(
349-
headers=self.headers, timeout=30.0, auth=auth
453+
headers=request_headers, timeout=30.0, auth=auth
350454
) as client:
351455
if method in ("POST", "PUT", "PATCH"):
352-
response = client.request(method, url, json=arguments or {})
456+
response = client.request(method, url, json=request_arguments)
353457
else:
354-
response = client.request(method, url, params=arguments or {})
458+
response = client.request(method, url, params=request_arguments)
355459

356460
response.raise_for_status()
357461

@@ -396,6 +500,10 @@ async def call_tool_async(
396500

397501
url = f"{base_url.rstrip('/')}{target_operation['path']}"
398502
method = target_operation["method"]
503+
fixed_headers = target_operation.get("fixed_headers", {})
504+
request_headers, request_arguments = self._prepare_request_inputs(
505+
arguments, fixed_headers
506+
)
399507

400508
# 应用 RAM 签名
401509
url, auth = self._build_ram_auth(url)
@@ -406,15 +514,15 @@ async def call_tool_async(
406514
)
407515

408516
async with httpx.AsyncClient(
409-
headers=self.headers, timeout=30.0, auth=auth
517+
headers=request_headers, timeout=30.0, auth=auth
410518
) as client:
411519
if method in ("POST", "PUT", "PATCH"):
412520
response = await client.request(
413-
method, url, json=arguments or {}
521+
method, url, json=request_arguments
414522
)
415523
else:
416524
response = await client.request(
417-
method, url, params=arguments or {}
525+
method, url, params=request_arguments
418526
)
419527

420528
response.raise_for_status()

0 commit comments

Comments
 (0)