-
Notifications
You must be signed in to change notification settings - Fork 261
Expand file tree
/
Copy pathchat_service.py
More file actions
308 lines (256 loc) · 13.9 KB
/
chat_service.py
File metadata and controls
308 lines (256 loc) · 13.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
"""
Provides the ChatService class and related utilities for handling chat interactions,
streaming responses, RAG (Retrieval-Augmented Generation) processing, and chart data
generation for visualization in a call center knowledge mining solution.
Includes thread management, caching, and integration with Azure OpenAI and FastAPI.
"""
import asyncio
import json
import logging
import os
import random
import re
from typing import AsyncGenerator
from common.logging.event_utils import track_event_if_configured
from helpers.azure_credential_utils import get_azure_credential_async
from common.database.sqldb_service import SQLTool, get_db_connection as get_sqldb_connection
from fastapi import HTTPException, status
from azure.ai.projects.aio import AIProjectClient
from agent_framework.azure import AzureAIProjectAgentProvider
from cachetools import TTLCache
from common.config.config import Config
# Constants
HOST_NAME = "CKM"
HOST_INSTRUCTIONS = "Answer questions about call center operations"
logger = logging.getLogger(__name__)
# Suppress informational warnings from agent_framework about runtime
# tool/structured_output overrides not being supported by AzureAIClient.
# This can be made configurable via env var if needed for debugging.
agent_log_level = os.getenv("AGENT_FRAMEWORK_LOG_LEVEL", "ERROR").upper()
logging.getLogger("agent_framework.azure").setLevel(getattr(logging, agent_log_level, logging.ERROR))
class ExpCache(TTLCache):
"""Extended TTLCache that deletes Azure AI agent threads when items expire."""
def __init__(self, *args, **kwargs):
"""Initialize cache without creating persistent client connections."""
super().__init__(*args, **kwargs)
def expire(self, time=None):
"""Remove expired items and delete associated Azure AI threads."""
items = super().expire(time)
for key, thread_conversation_id in items:
try:
# Create task for async deletion with proper session management
asyncio.create_task(self._delete_thread_async(thread_conversation_id))
logger.info("Scheduled thread deletion: %s", thread_conversation_id)
except Exception as e:
logger.exception("Failed to schedule thread deletion for key %s: %s", key, e)
return items
def popitem(self):
"""Remove item using LRU eviction and delete associated Azure AI thread."""
key, thread_conversation_id = super().popitem()
try:
# Create task for async deletion with proper session management
asyncio.create_task(self._delete_thread_async(thread_conversation_id))
logger.info("Scheduled thread deletion (LRU evict): %s", thread_conversation_id)
except Exception as e:
logger.exception("Failed to schedule thread deletion for key %s (LRU evict): %s", key, e)
return key, thread_conversation_id
async def _delete_thread_async(self, thread_conversation_id: str):
"""Asynchronously delete a thread using a properly managed Azure AI Project Client."""
credential = None
config = Config()
try:
if thread_conversation_id:
# Get credential and use async context managers to ensure proper cleanup
credential = await get_azure_credential_async(client_id=config.azure_client_id)
async with AIProjectClient(
endpoint=config.ai_project_endpoint,
credential=credential
) as project_client:
openai_client = project_client.get_openai_client()
await openai_client.conversations.delete(conversation_id=thread_conversation_id)
logger.info("Thread deleted successfully: %s", thread_conversation_id)
except Exception as e:
logger.exception("Failed to delete thread %s: %s", thread_conversation_id, e)
finally:
# Close credential to prevent unclosed client session warnings
if credential is not None:
await credential.close()
thread_cache = None
class ChatService:
"""
Service for handling chat interactions, including streaming responses,
processing RAG responses, and generating chart data for visualization.
"""
def __init__(self):
config = Config()
self.orchestrator_agent_name = config.orchestrator_agent_name
self.azure_client_id = config.azure_client_id
self.ai_project_endpoint = config.ai_project_endpoint
def get_thread_cache(self):
"""Get or create the global thread cache."""
global thread_cache
if thread_cache is None:
thread_cache = ExpCache(maxsize=1000, ttl=3600.0)
return thread_cache
async def stream_openai_text(self, conversation_id: str, query: str, user_id: str = "") -> AsyncGenerator[tuple[str, str], None]:
"""
Get a streaming text response from OpenAI.
Yields:
tuple[str, str]: (role, content) tuples where role is "assistant" or "tool"
"""
logger.info("stream_openai_text called: conversation_id=%s, query_length=%d",
conversation_id, len(query) if query else 0)
async with (
await get_azure_credential_async(client_id=self.azure_client_id) as credential,
AIProjectClient(endpoint=self.ai_project_endpoint, credential=credential) as project_client,
):
complete_response = ""
db_conn = None
had_error = False
try:
if not query:
query = "Please provide a query."
# Create provider for agent management
provider = AzureAIProjectAgentProvider(project_client=project_client)
db_conn = await get_sqldb_connection()
custom_tool = SQLTool(conn=db_conn)
thread_conversation_id = None
cache = self.get_thread_cache()
thread_conversation_id = cache.get(conversation_id, None)
if thread_conversation_id:
logger.info("Reusing existing thread %s for conversation %s",
thread_conversation_id, conversation_id)
# Get agent with tools using provider
logger.info("Retrieving orchestrator agent: '%s'", self.orchestrator_agent_name)
agent = await provider.get_agent(
name=self.orchestrator_agent_name,
tools=custom_tool.get_sql_response
)
logger.info("Orchestrator agent retrieved successfully: '%s'", self.orchestrator_agent_name)
citations = []
citation_marker_map = {} # Maps original markers to sequential numbers
citation_counter = 0
if not thread_conversation_id:
# Create a conversation using OpenAI client for conversation continuity
logger.info("No existing thread found, creating new thread for conversation %s", conversation_id)
openai_client = project_client.get_openai_client()
conversation = await openai_client.conversations.create()
thread_conversation_id = conversation.id
logger.info("New thread created: %s for conversation %s", thread_conversation_id, conversation_id)
def replace_citation_marker(match):
nonlocal citation_counter
marker = match.group(0)
if marker not in citation_marker_map:
citation_counter += 1
citation_marker_map[marker] = citation_counter
return f"[{citation_marker_map[marker]}]"
logger.info("Starting agent.run stream for conversation %s, thread %s",
conversation_id, thread_conversation_id)
async for chunk in agent.run(query, stream=True, conversation_id=thread_conversation_id):
# Collect citations from Azure AI Search responses
for content in getattr(chunk, "contents", []):
annotations = getattr(content, "annotations", [])
if annotations:
citations.extend(annotations)
chunk_text = str(chunk.text) if chunk.text else ""
# Replace complete citation markers like 【4:0†source】 or 【4:0 source】 with [1], [2], etc.
chunk_text = re.sub(r'【\d+:\d+†?[^】]*】', replace_citation_marker, chunk_text)
if chunk_text:
complete_response += chunk_text
yield ("assistant", chunk_text)
logger.info("Streaming complete for conversation %s: response_length=%d, citation_count=%d",
conversation_id, len(complete_response), len(citations))
track_event_if_configured("ChatResponseCompleted", {
"conversation_id": conversation_id,
"user_id": user_id,
"response_length": len(complete_response),
"citation_count": len(citations),
"response_content": complete_response[:8192] if len(complete_response) > 8192 else complete_response
})
cache[conversation_id] = thread_conversation_id
citation_json = "[]"
if citations:
citation_list = []
seen_doc_ids = set() # Track unique document IDs to avoid duplicates
for citation in citations:
get_url = (citation.get("additional_properties") or {}).get("get_url")
url = get_url if get_url else 'N/A'
title = citation.get('title', 'N/A')
# Extract document ID from the get_url to use as a more meaningful title
doc_id = None
if get_url and title.startswith('doc_'):
# URL format: .../indexes/{index_name}/docs/{document_id}?api-version=...
match = re.search(r'/docs/([^?]+)', get_url)
if match:
doc_id = match.group(1)
title = doc_id
# Skip duplicate citations based on document ID
if doc_id and doc_id in seen_doc_ids:
continue
if doc_id:
seen_doc_ids.add(doc_id)
citation_list.append({"url": url, "title": title})
citation_json = json.dumps(citation_list)
except Exception as e:
had_error = True
logger.exception("Error in stream_openai_text: %s", e)
cache = self.get_thread_cache()
thread_conversation_id = cache.pop(conversation_id, None)
if thread_conversation_id is not None:
corrupt_key = f"{conversation_id}_corrupt_{random.randint(1000, 9999)}"
cache[corrupt_key] = thread_conversation_id
# Provide user-friendly error messages
error_message = str(e).lower()
if "too many requests" in error_message or "429" in error_message:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="The service is currently experiencing high demand. Please try again in a few moments."
) from e
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An error occurred while processing the request."
) from e
finally:
# Close the DB connection to prevent connection leaks
if db_conn is not None:
try:
db_conn.close()
except Exception: # Best-effort connection cleanup
pass
# Only emit fallback and tool citations if no error occurred
if not had_error:
if complete_response == "":
logger.info("No response received from OpenAI.")
yield ("assistant", "I cannot answer this question with the current data. Please rephrase or add more details.")
yield ("tool", citation_json)
async def stream_chat_request(self, conversation_id, query, user_id: str = ""):
"""
Handles streaming chat requests.
"""
logger.info("stream_chat_request called: conversation_id=%s", conversation_id)
async def generate():
try:
async for role, content in self.stream_openai_text(conversation_id, query, user_id=user_id):
if content:
response = {
"choices": [
{
"delta": {
"role": role,
"content": content
}
}
]
}
yield json.dumps(response) + "\n\n"
except Exception as e:
logger.exception("Unexpected error: %s", e)
# Extract user-friendly message from HTTPException if available
if isinstance(e, HTTPException):
error_message = e.detail
else:
error_message = "An error occurred while processing the request."
error_response = {"error": error_message}
yield json.dumps(error_response) + "\n\n"
return generate()