Skip to content

Commit 544d737

Browse files
feat: check and support prompt caching for all models (#1482)
Co-authored-by: Wendong-Fan <w3ndong.fan@gmail.com>
1 parent c6ba7e0 commit 544d737

14 files changed

Lines changed: 264 additions & 82 deletions

File tree

backend/app/agent/agent_model.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from app.agent.listen_chat_agent import ListenChatAgent, logger
2626
from app.model.chat import AgentModelConfig, Chat
27+
from app.model.model_platform import patch_bedrock_cloud_config
2728
from app.service.task import ActionCreateAgentData, Agents, get_task_lock
2829
from app.utils.event_loop_utils import _schedule_async_task
2930

@@ -80,7 +81,14 @@ def agent_model(
8081
for attr in config_attrs:
8182
effective_config[attr] = getattr(options, attr)
8283
extra_params = options.extra_params or {}
83-
84+
# Cloud mode: inject default Bedrock region and adjust URL for proxy.
85+
if (
86+
effective_config.get("model_platform") == "aws-bedrock-converse"
87+
and options.is_cloud()
88+
):
89+
effective_config["api_url"], extra_params = patch_bedrock_cloud_config(
90+
effective_config["api_url"], extra_params
91+
)
8492
init_param_keys = {
8593
"api_version",
8694
"azure_ad_token",
@@ -90,6 +98,10 @@ def agent_model(
9098
"client",
9199
"async_client",
92100
"azure_deployment_name",
101+
"region_name",
102+
"aws_access_key_id",
103+
"aws_secret_access_key",
104+
"aws_session_token",
93105
}
94106

95107
init_params = {}
@@ -113,6 +125,26 @@ def agent_model(
113125
else:
114126
model_config[k] = v
115127

128+
# Auto-inject prompt caching based on model platform
129+
try:
130+
model_platform_enum = ModelPlatformType(
131+
effective_config["model_platform"].lower()
132+
)
133+
if model_platform_enum in {
134+
ModelPlatformType.ANTHROPIC,
135+
ModelPlatformType.AWS_BEDROCK_CONVERSE,
136+
}:
137+
model_config.setdefault("cache_control", "5m")
138+
elif model_platform_enum == ModelPlatformType.OPENAI:
139+
model_config.setdefault(
140+
"prompt_cache_key", str(options.project_id)
141+
)
142+
except (ValueError, AttributeError):
143+
logging.error(
144+
f"Invalid model platform: {effective_config['model_platform']}",
145+
exc_info=True,
146+
)
147+
116148
if agent_name == Agents.task_agent:
117149
model_config["stream"] = True
118150
if agent_name == Agents.browser_agent:
@@ -137,10 +169,8 @@ def agent_model(
137169
model_platform_enum = None
138170

139171
if effective_config["model_platform"].lower() == "anthropic":
140-
if model_config.get("cache_control") is None:
141-
model_config["cache_control"] = "5m"
142172
if model_config.get("max_tokens") is None:
143-
model_config["max_tokens"] = 64000
173+
model_config["max_tokens"] = 128000
144174

145175
model = ModelFactory.create(
146176
model_platform=effective_config["model_platform"],

backend/app/agent/factory/mcp.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,18 @@
1212
# limitations under the License.
1313
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
1414
import asyncio
15+
import logging
1516
import uuid
1617

1718
from camel.models import ModelFactory
19+
from camel.types import ModelPlatformType
1820

1921
from app.agent.listen_chat_agent import ListenChatAgent, logger
2022
from app.agent.prompt import MCP_SYS_PROMPT
2123
from app.agent.toolkit.mcp_search_toolkit import McpSearchToolkit
2224
from app.agent.tools import get_mcp_tools
2325
from app.model.chat import Chat
26+
from app.model.model_platform import patch_bedrock_cloud_config
2427
from app.service.task import ActionCreateAgentData, Agents, get_task_lock
2528

2629

@@ -73,6 +76,38 @@ async def mcp_agent(options: Chat):
7376
)
7477
)
7578
)
79+
extra_params = {
80+
k: v
81+
for k, v in (options.extra_params or {}).items()
82+
if k not in ["model_platform", "model_type", "api_key", "url"]
83+
}
84+
api_url = options.api_url
85+
if options.model_platform == "aws-bedrock-converse" and options.is_cloud():
86+
api_url, extra_params = patch_bedrock_cloud_config(
87+
api_url, extra_params
88+
)
89+
90+
# Build model_config_dict with prompt caching
91+
model_config_dict = {}
92+
if options.is_cloud():
93+
model_config_dict["user"] = str(options.project_id)
94+
try:
95+
platform_enum = ModelPlatformType(options.model_platform.lower())
96+
if platform_enum in {
97+
ModelPlatformType.ANTHROPIC,
98+
ModelPlatformType.AWS_BEDROCK_CONVERSE,
99+
}:
100+
model_config_dict.setdefault("cache_control", "5m")
101+
elif platform_enum == ModelPlatformType.OPENAI:
102+
model_config_dict.setdefault(
103+
"prompt_cache_key", str(options.project_id)
104+
)
105+
except (ValueError, AttributeError):
106+
logging.error(
107+
f"Invalid model platform: {options.model_platform}",
108+
exc_info=True,
109+
)
110+
76111
return ListenChatAgent(
77112
options.project_id,
78113
Agents.mcp_agent,
@@ -81,20 +116,10 @@ async def mcp_agent(options: Chat):
81116
model_platform=options.model_platform,
82117
model_type=options.model_type,
83118
api_key=options.api_key,
84-
url=options.api_url,
85-
model_config_dict=(
86-
{
87-
"user": str(options.project_id),
88-
}
89-
if options.is_cloud()
90-
else None
91-
),
119+
url=api_url,
120+
model_config_dict=model_config_dict or None,
92121
timeout=600, # 10 minutes
93-
**{
94-
k: v
95-
for k, v in (options.extra_params or {}).items()
96-
if k not in ["model_platform", "model_type", "api_key", "url"]
97-
},
122+
**extra_params,
98123
),
99124
# output_language=options.language,
100125
tools=tools,

backend/app/component/model_validation.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from camel.agents import ChatAgent
2020
from camel.models import ModelFactory, ModelProcessingError
2121

22+
from app.model.model_platform import BEDROCK_CONVERSE_REGION
23+
2224
logger = logging.getLogger("model_validation")
2325

2426
# Expected result from tool execution for validation
@@ -231,6 +233,8 @@ def create_agent(
231233
model_config_dict = dict(model_config_dict or {})
232234
if model_config_dict.get("max_tokens") is None:
233235
model_config_dict["max_tokens"] = 4096
236+
if str(platform).lower() == "aws-bedrock-converse":
237+
kwargs.setdefault("region_name", BEDROCK_CONVERSE_REGION)
234238
model = ModelFactory.create(
235239
model_platform=platform,
236240
model_type=mtype,
@@ -334,6 +338,8 @@ def validate_model_with_details(
334338
model_config_dict = dict(model_config_dict or {})
335339
if model_config_dict.get("max_tokens") is None:
336340
model_config_dict["max_tokens"] = 4096
341+
if str(model_platform).lower() == "aws-bedrock-converse":
342+
kwargs.setdefault("region_name", BEDROCK_CONVERSE_REGION)
337343
model = ModelFactory.create(
338344
model_platform=model_platform,
339345
model_type=model_type,

backend/app/model/chat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def get_uvx_env(self) -> dict[str, str]:
116116
)
117117

118118
def is_cloud(self):
119-
return self.api_url is not None and "44.247.171.124" in self.api_url
119+
return self.api_url is not None and "eigent-proxy" in self.api_url
120120

121121
def file_save_path(self, path: str | None = None):
122122
email = re.sub(r'[\\/*?:"<>|\s]', "_", self.email.split("@")[0]).strip(

backend/app/model/model_platform.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,24 @@
2424
"llama.cpp": "openai-compatible-model",
2525
}
2626

27+
# Bedrock Converse requires a region during model initialization.
28+
BEDROCK_CONVERSE_REGION: Final[str] = "us-west-2"
29+
30+
31+
def patch_bedrock_cloud_config(
32+
api_url: str, extra_params: dict
33+
) -> tuple[str, dict]:
34+
"""Patch API URL and extra_params for Bedrock Converse in cloud mode.
35+
36+
Appends '/bedrock' to the proxy URL and defaults the region.
37+
Returns the updated (api_url, extra_params).
38+
"""
39+
extra_params = dict(extra_params)
40+
extra_params.setdefault("region_name", BEDROCK_CONVERSE_REGION)
41+
if not api_url.rstrip("/").endswith("/bedrock"):
42+
api_url = api_url + "/bedrock"
43+
return api_url, extra_params
44+
2745

2846
def normalize_model_platform(platform: str) -> str:
2947
"""Normalize provider aliases to supported model platform names."""

backend/app/utils/single_agent_worker.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,15 @@
1212
# limitations under the License.
1313
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
1414

15+
import asyncio
1516
import datetime
1617
import logging
18+
from collections.abc import Awaitable, Callable
1719

18-
from camel.agents.chat_agent import AsyncStreamingChatAgentResponse
20+
from camel.agents.chat_agent import (
21+
AsyncStreamingChatAgentResponse,
22+
ChatAgentResponse,
23+
)
1924
from camel.societies.workforce.prompts import PROCESS_TASK_PROMPT
2025
from camel.societies.workforce.single_agent_worker import (
2126
SingleAgentWorker as BaseSingleAgentWorker,
@@ -67,7 +72,13 @@ def __init__(
6772
self.worker = worker # change type hint
6873

6974
async def _process_task(
70-
self, task: Task, dependencies: list[Task], stream_callback=None
75+
self,
76+
task: Task,
77+
dependencies: list[Task],
78+
stream_callback: Callable[
79+
["ChatAgentResponse"], Awaitable[None] | None
80+
]
81+
| None = None,
7182
) -> TaskState:
7283
r"""Processes a task with its dependencies using an efficient agent
7384
management system.
@@ -146,6 +157,10 @@ async def _process_task(
146157
async for chunk in response:
147158
chunk_count += 1
148159
last_chunk = chunk
160+
if stream_callback:
161+
maybe = stream_callback(chunk)
162+
if asyncio.iscoroutine(maybe):
163+
await maybe
149164
if chunk.msg and chunk.msg.content:
150165
accumulated_content += chunk.msg.content
151166
logger.info(
@@ -186,6 +201,10 @@ async def _process_task(
186201
last_chunk = None
187202
async for chunk in response:
188203
last_chunk = chunk
204+
if stream_callback:
205+
maybe = stream_callback(chunk)
206+
if asyncio.iscoroutine(maybe):
207+
await maybe
189208
if chunk.msg:
190209
if chunk.msg.content:
191210
accumulated_content += chunk.msg.content

backend/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ readme = "README.md"
66
requires-python = ">=3.11,<3.12"
77
dependencies = [
88
"pip>=23.0",
9-
"camel-ai[eigent]==0.2.90a6",
9+
"camel-ai[eigent]==0.2.90",
1010
"fastapi>=0.115.12",
1111
"fastapi-babel>=1.0.0",
1212
"uvicorn[standard]>=0.34.2",

backend/tests/app/component/test_model_validation.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,25 @@ def test_create_agent_invalid_model_platform():
215215
create_agent(model_platform=None, model_type="GPT_4O_MINI")
216216

217217

218+
@pytest.mark.unit
219+
@patch("app.component.model_validation.ModelFactory.create")
220+
@patch("app.component.model_validation.ChatAgent")
221+
def test_create_agent_hardcodes_bedrock_converse_region(
222+
mock_chat_agent, mock_model_factory
223+
):
224+
"""Test Bedrock Converse validation always uses the hardcoded region."""
225+
mock_model_factory.return_value = MagicMock()
226+
mock_chat_agent.return_value = MagicMock()
227+
228+
create_agent(
229+
model_platform="aws-bedrock-converse",
230+
model_type="anthropic.claude-3-5-sonnet",
231+
api_key="test_key",
232+
)
233+
234+
assert mock_model_factory.call_args.kwargs["region_name"] == "us-west-2"
235+
236+
218237
@pytest.mark.unit
219238
def test_validation_missing_model_type():
220239
"""Test validation with missing model type."""

backend/uv.lock

Lines changed: 5 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)