Skip to content

Commit a0186b5

Browse files
committed
fix: preserve persona on new conversation creation
1 parent 55ed028 commit a0186b5

7 files changed

Lines changed: 148 additions & 12 deletions

File tree

astrbot/builtin_stars/builtin_commands/commands/conversation.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
DEERFLOW_PROVIDER_TYPE,
77
DEERFLOW_THREAD_ID_KEY,
88
)
9+
from astrbot.core.persona_utils import (
10+
is_persona_none_marker,
11+
normalize_persona_id,
12+
)
913
from astrbot.core.platform.astr_message_event import MessageSession
1014
from astrbot.core.platform.message_type import MessageType
1115
from astrbot.core.utils.active_event_registry import active_event_registry
@@ -37,7 +41,7 @@ async def _get_current_persona_id(self, session_id):
3741
)
3842
if not conv:
3943
return None
40-
return conv.persona_id
44+
return normalize_persona_id(conv.persona_id)
4145

4246
async def reset(self, message: AstrMessageEvent) -> None:
4347
"""重置 LLM 会话"""
@@ -225,7 +229,7 @@ async def convs(self, message: AstrMessageEvent, page: int = 1) -> None:
225229
platform_name=platform_name,
226230
provider_settings=provider_settings,
227231
)
228-
if persona_id == "[%None]":
232+
if is_persona_none_marker(persona_id):
229233
persona_name = "无"
230234
elif persona_id:
231235
persona_name = persona_id

astrbot/builtin_stars/builtin_commands/commands/persona.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33

44
from astrbot.api import star
55
from astrbot.api.event import AstrMessageEvent, MessageEventResult
6+
from astrbot.core.persona_utils import (
7+
PERSONA_NONE_MARKER,
8+
is_persona_none_marker,
9+
)
610

711
if TYPE_CHECKING:
812
from astrbot.core.db.po import Persona
@@ -92,7 +96,7 @@ async def persona(self, message: AstrMessageEvent) -> None:
9296
provider_settings=provider_settings,
9397
)
9498

95-
if persona_id == "[%None]":
99+
if is_persona_none_marker(persona_id):
96100
curr_persona_name = "无"
97101
elif persona_id:
98102
curr_persona_name = persona_id
@@ -174,7 +178,7 @@ async def persona(self, message: AstrMessageEvent) -> None:
174178
return
175179
await self.context.conversation_manager.update_conversation_persona_id(
176180
message.unified_msg_origin,
177-
"[%None]",
181+
PERSONA_NONE_MARKER,
178182
)
179183
message.set_result(MessageEventResult().message("取消人格成功。"))
180184
else:

astrbot/core/conversation_mgr.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from astrbot.core.agent.message import AssistantMessageSegment, UserMessageSegment
1212
from astrbot.core.db import BaseDatabase
1313
from astrbot.core.db.po import Conversation, ConversationV2
14+
from astrbot.core.persona_utils import normalize_persona_id
1415
from astrbot.core.utils.datetime_utils import to_utc_timestamp
1516

1617

@@ -98,6 +99,12 @@ async def new_conversation(
9899
platform_id = parts[0]
99100
if not platform_id:
100101
platform_id = "unknown"
102+
if persona_id is None:
103+
curr_cid = await self.get_curr_conversation_id(unified_msg_origin)
104+
if curr_cid:
105+
curr_conv = await self.db.get_conversation_by_id(cid=curr_cid)
106+
if curr_conv:
107+
persona_id = normalize_persona_id(curr_conv.persona_id)
101108
conv = await self.db.create_conversation(
102109
user_id=unified_msg_origin,
103110
platform_id=platform_id,

astrbot/core/persona_mgr.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
44
from astrbot.core.db import BaseDatabase
55
from astrbot.core.db.po import Persona, PersonaFolder, Personality
6+
from astrbot.core.persona_utils import is_persona_none_marker
67
from astrbot.core.platform.message_session import MessageSession
78
from astrbot.core.sentinels import NOT_GIVEN
89

@@ -104,7 +105,7 @@ async def resolve_selected_persona(
104105

105106
if not persona_id:
106107
persona_id = conversation_persona_id
107-
if persona_id == "[%None]":
108+
if is_persona_none_marker(persona_id):
108109
pass
109110
elif persona_id is None:
110111
persona_id = (provider_settings or {}).get("default_personality")
@@ -115,7 +116,11 @@ async def resolve_selected_persona(
115116
)
116117

117118
use_webchat_special_default = False
118-
if not persona and platform_name == "webchat" and persona_id != "[%None]":
119+
if (
120+
not persona
121+
and platform_name == "webchat"
122+
and not is_persona_none_marker(persona_id)
123+
):
119124
persona_id = "_chatui_default_"
120125
use_webchat_special_default = True
121126

astrbot/core/persona_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""Helpers for persona marker handling."""
2+
3+
PERSONA_NONE_MARKER = "[%None]"
4+
5+
6+
def is_persona_none_marker(persona_id: str | None) -> bool:
7+
"""Return whether the persona id is the explicit no-persona marker."""
8+
return persona_id == PERSONA_NONE_MARKER
9+
10+
11+
def normalize_persona_id(persona_id: str | None) -> str | None:
12+
"""Normalize the explicit no-persona marker to None."""
13+
if is_persona_none_marker(persona_id):
14+
return None
15+
return persona_id

tests/unit/test_astr_main_agent.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from astrbot.core.agent.tool import FunctionTool, ToolSet
1111
from astrbot.core.conversation_mgr import Conversation
1212
from astrbot.core.message.components import File, Image, Plain, Reply
13+
from astrbot.core.persona_utils import PERSONA_NONE_MARKER
1314
from astrbot.core.platform.astr_message_event import AstrMessageEvent
1415
from astrbot.core.platform.platform_metadata import PlatformMetadata
1516
from astrbot.core.provider import Provider
@@ -504,14 +505,14 @@ async def test_ensure_persona_from_conversation(self, mock_event, mock_context):
504505

505506
@pytest.mark.asyncio
506507
async def test_ensure_persona_none_explicit(self, mock_event, mock_context):
507-
"""Test that [%None] persona is explicitly set to no persona."""
508+
"""Test that the explicit no-persona marker is treated as no persona."""
508509
module = ama
509510
mock_context.persona_manager.personas_v3 = []
510511
mock_context.persona_manager.resolve_selected_persona = AsyncMock(
511-
return_value=("[%None]", None, None, False)
512+
return_value=(PERSONA_NONE_MARKER, None, None, False)
512513
)
513514
req = ProviderRequest()
514-
req.conversation = MagicMock(persona_id="[%None]")
515+
req.conversation = MagicMock(persona_id=PERSONA_NONE_MARKER)
515516

516517
await module._ensure_persona_and_skills(req, {}, mock_context, mock_event)
517518

@@ -565,9 +566,10 @@ async def test_subagent_dedupe_uses_default_persona_tools(
565566
tmgr = mock_context.get_llm_tool_manager.return_value
566567
tmgr.func_list = [tool_a, tool_b]
567568
tmgr.get_full_tool_set.return_value = ToolSet([tool_a, tool_b])
568-
tmgr.get_func.side_effect = lambda name: {"tool_a": tool_a, "tool_b": tool_b}.get(
569-
name
570-
)
569+
tmgr.get_func.side_effect = lambda name: {
570+
"tool_a": tool_a,
571+
"tool_b": tool_b,
572+
}.get(name)
571573

572574
handoff = MagicMock()
573575
handoff.name = "transfer_to_planner"
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
"""Tests for conversation persona inheritance behavior."""
2+
3+
from types import SimpleNamespace
4+
from unittest.mock import AsyncMock, MagicMock, patch
5+
6+
import pytest
7+
8+
from astrbot.builtin_stars.builtin_commands.commands.conversation import (
9+
ConversationCommands,
10+
)
11+
from astrbot.core.conversation_mgr import ConversationManager
12+
from astrbot.core.persona_utils import PERSONA_NONE_MARKER
13+
14+
15+
@pytest.mark.asyncio
16+
async def test_new_conversation_inherits_current_persona_when_not_provided():
17+
db = MagicMock()
18+
db.get_conversation_by_id = AsyncMock(
19+
return_value=SimpleNamespace(persona_id="psychologist")
20+
)
21+
db.create_conversation = AsyncMock(
22+
return_value=SimpleNamespace(conversation_id="new-cid")
23+
)
24+
25+
manager = ConversationManager(db)
26+
manager.session_conversations["test:private:u1"] = "old-cid"
27+
28+
with patch(
29+
"astrbot.core.conversation_mgr.sp.session_put",
30+
new=AsyncMock(return_value=None),
31+
):
32+
await manager.new_conversation("test:private:u1", platform_id="test")
33+
34+
assert db.create_conversation.await_args.kwargs["persona_id"] == "psychologist"
35+
36+
37+
@pytest.mark.asyncio
38+
async def test_new_conversation_does_not_inherit_persona_none_marker():
39+
db = MagicMock()
40+
db.get_conversation_by_id = AsyncMock(
41+
return_value=SimpleNamespace(persona_id=PERSONA_NONE_MARKER)
42+
)
43+
db.create_conversation = AsyncMock(
44+
return_value=SimpleNamespace(conversation_id="new-cid")
45+
)
46+
47+
manager = ConversationManager(db)
48+
manager.session_conversations["test:private:u1"] = "old-cid"
49+
50+
with patch(
51+
"astrbot.core.conversation_mgr.sp.session_put",
52+
new=AsyncMock(return_value=None),
53+
):
54+
await manager.new_conversation("test:private:u1", platform_id="test")
55+
56+
assert db.create_conversation.await_args.kwargs["persona_id"] is None
57+
58+
59+
@pytest.mark.asyncio
60+
async def test_new_conversation_keeps_explicit_persona_id():
61+
db = MagicMock()
62+
db.get_conversation_by_id = AsyncMock(
63+
return_value=SimpleNamespace(persona_id="psychologist")
64+
)
65+
db.create_conversation = AsyncMock(
66+
return_value=SimpleNamespace(conversation_id="new-cid")
67+
)
68+
69+
manager = ConversationManager(db)
70+
manager.session_conversations["test:private:u1"] = "old-cid"
71+
72+
with patch(
73+
"astrbot.core.conversation_mgr.sp.session_put",
74+
new=AsyncMock(return_value=None),
75+
):
76+
await manager.new_conversation(
77+
"test:private:u1",
78+
platform_id="test",
79+
persona_id="teacher",
80+
)
81+
82+
assert db.create_conversation.await_args.kwargs["persona_id"] == "teacher"
83+
84+
85+
@pytest.mark.asyncio
86+
async def test_get_current_persona_id_returns_none_for_none_marker():
87+
context = MagicMock()
88+
context.conversation_manager.get_curr_conversation_id = AsyncMock(
89+
return_value="old-cid"
90+
)
91+
context.conversation_manager.get_conversation = AsyncMock(
92+
return_value=MagicMock(persona_id=PERSONA_NONE_MARKER)
93+
)
94+
95+
command = ConversationCommands(context)
96+
97+
result = await command._get_current_persona_id("test:private:u1")
98+
99+
assert result is None

0 commit comments

Comments
 (0)