Skip to content

Commit de443fd

Browse files
authored
feat(bot): Bot session compact by using openviking, ov client add function and param (volcengine#2284)
* bot session compact by using openviking * fix bug * fix bug * 默认值 * 默认值 * fix pr bug * fix merge bug
1 parent 676654b commit de443fd

21 files changed

Lines changed: 3077 additions & 148 deletions

bot/docs/vikingbot-openviking-context-plan.md

Lines changed: 464 additions & 0 deletions
Large diffs are not rendered by default.

bot/tests/test_agent_loop_outcome.py

Lines changed: 622 additions & 0 deletions
Large diffs are not rendered by default.

bot/tests/test_openviking_api_key_type.py

Lines changed: 747 additions & 8 deletions
Large diffs are not rendered by default.

bot/vikingbot/agent/loop.py

Lines changed: 314 additions & 21 deletions
Large diffs are not rendered by default.

bot/vikingbot/config/schema.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,10 @@ class AgentsConfig(BaseModel):
433433
model: str = "openai/doubao-seed-2-0-pro-260215"
434434
max_tool_iterations: int = 50
435435
memory_window: int = 50
436+
session_context_enabled: bool = False
437+
session_context_token_budget: int = 3000
438+
commit_token_threshold: int = 20000
439+
commit_keep_recent_count: int = 5
436440
gen_image_model: str = "openai/doubao-seedream-4-5-251128"
437441
provider: str = ""
438442
api_key: str = ""

bot/vikingbot/heartbeat/service.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,6 @@ async def _tick(self) -> None:
175175
active_workspaces = 0
176176

177177
for workspace_path, session_key_list in workspaces.items():
178-
logger.debug(f"Heartbeat: checking workspace {workspace_path}...")
179-
180178
content = _read_heartbeat_file(workspace_path)
181179

182180
# Skip if HEARTBEAT.md is empty or doesn't exist
@@ -211,9 +209,6 @@ async def _tick(self) -> None:
211209
except Exception as e:
212210
logger.exception(f"Heartbeat execution failed for {workspace_path}: {e}")
213211

214-
if active_workspaces == 0:
215-
logger.debug("Heartbeat: no tasks in any workspace")
216-
217212
async def trigger_now(self, session_key: SessionKey | None = None) -> str | None:
218213
"""Manually trigger a heartbeat."""
219214
if self.on_heartbeat:

bot/vikingbot/hooks/builtins/openviking_hooks.py

Lines changed: 274 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,21 @@
11
import asyncio
22
import re
33
from collections import defaultdict
4+
from datetime import datetime
45
from typing import Any
56

67
from loguru import logger
78

89
from vikingbot.config.loader import load_config
10+
from vikingbot.openviking_mount.session_state import (
11+
get_openviking_session_id,
12+
get_openviking_state,
13+
get_unsynced_messages,
14+
get_unsynced_messages_for_sender,
15+
parse_local_index,
16+
set_sender_synced_local_index,
17+
)
18+
from vikingbot.utils.helpers import cal_str_tokens
919

1020
from ...session import Session
1121
from ..base import Hook, HookContext
@@ -40,15 +50,273 @@ class OpenVikingCompactHook(Hook):
4050
async def _get_client(self, workspace_id: str) -> VikingClient:
4151
return await get_global_client(workspace_id)
4252

53+
async def _execute_session_context_commit(
54+
self,
55+
context: HookContext,
56+
session: Session,
57+
client: VikingClient,
58+
agents_config: Any,
59+
admin_user_id: str,
60+
*,
61+
force_commit: bool,
62+
keep_recent_count: int,
63+
commit_message_threshold: int | None,
64+
) -> dict[str, Any]:
65+
state = get_openviking_state(session)
66+
session_id = get_openviking_session_id(
67+
session,
68+
default_session_id=context.session_key.safe_name(),
69+
)
70+
71+
commit_token_threshold = int(getattr(agents_config, "commit_token_threshold", 6000) or 6000)
72+
pending_tokens = int(state.get("last_pending_tokens", 0) or 0)
73+
messages_to_sync = get_unsynced_messages(session)
74+
last_commit_local_index = parse_local_index(state.get("last_commit_local_index", -1))
75+
messages_since_commit = len(session.messages) - last_commit_local_index - 1
76+
reached_message_threshold = bool(
77+
commit_message_threshold is not None
78+
and commit_message_threshold > 0
79+
and messages_since_commit >= commit_message_threshold
80+
)
81+
82+
admin_append_result = None
83+
admin_commit_result = None
84+
user_results = []
85+
pending_tokens_before_sync = pending_tokens
86+
87+
unsynced_tokens = sum(
88+
cal_str_tokens(str(msg.get("content") or ""))
89+
for msg in messages_to_sync
90+
if msg.get("content") is not None
91+
)
92+
93+
all_sender_ids = sorted(
94+
{
95+
str(msg.get("sender_id"))
96+
for msg in session.messages
97+
if msg.get("sender_id") and msg.get("sender_id") != admin_user_id
98+
}
99+
)
100+
unsynced_messages_by_sender = {}
101+
sender_latest_indexes: dict[str, int] = {}
102+
for sender_id in all_sender_ids:
103+
user_messages = get_unsynced_messages_for_sender(
104+
session,
105+
sender_id,
106+
admin_user_id=admin_user_id,
107+
)
108+
if user_messages:
109+
unsynced_messages_by_sender[sender_id] = user_messages
110+
sender_latest_indexes[sender_id] = max(
111+
index
112+
for index, msg in enumerate(session.messages)
113+
if msg.get("sender_id") == sender_id
114+
)
115+
116+
should_commit = bool(
117+
force_commit
118+
or pending_tokens + unsynced_tokens >= commit_token_threshold
119+
or reached_message_threshold
120+
)
121+
sender_ids_to_sync = (
122+
all_sender_ids if should_commit else sorted(unsynced_messages_by_sender)
123+
)
124+
user_results_by_id: dict[str, dict[str, Any]] = {}
125+
126+
if sender_ids_to_sync:
127+
semaphore = asyncio.Semaphore(5)
128+
129+
async def sync_sender(user_id: str):
130+
user_messages = unsynced_messages_by_sender.get(user_id, [])
131+
async with semaphore:
132+
sender_session_id = f"{session_id}_{user_id}"
133+
append_result = None
134+
if user_messages:
135+
append_result = await client.append_messages(
136+
session_id=sender_session_id,
137+
messages=user_messages,
138+
default_user_role_id=user_id,
139+
session_user_id=user_id,
140+
)
141+
return {
142+
"session_id": sender_session_id,
143+
"user_id": user_id,
144+
"append": append_result,
145+
}
146+
147+
user_results = await asyncio.gather(
148+
*(sync_sender(user_id) for user_id in sender_ids_to_sync),
149+
return_exceptions=True,
150+
)
151+
for result in user_results:
152+
if isinstance(result, dict):
153+
user_id = result["user_id"]
154+
user_results_by_id[user_id] = result
155+
if result.get("append") is not None and user_id in sender_latest_indexes:
156+
set_sender_synced_local_index(
157+
session, user_id, sender_latest_indexes[user_id]
158+
)
159+
160+
fanout_errors = [result for result in user_results if isinstance(result, Exception)]
161+
if fanout_errors:
162+
error_message = "; ".join(str(error) for error in fanout_errors)
163+
state["last_pending_tokens"] = pending_tokens_before_sync
164+
state["last_commit_performed"] = False
165+
state["last_sync_status"] = "error"
166+
state["last_sync_error"] = error_message
167+
return {
168+
"success": False,
169+
"session_id": session_id,
170+
"admin_result": {
171+
"append": admin_append_result,
172+
"commit": admin_commit_result,
173+
"committed": False,
174+
},
175+
"user_results": user_results,
176+
"users_count": len(sender_ids_to_sync),
177+
"pending_tokens": pending_tokens_before_sync,
178+
"error": error_message,
179+
}
180+
181+
if messages_to_sync:
182+
admin_append_result = await client.append_messages(
183+
session_id=session_id,
184+
messages=messages_to_sync,
185+
default_user_role_id=admin_user_id,
186+
session_user_id=admin_user_id,
187+
)
188+
state["last_synced_local_index"] = len(session.messages) - 1
189+
admin_session_state = await client.get_session(session_id, user_id=admin_user_id)
190+
pending_tokens = int(admin_session_state.get("pending_tokens", 0) or 0)
191+
elif force_commit:
192+
admin_session_state = await client.get_session(session_id, user_id=admin_user_id)
193+
pending_tokens = int(admin_session_state.get("pending_tokens", 0) or 0)
194+
195+
should_commit = (
196+
force_commit or pending_tokens >= commit_token_threshold or reached_message_threshold
197+
)
198+
if should_commit:
199+
sender_ids_to_commit = all_sender_ids
200+
for user_id in sender_ids_to_commit:
201+
user_results_by_id.setdefault(
202+
user_id,
203+
{
204+
"session_id": f"{session_id}_{user_id}",
205+
"user_id": user_id,
206+
"append": None,
207+
},
208+
)
209+
210+
if sender_ids_to_commit:
211+
semaphore = asyncio.Semaphore(5)
212+
213+
async def commit_sender(user_id: str):
214+
async with semaphore:
215+
sender_session_id = f"{session_id}_{user_id}"
216+
logger.info(
217+
f"[HOOK] Committed session {sender_session_id} for user {user_id}"
218+
)
219+
return await client.commit_session(
220+
session_id=sender_session_id,
221+
keep_recent_count=keep_recent_count,
222+
user_id=user_id,
223+
)
224+
225+
commit_results = await asyncio.gather(
226+
*(commit_sender(user_id) for user_id in sender_ids_to_commit),
227+
return_exceptions=True,
228+
)
229+
for user_id, commit_result in zip(
230+
sender_ids_to_commit, commit_results, strict=True
231+
):
232+
user_results_by_id[user_id]["commit"] = commit_result
233+
fanout_errors = [
234+
result for result in commit_results if isinstance(result, Exception)
235+
]
236+
if fanout_errors:
237+
error_message = "; ".join(str(error) for error in fanout_errors)
238+
state["last_pending_tokens"] = pending_tokens
239+
state["last_commit_performed"] = False
240+
state["last_sync_status"] = "error"
241+
state["last_sync_error"] = error_message
242+
return {
243+
"success": False,
244+
"session_id": session_id,
245+
"admin_result": {
246+
"append": admin_append_result,
247+
"commit": admin_commit_result,
248+
"committed": False,
249+
},
250+
"user_results": list(user_results_by_id.values()),
251+
"users_count": len(user_results_by_id),
252+
"pending_tokens": pending_tokens,
253+
"error": error_message,
254+
}
255+
256+
admin_commit_result = await client.commit_session(
257+
session_id=session_id,
258+
keep_recent_count=keep_recent_count,
259+
user_id=admin_user_id,
260+
)
261+
logger.info(f"[HOOK] Committed session {session_id} for user {admin_user_id}")
262+
admin_session_state = await client.get_session(session_id, user_id=admin_user_id)
263+
pending_tokens = int(admin_session_state.get("pending_tokens", 0) or 0)
264+
265+
if should_commit:
266+
state["last_commit_at"] = datetime.now().isoformat()
267+
state["last_commit_local_index"] = len(session.messages) - 1
268+
state["last_pending_tokens"] = pending_tokens
269+
state["last_commit_performed"] = should_commit
270+
state["last_sync_status"] = "success"
271+
state.pop("last_sync_error", None)
272+
273+
return {
274+
"success": True,
275+
"session_id": session_id,
276+
"admin_result": {
277+
"append": admin_append_result,
278+
"commit": admin_commit_result,
279+
"committed": should_commit,
280+
},
281+
"user_results": list(user_results_by_id.values()),
282+
"users_count": len(user_results_by_id),
283+
"pending_tokens": pending_tokens,
284+
}
285+
43286
async def execute(self, context: HookContext, **kwargs) -> Any:
44287
vikingbot_session: Session = kwargs.get("session", {})
45288
session_id = context.session_key.safe_name()
46289
config = load_config()
47-
admin_user_id = config.ov_server.admin_user_id
290+
ov_config = config.ov_server
291+
agents_config = config.agents
292+
admin_user_id = ov_config.admin_user_id
293+
force_commit = bool(kwargs.get("force_commit", False))
294+
keep_recent_count = int(
295+
kwargs.get(
296+
"keep_recent_count",
297+
getattr(agents_config, "commit_keep_recent_count", 10),
298+
)
299+
or 0
300+
)
301+
commit_message_threshold = kwargs.get("commit_message_threshold")
302+
if commit_message_threshold is not None:
303+
commit_message_threshold = int(commit_message_threshold)
48304

49305
try:
50306
client = await self._get_client(context.workspace_id)
51307

308+
if getattr(agents_config, "session_context_enabled", False):
309+
return await self._execute_session_context_commit(
310+
context,
311+
vikingbot_session,
312+
client,
313+
agents_config,
314+
admin_user_id,
315+
force_commit=force_commit,
316+
keep_recent_count=keep_recent_count,
317+
commit_message_threshold=commit_message_threshold,
318+
)
319+
52320
if not client.should_sender_fanout():
53321
single_result = await client.commit(session_id, vikingbot_session.messages, None)
54322
return {
@@ -92,6 +360,11 @@ async def commit_with_semaphore(user_id: str, user_messages: list):
92360
"users_count": len(messages_by_sender),
93361
}
94362
except Exception as e:
363+
state = None
364+
if hasattr(vikingbot_session, "metadata"):
365+
state = get_openviking_state(vikingbot_session)
366+
state["last_sync_status"] = "error"
367+
state["last_sync_error"] = str(e)
95368
logger.exception(f"Failed to add message to OpenViking: {e}")
96369
return {"success": False, "error": str(e)}
97370

bot/vikingbot/hooks/manager.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import importlib
33
from collections import defaultdict
4-
from typing import List, Any, Dict, Type
4+
from typing import Any, Dict, List, Type
55

66
from loguru import logger
77

@@ -47,9 +47,6 @@ async def execute_hooks(self, context: HookContext, **kwargs) -> List[Any]:
4747
async_hooks = [hook for hook in self._hooks[context.event_type] if not hook.is_sync]
4848
sync_hooks = [hook for hook in self._hooks[context.event_type] if hook.is_sync]
4949
if async_hooks:
50-
logger.debug(
51-
f"Executing {len(async_hooks)} async hooks for event '{context.event_type}'"
52-
)
5350
async_results = await asyncio.gather(
5451
*[hook.execute(context, **kwargs) for hook in async_hooks], return_exceptions=True
5552
)

0 commit comments

Comments
 (0)