Skip to content

Commit 81e303a

Browse files
authored
Merge pull request #1129 from asimurka/rebase_conversation_history
LCORE-1166: Rebased conversation history changes
2 parents 1eb6f32 + 23a433d commit 81e303a

6 files changed

Lines changed: 379 additions & 139 deletions

File tree

src/app/endpoints/query.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,13 +190,16 @@ async def query_endpoint_handler(
190190
quota_limiters=configuration.quota_limiters, user_id=user_id
191191
)
192192

193+
completed_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
193194
conversation_id = normalize_conversation_id(responses_params.conversation)
195+
194196
logger.info("Storing query results")
195197
store_query_results(
196198
user_id=user_id,
197199
conversation_id=conversation_id,
198-
model_id=responses_params.model,
200+
model=responses_params.model,
199201
started_at=started_at,
202+
completed_at=completed_at,
200203
summary=turn_summary,
201204
query_request=query_request,
202205
configuration=configuration,

src/app/endpoints/streaming_query.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,13 +366,15 @@ async def generate_response(
366366
turn_summary.referenced_documents,
367367
context.query_request.media_type or MEDIA_TYPE_JSON,
368368
)
369+
completed_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
369370

370371
# Store query results (transcript, conversation details, cache)
371372
logger.info("Storing query results")
372373
store_query_results(
373374
user_id=context.user_id,
374375
conversation_id=context.conversation_id,
375-
model_id=responses_params.model,
376+
model=responses_params.model,
377+
completed_at=completed_at,
376378
started_at=context.started_at,
377379
summary=turn_summary,
378380
query_request=context.query_request,

src/utils/conversations.py

Lines changed: 95 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -297,85 +297,128 @@ def _create_turn_from_db_metadata(
297297
)
298298

299299

300-
def build_conversation_turns_from_items(
300+
def _group_items_into_turns(
301301
items: list[ItemListResponse],
302-
turns_metadata: list[UserTurn],
303-
conversation_start_time: datetime,
304-
) -> list[ConversationTurn]:
305-
"""Build conversation turns from Conversations API items and turns metadata.
302+
) -> list[list[ItemListResponse]]:
303+
"""Group conversation items into turns.
304+
305+
Each turn starts with a user message. All subsequent messages and tool items
306+
belong to that turn until the next user message.
306307
307308
Args:
308309
items: Conversation items list from Conversations API, oldest first
309-
turns_metadata: List of UserTurn database objects ordered by turn_number.
310-
Can be empty for legacy conversations without stored metadata.
311-
conversation_start_time: Timestamp to use for dummy metadata in legacy conversations.
312-
Typically the conversation's created_at timestamp.
313310
314311
Returns:
315-
List of ConversationTurn objects, oldest first
312+
List of turns, where each turn is a list of items belonging to that turn
316313
"""
317-
chat_history: list[ConversationTurn] = []
318-
current_messages: list[Message] = []
319-
current_tool_calls: list[ToolCallSummary] = []
320-
current_tool_results: list[ToolResultSummary] = []
321-
current_turn_index = 0
314+
turns: list[list[ItemListResponse]] = []
315+
current_turn_items: list[ItemListResponse] = []
322316

323317
for item in items:
324318
item_type = getattr(item, "type", None)
325319

326-
# Parse message items
320+
# User message marks the beginning of a new turn
327321
if item_type == "message":
328322
message_item = cast(MessageOutput, item)
329-
message = _parse_message_item(message_item)
330-
331-
# User message marks the beginning of a new turn
332-
if message.type == "user":
323+
if message_item.role == "user":
333324
# If we have accumulated items, finish the previous turn
334-
if current_messages or current_tool_calls or current_tool_results:
335-
turn_metadata = (
336-
turns_metadata[current_turn_index]
337-
if current_turn_index < len(turns_metadata)
338-
else _create_dummy_turn_metadata(conversation_start_time)
339-
)
340-
chat_history.append(
341-
_create_turn_from_db_metadata(
342-
turn_metadata,
343-
current_messages,
344-
current_tool_calls,
345-
current_tool_results,
346-
)
347-
)
348-
current_turn_index += 1
325+
if current_turn_items:
326+
turns.append(current_turn_items)
327+
current_turn_items = []
349328

350329
# Start new turn with this user message
351-
current_messages = [message]
352-
current_tool_calls = []
353-
current_tool_results = []
330+
current_turn_items = [item]
354331
else:
355332
# Add non-user message to current turn
356-
current_messages.append(message)
333+
current_turn_items.append(item)
334+
else:
335+
# Add tool-related items to current turn
336+
current_turn_items.append(item)
357337

358-
# Parse tool-related items
338+
# Add final turn if there are items
339+
if current_turn_items:
340+
turns.append(current_turn_items)
341+
342+
return turns
343+
344+
345+
def _process_turn_items(
346+
turn_items: list[ItemListResponse],
347+
) -> tuple[list[Message], list[ToolCallSummary], list[ToolResultSummary]]:
348+
"""Process items from a single turn into messages, tool calls, and tool results.
349+
350+
Args:
351+
turn_items: List of items belonging to a single turn
352+
353+
Returns:
354+
Tuple of (messages, tool_calls, tool_results)
355+
"""
356+
messages: list[Message] = []
357+
tool_calls: list[ToolCallSummary] = []
358+
tool_results: list[ToolResultSummary] = []
359+
360+
for item in turn_items:
361+
item_type = getattr(item, "type", None)
362+
363+
if item_type == "message":
364+
message_item = cast(MessageOutput, item)
365+
message = _parse_message_item(message_item)
366+
messages.append(message)
359367
else:
360368
tool_call, tool_result = _build_tool_call_summary_from_item(item)
361369
if tool_call is not None:
362-
current_tool_calls.append(tool_call)
370+
tool_calls.append(tool_call)
363371
if tool_result is not None:
364-
current_tool_results.append(tool_result)
372+
tool_results.append(tool_result)
365373

366-
# Add final turn if there are items
367-
if current_messages or current_tool_calls or current_tool_results:
368-
turn_metadata = (
369-
turns_metadata[current_turn_index]
370-
if current_turn_index < len(turns_metadata)
371-
else _create_dummy_turn_metadata(conversation_start_time)
372-
)
374+
return messages, tool_calls, tool_results
375+
376+
377+
def build_conversation_turns_from_items(
378+
items: list[ItemListResponse],
379+
turns_metadata: list[UserTurn],
380+
conversation_start_time: datetime,
381+
) -> list[ConversationTurn]:
382+
"""Build conversation turns from Conversations API items and turns metadata.
383+
384+
Args:
385+
items: Conversation items list from Conversations API, oldest first
386+
turns_metadata: List of UserTurn database objects ordered by turn_number.
387+
Can be empty for legacy conversations without stored metadata.
388+
For extended legacy conversations, only the newer turns have metadata.
389+
conversation_start_time: Timestamp to use for dummy metadata in legacy conversations.
390+
Typically the conversation's created_at timestamp.
391+
392+
Returns:
393+
List of ConversationTurn objects, oldest first
394+
"""
395+
# Group items into turns first
396+
turn_items_list = _group_items_into_turns(items)
397+
398+
# Calculate how many legacy turns don't have metadata
399+
total_turns = len(turn_items_list)
400+
legacy_turns_count = total_turns - len(turns_metadata)
401+
402+
# Process each turn with its corresponding metadata
403+
chat_history: list[ConversationTurn] = []
404+
for turn_index, turn_items in enumerate(turn_items_list):
405+
# Process items into messages, tool calls, and tool results
406+
messages, tool_calls, tool_results = _process_turn_items(turn_items)
407+
408+
# Select appropriate metadata for this turn
409+
if turn_index < legacy_turns_count:
410+
turn_metadata = _create_dummy_turn_metadata(conversation_start_time)
411+
else:
412+
metadata_index = turn_index - legacy_turns_count
413+
turn_metadata = turns_metadata[metadata_index]
414+
415+
# Create ConversationTurn from metadata and processed items
373416
chat_history.append(
374417
_create_turn_from_db_metadata(
375418
turn_metadata,
376-
current_messages,
377-
current_tool_calls,
378-
current_tool_results,
419+
messages,
420+
tool_calls,
421+
tool_results,
379422
)
380423
)
381424

src/utils/query.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313
from llama_stack_client.types import ModelListResponse, Shield
1414

1515
from fastapi import HTTPException
16+
from sqlalchemy import func
1617
from configuration import AppConfig, configuration
1718
from models.cache_entry import CacheEntry
1819
from models.config import Action
19-
from models.database.conversations import UserConversation
20+
from models.database.conversations import UserConversation, UserTurn
2021
import constants
2122
from models.requests import Attachment, QueryRequest
2223
from models.responses import (
@@ -330,8 +331,9 @@ def prepare_input(query_request: QueryRequest) -> str:
330331
def store_query_results( # pylint: disable=too-many-arguments,too-many-locals
331332
user_id: str,
332333
conversation_id: str,
333-
model_id: str,
334+
model: str,
334335
started_at: str,
336+
completed_at: str,
335337
summary: TurnSummary,
336338
query_request: QueryRequest,
337339
configuration: AppConfig,
@@ -349,8 +351,9 @@ def store_query_results( # pylint: disable=too-many-arguments,too-many-locals
349351
Args:
350352
user_id: The authenticated user ID
351353
conversation_id: The conversation ID
352-
model_id: The model identifier
354+
model: The model identifier
353355
started_at: ISO formatted timestamp when the request started
356+
completed_at: ISO formatted timestamp when the request completed
354357
summary: Summary of the turn including LLM response and tool calls
355358
query_request: The original query request
356359
configuration: Application configuration
@@ -360,7 +363,7 @@ def store_query_results( # pylint: disable=too-many-arguments,too-many-locals
360363
Raises:
361364
HTTPException: On any database, cache, or IO errors during processing
362365
"""
363-
provider, model = extract_provider_and_model_from_model_id(model_id)
366+
provider_id, model_id = extract_provider_and_model_from_model_id(model)
364367
# Store transcript if enabled
365368
if is_transcripts_enabled():
366369
try:
@@ -370,8 +373,8 @@ def store_query_results( # pylint: disable=too-many-arguments,too-many-locals
370373
store_transcript(
371374
user_id=user_id,
372375
conversation_id=conversation_id,
373-
model_id=model,
374-
provider_id=provider,
376+
model_id=model_id,
377+
provider_id=provider_id,
375378
query_is_valid=True, # TODO(lucasagomes): implement as part of query validation
376379
query=query_request.query,
377380
query_request=query_request,
@@ -394,8 +397,10 @@ def store_query_results( # pylint: disable=too-many-arguments,too-many-locals
394397
persist_user_conversation_details(
395398
user_id=user_id,
396399
conversation_id=conversation_id,
397-
model=model,
398-
provider_id=provider,
400+
started_at=started_at,
401+
completed_at=completed_at,
402+
model_id=model_id,
403+
provider_id=provider_id,
399404
topic_summary=topic_summary,
400405
)
401406
except SQLAlchemyError as e:
@@ -405,12 +410,11 @@ def store_query_results( # pylint: disable=too-many-arguments,too-many-locals
405410

406411
# Store conversation in cache
407412
try:
408-
completed_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
409413
cache_entry = CacheEntry(
410414
query=query_request.query,
411415
response=summary.llm_response,
412-
provider=provider,
413-
model=model,
416+
provider=provider_id,
417+
model=model_id,
414418
started_at=started_at,
415419
completed_at=completed_at,
416420
referenced_documents=summary.referenced_documents,
@@ -484,7 +488,9 @@ def is_transcripts_enabled() -> bool:
484488
def persist_user_conversation_details(
485489
user_id: str,
486490
conversation_id: str,
487-
model: str,
491+
started_at: str,
492+
completed_at: str,
493+
model_id: str,
488494
provider_id: str,
489495
topic_summary: Optional[str],
490496
) -> None:
@@ -493,7 +499,9 @@ def persist_user_conversation_details(
493499
Args:
494500
user_id: The authenticated user ID
495501
conversation_id: The conversation ID
496-
model: The model identifier
502+
started_at: The timestamp when the conversation started
503+
completed_at: The timestamp when the conversation completed
504+
model_id: The model identifier
497505
provider_id: The provider identifier
498506
topic_summary: Optional topic summary for the conversation
499507
"""
@@ -515,7 +523,7 @@ def persist_user_conversation_details(
515523
conversation = UserConversation(
516524
id=normalized_id,
517525
user_id=user_id,
518-
last_used_model=model,
526+
last_used_model=model_id,
519527
last_used_provider=provider_id,
520528
topic_summary=topic_summary or "",
521529
message_count=1,
@@ -525,7 +533,7 @@ def persist_user_conversation_details(
525533
"Associated conversation %s to user %s", normalized_id, user_id
526534
)
527535
else:
528-
existing_conversation.last_used_model = model
536+
existing_conversation.last_used_model = model_id
529537
existing_conversation.last_used_provider = provider_id
530538
existing_conversation.last_message_at = datetime.now(UTC)
531539
existing_conversation.message_count += 1
@@ -536,6 +544,27 @@ def persist_user_conversation_details(
536544
existing_conversation.message_count,
537545
)
538546

547+
max_turn_number = (
548+
session.query(func.max(UserTurn.turn_number))
549+
.filter_by(conversation_id=normalized_id)
550+
.scalar()
551+
)
552+
turn_number = (max_turn_number or 0) + 1
553+
turn = UserTurn(
554+
conversation_id=normalized_id,
555+
turn_number=turn_number,
556+
started_at=datetime.fromisoformat(started_at),
557+
completed_at=datetime.fromisoformat(completed_at),
558+
provider=provider_id,
559+
model=model_id,
560+
)
561+
session.add(turn)
562+
logger.debug(
563+
"Created conversation turn - Conversation: %s, Turn: %d",
564+
normalized_id,
565+
turn_number,
566+
)
567+
539568
session.commit()
540569
logger.debug(
541570
"Successfully committed conversation %s to database", normalized_id

0 commit comments

Comments
 (0)