Skip to content

Commit 4a737ca

Browse files
committed
do cohere specific mapping of roles
1 parent e34424e commit 4a737ca

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

sentry_sdk/ai/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class GEN_AI_ALLOWED_MESSAGE_ROLES:
3030
GEN_AI_MESSAGE_ROLE_REVERSE_MAPPING = {
3131
GEN_AI_ALLOWED_MESSAGE_ROLES.SYSTEM: ["system"],
3232
GEN_AI_ALLOWED_MESSAGE_ROLES.USER: ["user", "human"],
33-
GEN_AI_ALLOWED_MESSAGE_ROLES.ASSISTANT: ["assistant", "ai", "chatbot"],
33+
GEN_AI_ALLOWED_MESSAGE_ROLES.ASSISTANT: ["assistant", "ai"],
3434
GEN_AI_ALLOWED_MESSAGE_ROLES.TOOL: ["tool", "tool_call"],
3535
}
3636

@@ -503,7 +503,7 @@ def normalize_message_role(role: str) -> str:
503503
Normalize a message role to one of the 4 allowed gen_ai role values.
504504
Maps "ai" -> "assistant" and keeps other standard roles unchanged.
505505
"""
506-
return GEN_AI_MESSAGE_ROLE_MAPPING.get(role.lower(), role)
506+
return GEN_AI_MESSAGE_ROLE_MAPPING.get(role, role)
507507

508508

509509
def normalize_message_roles(messages: "list[dict[str, Any]]") -> "list[dict[str, Any]]":

sentry_sdk/integrations/cohere.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from sentry_sdk import consts
55
from sentry_sdk.ai.monitoring import record_token_usage
66
from sentry_sdk.consts import OP, SPANDATA
7-
from sentry_sdk.ai.utils import set_data_normalized, normalize_message_roles
7+
from sentry_sdk.ai.utils import set_data_normalized
88

99
from typing import TYPE_CHECKING
1010

@@ -39,6 +39,14 @@
3939
from cohere import StreamedChatResponse_StreamEnd as StreamEndStreamedChatResponse
4040

4141

42+
COHERE_ROLE_MAPPING = {
43+
"SYSTEM": "system",
44+
"USER": "user",
45+
"CHATBOT": "assistant",
46+
"TOOL": "tool",
47+
}
48+
49+
4250
COLLECTED_CHAT_PARAMS = {
4351
"model": SPANDATA.GEN_AI_REQUEST_MODEL,
4452
"k": SPANDATA.GEN_AI_REQUEST_TOP_K,
@@ -157,14 +165,14 @@ def new_chat(*args: "Any", **kwargs: "Any") -> "Any":
157165
if should_send_default_pii() and integration.include_prompts:
158166
messages = []
159167
for x in kwargs.get("chat_history", []):
168+
role = getattr(x, "role", "")
160169
messages.append(
161170
{
162-
"role": getattr(x, "role", ""),
171+
"role": COHERE_ROLE_MAPPING.get(role, role),
163172
"content": getattr(x, "message", ""),
164173
}
165174
)
166175
messages.append({"role": "user", "content": message})
167-
messages = normalize_message_roles(messages)
168176
set_data_normalized(
169177
span,
170178
SPANDATA.GEN_AI_REQUEST_MESSAGES,

0 commit comments

Comments
 (0)