Skip to content

Commit 1730b65

Browse files
author
Radovan Fuchs
committed
add the fix to streaming query as well along with UTs
1 parent eae3d67 commit 1730b65

2 files changed

Lines changed: 111 additions & 5 deletions

File tree

src/utils/endpoints.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -671,11 +671,19 @@ async def cleanup_after_streaming(
671671
session.query(UserConversation).filter_by(id=conversation_id).first()
672672
)
673673
if not existing_conversation:
674-
topic_summary = await get_topic_summary_func(
675-
query_request.query,
676-
client,
677-
llama_stack_model_id,
678-
)
674+
# Check if topic summary should be generated (default: True)
675+
should_generate = query_request.generate_topic_summary
676+
677+
if should_generate:
678+
logger.debug("Generating topic summary for new conversation")
679+
topic_summary = await get_topic_summary_func(
680+
query_request.query, client, llama_stack_model_id
681+
)
682+
else:
683+
logger.debug(
684+
"Topic summary generation disabled by request parameter"
685+
)
686+
topic_summary = None
679687

680688
completed_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
681689

tests/unit/utils/test_endpoints.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,3 +1036,101 @@ def test_create_referenced_documents_invalid_urls(self) -> None:
10361036
assert result[0].doc_title == "not-a-valid-url"
10371037
assert result[1].doc_url == AnyUrl("https://example.com/doc1")
10381038
assert result[1].doc_title == "doc1"
1039+
1040+
1041+
@pytest.mark.asyncio
1042+
async def test_cleanup_after_streaming_generate_topic_summary_default_true(
1043+
mocker: MockerFixture,
1044+
) -> None:
1045+
"""Test that topic summary is generated by default for new conversations."""
1046+
mock_is_transcripts_enabled = mocker.Mock(return_value=False)
1047+
mock_get_topic_summary = mocker.AsyncMock(return_value="Generated topic")
1048+
mock_store_transcript = mocker.Mock()
1049+
mock_persist_conversation = mocker.Mock()
1050+
mock_client = mocker.AsyncMock()
1051+
mock_config = mocker.Mock()
1052+
1053+
mock_session = mocker.Mock()
1054+
mock_session.query.return_value.filter_by.return_value.first.return_value = None
1055+
mock_session.__enter__ = mocker.Mock(return_value=mock_session)
1056+
mock_session.__exit__ = mocker.Mock(return_value=None)
1057+
mocker.patch("utils.endpoints.get_session", return_value=mock_session)
1058+
1059+
mocker.patch("utils.endpoints.create_referenced_documents_with_metadata", return_value=[])
1060+
mocker.patch("utils.endpoints.store_conversation_into_cache")
1061+
1062+
query_request = QueryRequest(query="test query")
1063+
1064+
await endpoints.cleanup_after_streaming(
1065+
user_id="test_user",
1066+
conversation_id="test_conv_id",
1067+
model_id="test_model",
1068+
provider_id="test_provider",
1069+
llama_stack_model_id="test_llama_model",
1070+
query_request=query_request,
1071+
summary=mocker.Mock(llm_response="test response", tool_calls=[]),
1072+
metadata_map={},
1073+
started_at="2024-01-01T00:00:00Z",
1074+
client=mock_client,
1075+
config=mock_config,
1076+
skip_userid_check=False,
1077+
get_topic_summary_func=mock_get_topic_summary,
1078+
is_transcripts_enabled_func=mock_is_transcripts_enabled,
1079+
store_transcript_func=mock_store_transcript,
1080+
persist_user_conversation_details_func=mock_persist_conversation,
1081+
)
1082+
1083+
mock_get_topic_summary.assert_called_once_with(
1084+
"test query", mock_client, "test_llama_model"
1085+
)
1086+
1087+
mock_persist_conversation.assert_called_once()
1088+
assert mock_persist_conversation.call_args[1]["topic_summary"] == "Generated topic"
1089+
1090+
1091+
@pytest.mark.asyncio
1092+
async def test_cleanup_after_streaming_generate_topic_summary_explicit_false(
1093+
mocker: MockerFixture,
1094+
) -> None:
1095+
"""Test that topic summary is NOT generated when explicitly set to False."""
1096+
mock_is_transcripts_enabled = mocker.Mock(return_value=False)
1097+
mock_get_topic_summary = mocker.AsyncMock(return_value="Generated topic")
1098+
mock_store_transcript = mocker.Mock()
1099+
mock_persist_conversation = mocker.Mock()
1100+
mock_client = mocker.AsyncMock()
1101+
mock_config = mocker.Mock()
1102+
1103+
mock_session = mocker.Mock()
1104+
mock_session.query.return_value.filter_by.return_value.first.return_value = None
1105+
mock_session.__enter__ = mocker.Mock(return_value=mock_session)
1106+
mock_session.__exit__ = mocker.Mock(return_value=None)
1107+
mocker.patch("utils.endpoints.get_session", return_value=mock_session)
1108+
1109+
mocker.patch("utils.endpoints.create_referenced_documents_with_metadata", return_value=[])
1110+
mocker.patch("utils.endpoints.store_conversation_into_cache")
1111+
1112+
query_request = QueryRequest(query="test query", generate_topic_summary=False)
1113+
1114+
await endpoints.cleanup_after_streaming(
1115+
user_id="test_user",
1116+
conversation_id="test_conv_id",
1117+
model_id="test_model",
1118+
provider_id="test_provider",
1119+
llama_stack_model_id="test_llama_model",
1120+
query_request=query_request,
1121+
summary=mocker.Mock(llm_response="test response", tool_calls=[]),
1122+
metadata_map={},
1123+
started_at="2024-01-01T00:00:00Z",
1124+
client=mock_client,
1125+
config=mock_config,
1126+
skip_userid_check=False,
1127+
get_topic_summary_func=mock_get_topic_summary,
1128+
is_transcripts_enabled_func=mock_is_transcripts_enabled,
1129+
store_transcript_func=mock_store_transcript,
1130+
persist_user_conversation_details_func=mock_persist_conversation,
1131+
)
1132+
1133+
mock_get_topic_summary.assert_not_called()
1134+
1135+
mock_persist_conversation.assert_called_once()
1136+
assert mock_persist_conversation.call_args[1]["topic_summary"] is None

0 commit comments

Comments
 (0)