Skip to content

Commit a62ae9d

Browse files
committed
Add conversations support for Responses API: modifications from reviews
This is the rebase for what was left after merging PR-866
1 parent 364707e commit a62ae9d

5 files changed

Lines changed: 112 additions & 85 deletions

File tree

src/app/endpoints/query.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
consume_tokens,
6666
get_available_quotas,
6767
)
68+
from utils.suid import normalize_conversation_id
6869
from utils.token_counter import TokenCounter, extract_and_update_token_metrics
6970
from utils.transcripts import store_transcript
7071
from utils.types import TurnSummary, content_to_str
@@ -109,14 +110,23 @@ def persist_user_conversation_details(
109110
topic_summary: Optional[str],
110111
) -> None:
111112
"""Associate conversation to user in the database."""
113+
# Normalize the conversation ID (strip 'conv_' prefix if present)
114+
normalized_id = normalize_conversation_id(conversation_id)
115+
logger.debug(
116+
"persist_user_conversation_details - original conv_id: %s, normalized: %s, user: %s",
117+
conversation_id,
118+
normalized_id,
119+
user_id,
120+
)
121+
112122
with get_session() as session:
113123
existing_conversation = (
114-
session.query(UserConversation).filter_by(id=conversation_id).first()
124+
session.query(UserConversation).filter_by(id=normalized_id).first()
115125
)
116126

117127
if not existing_conversation:
118128
conversation = UserConversation(
119-
id=conversation_id,
129+
id=normalized_id,
120130
user_id=user_id,
121131
last_used_model=model,
122132
last_used_provider=provider_id,
@@ -125,15 +135,24 @@ def persist_user_conversation_details(
125135
)
126136
session.add(conversation)
127137
logger.debug(
128-
"Associated conversation %s to user %s", conversation_id, user_id
138+
"Associated conversation %s to user %s", normalized_id, user_id
129139
)
130140
else:
131141
existing_conversation.last_used_model = model
132142
existing_conversation.last_used_provider = provider_id
133143
existing_conversation.last_message_at = datetime.now(UTC)
134144
existing_conversation.message_count += 1
145+
logger.debug(
146+
"Updating existing conversation in DB - ID: %s, User: %s, Messages: %d",
147+
normalized_id,
148+
user_id,
149+
existing_conversation.message_count,
150+
)
135151

136152
session.commit()
153+
logger.debug(
154+
"Successfully committed conversation %s to database", normalized_id
155+
)
137156

138157

139158
def evaluate_model_hints(
@@ -257,9 +276,13 @@ async def query_endpoint_handler_base( # pylint: disable=R0914
257276
logger.debug(
258277
"Conversation ID specified in query: %s", query_request.conversation_id
259278
)
279+
# Normalize the conversation ID for database lookup (strip conv_ prefix if present)
280+
normalized_conv_id_for_lookup = normalize_conversation_id(
281+
query_request.conversation_id
282+
)
260283
user_conversation = validate_conversation_ownership(
261284
user_id=user_id,
262-
conversation_id=query_request.conversation_id,
285+
conversation_id=normalized_conv_id_for_lookup,
263286
others_allowed=(
264287
Action.QUERY_OTHERS_CONVERSATIONS in request.state.authorized_actions
265288
),

src/app/endpoints/query_v2.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -439,15 +439,13 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
439439
conversation_id,
440440
)
441441

442-
# Normalize conversation ID before returning (remove conv_ prefix for consistency)
443-
normalized_conversation_id = (
444-
normalize_conversation_id(conversation_id)
445-
if conversation_id
446-
else conversation_id
442+
return (
443+
summary,
444+
normalize_conversation_id(conversation_id),
445+
referenced_documents,
446+
token_usage,
447447
)
448448

449-
return (summary, normalized_conversation_id, referenced_documents, token_usage)
450-
451449

452450
def parse_referenced_documents_from_responses_api(
453451
response: OpenAIResponseObject, # pylint: disable=unused-argument

src/app/endpoints/streaming_query_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,6 @@ async def retrieve_response( # pylint: disable=too-many-locals
456456
response_stream = cast(AsyncIterator[OpenAIResponseObjectStream], response)
457457
# async for chunk in response_stream:
458458
# logger.error("Chunk: %s", chunk.model_dump_json())
459-
# Return the normalized conversation_id (already normalized above)
459+
# Return the normalized conversation_id
460460
# The response_generator will emit it in the start event
461-
return response_stream, conversation_id
461+
return response_stream, normalize_conversation_id(conversation_id)

src/utils/suid.py

Lines changed: 20 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -23,75 +23,31 @@ def check_suid(suid: str) -> bool:
2323
Returns True if the string is a valid UUID or a llama-stack conversation ID.
2424
2525
Parameters:
26-
suid (str | bytes): UUID value to validate — accepts a UUID string,
27-
its byte representation, or a llama-stack conversation ID (conv_xxx),
28-
or a plain hex string (database format).
26+
suid (str): UUID value to validate — accepts a UUID string,
27+
or a llama-stack conversation ID (48-char hex, optionally with conv_ prefix).
2928
3029
Notes:
31-
Validation is performed by:
32-
1. For llama-stack conversation IDs starting with 'conv_':
33-
- Strips the 'conv_' prefix
34-
- Validates at least 32 hex characters follow (may have additional suffix)
35-
- Extracts first 32 hex chars as the UUID part
36-
- Converts to UUID format by inserting hyphens at standard positions
37-
- Validates the resulting UUID structure
38-
2. For plain hex strings (database format, 32+ chars without conv_ prefix):
39-
- Validates it's a valid hex string
40-
- Extracts first 32 chars as UUID part
41-
- Converts to UUID format and validates
42-
3. For standard UUIDs: attempts to construct uuid.UUID(suid)
43-
Invalid formats or types result in False.
30+
Validation accepts:
31+
1. Standard UUID format (e.g., '550e8400-e29b-41d4-a716-446655440000')
32+
2. 48-character hex string (llama-stack format)
33+
3. 'conv_' prefix + 48-character hex string (53 chars total)
4434
"""
45-
try:
46-
# Accept llama-stack conversation IDs (conv_<hex> format)
47-
if isinstance(suid, str) and suid.startswith("conv_"):
48-
# Extract the hex string after 'conv_'
49-
hex_part = suid[5:] # Remove 'conv_' prefix
50-
51-
# Verify it's a valid hex string
52-
# llama-stack may use 32 hex chars (UUID) or 36 hex chars (UUID + suffix)
53-
if len(hex_part) < 32:
54-
return False
55-
56-
# Verify all characters are valid hex
57-
try:
58-
int(hex_part, 16)
59-
except ValueError:
60-
return False
61-
62-
# Extract the first 32 hex characters (the UUID part)
63-
uuid_hex = hex_part[:32]
64-
65-
# Convert to UUID format with hyphens: 8-4-4-4-12
66-
uuid_str = (
67-
f"{uuid_hex[:8]}-{uuid_hex[8:12]}-{uuid_hex[12:16]}-"
68-
f"{uuid_hex[16:20]}-{uuid_hex[20:]}"
69-
)
70-
71-
# Validate it's a proper UUID
72-
uuid.UUID(uuid_str)
35+
if not isinstance(suid, str):
36+
return False
37+
38+
# Strip 'conv_' prefix if present
39+
hex_part = suid[5:] if suid.startswith("conv_") else suid
40+
41+
# Check for 48-char hex string (llama-stack conversation ID format)
42+
if len(hex_part) == 48:
43+
try:
44+
int(hex_part, 16)
7345
return True
46+
except ValueError:
47+
return False
7448

75-
# Check if it's a plain hex string (database format without conv_ prefix)
76-
if isinstance(suid, str) and len(suid) >= 32:
77-
try:
78-
int(suid, 16)
79-
# Extract the first 32 hex characters (the UUID part)
80-
uuid_hex = suid[:32]
81-
82-
# Convert to UUID format with hyphens: 8-4-4-4-12
83-
uuid_str = (
84-
f"{uuid_hex[:8]}-{uuid_hex[8:12]}-{uuid_hex[12:16]}-"
85-
f"{uuid_hex[16:20]}-{uuid_hex[20:]}"
86-
)
87-
88-
# Validate it's a proper UUID
89-
uuid.UUID(uuid_str)
90-
return True
91-
except ValueError:
92-
pass # Not a valid hex string, try standard UUID validation
93-
94-
# accepts strings and bytes only for UUID validation
49+
# Check for standard UUID format
50+
try:
9551
uuid.UUID(suid)
9652
return True
9753
except (ValueError, TypeError):

tests/unit/utils/test_suid.py

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
"""Unit tests for functions defined in utils.suid module."""
22

3+
from typing import Any
4+
5+
import pytest
6+
37
from utils import suid
48

59

@@ -12,16 +16,62 @@ def test_get_suid(self) -> None:
1216
assert suid.check_suid(suid_value), "Generated SUID is not valid"
1317
assert isinstance(suid_value, str), "SUID should be a string"
1418

15-
def test_check_suid_valid(self) -> None:
19+
def test_check_suid_valid_uuid(self) -> None:
1620
"""Test that check_suid returns True for a valid UUID."""
1721
valid_suid = "123e4567-e89b-12d3-a456-426614174000"
22+
assert suid.check_suid(valid_suid), "check_suid should return True for UUID"
23+
24+
def test_check_suid_valid_48char_hex(self) -> None:
25+
"""Test that check_suid returns True for a 48-char hex string."""
26+
valid_hex = "e6afd7aaa97b49ce8f4f96a801b07893d9cb784d72e53e3c"
27+
assert len(valid_hex) == 48
1828
assert suid.check_suid(
19-
valid_suid
20-
), "check_suid should return True for a valid SUID"
29+
valid_hex
30+
), "check_suid should return True for 48-char hex"
31+
32+
def test_check_suid_valid_conv_prefix(self) -> None:
33+
"""Test that check_suid returns True for conv_ + 48-char hex string."""
34+
valid_conv = "conv_e6afd7aaa97b49ce8f4f96a801b07893d9cb784d72e53e3c"
35+
assert len(valid_conv) == 53
36+
assert suid.check_suid(
37+
valid_conv
38+
), "check_suid should return True for conv_ prefixed hex"
39+
40+
def test_check_suid_invalid_string(self) -> None:
41+
"""Test that check_suid returns False for an invalid string."""
42+
assert not suid.check_suid("invalid-uuid")
2143

22-
def test_check_suid_invalid(self) -> None:
23-
"""Test that check_suid returns False for an invalid UUID."""
24-
invalid_suid = "invalid-uuid"
44+
def test_check_suid_valid_32char_hex_uuid(self) -> None:
45+
"""Test that check_suid returns True for 32-char hex (valid UUID format)."""
46+
# 32-char hex is a valid UUID format (without hyphens)
47+
assert suid.check_suid("e6afd7aaa97b49ce8f4f96a801b07893")
48+
49+
def test_check_suid_invalid_hex_wrong_length(self) -> None:
50+
"""Test that check_suid returns False for hex string with wrong length."""
51+
# 47 chars (not 48, not valid UUID)
52+
assert not suid.check_suid("e6afd7aaa97b49ce8f4f96a801b07893d9cb784d72e53e3")
53+
# 49 chars (not 48, not valid UUID)
54+
assert not suid.check_suid("e6afd7aaa97b49ce8f4f96a801b07893d9cb784d72e53e3c1")
55+
56+
def test_check_suid_invalid_conv_prefix_wrong_length(self) -> None:
57+
"""Test that check_suid returns False for conv_ with wrong hex length."""
58+
# conv_ + 47 chars (not 48)
59+
assert not suid.check_suid(
60+
"conv_e6afd7aaa97b49ce8f4f96a801b07893d9cb784d72e53e3"
61+
)
62+
# conv_ + 49 chars (not 48)
2563
assert not suid.check_suid(
26-
invalid_suid
27-
), "check_suid should return False for an invalid SUID"
64+
"conv_e6afd7aaa97b49ce8f4f96a801b07893d9cb784d72e53e3c1"
65+
)
66+
67+
def test_check_suid_invalid_non_hex_chars(self) -> None:
68+
"""Test that check_suid returns False for strings with non-hex characters."""
69+
# 48 chars but contains 'g' and 'z'
70+
invalid_hex = "g6afd7aaa97b49ce8f4f96a801b07893d9cb784d72e53ezz"
71+
assert len(invalid_hex) == 48
72+
assert not suid.check_suid(invalid_hex)
73+
74+
@pytest.mark.parametrize("invalid_type", [None, 123, [], {}])
75+
def test_check_suid_invalid_type(self, invalid_type: Any) -> None:
76+
"""Test that check_suid returns False for non-string types."""
77+
assert not suid.check_suid(invalid_type)

0 commit comments

Comments
 (0)