Skip to content

Commit 1bb53c6

Browse files
committed
fix (chat); voice mode backend logic update
1 parent fc8329e commit 1bb53c6

2 files changed

Lines changed: 108 additions & 83 deletions

File tree

src/server/main/chat/utils.py

Lines changed: 93 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -394,93 +394,109 @@ async def process_voice_command(
394394
user_timezone = ZoneInfo("UTC")
395395
current_user_time = datetime.datetime.now(user_timezone).strftime('%Y-%m-%d %H:%M:%S %Z')
396396

397-
# 3. Full tool selection logic
397+
# 3. Full tool selection logic (Stage 1)
398398
await send_status_update({"type": "status", "message": "choosing_tools"})
399399

400400
user_integrations = user_data.get("integrations", {})
401401
connected_tools, disconnected_tools = _get_tool_lists(user_integrations)
402402

403-
# Call Stage 1
404-
stage1_output = await _get_stage1_response(qwen_formatted_history, connected_tools, disconnected_tools, user_id)
403+
stage1_result = await _get_stage1_response(qwen_formatted_history, connected_tools, disconnected_tools, user_id)
405404

406-
final_text_response = "I'm sorry, I couldn't process that." # Default error message
405+
topic_changed = stage1_result.get("topic_changed", False)
406+
relevant_tool_names = stage1_result.get("connected_tools", [])
407+
disconnected_requested_tools = stage1_result.get("disconnected_tools", [])
407408

408-
if isinstance(stage1_output, str):
409-
# Stage 1 provided a direct response
410-
logger.info(f"Stage 1 provided a direct response for voice command for user {user_id}.")
411-
final_text_response = _extract_answer_from_llm_response(stage1_output)
409+
# 4. Stage 2 setup
410+
mandatory_tools = {"memory", "history", "tasks"}
411+
final_tool_names = set(relevant_tool_names) | mandatory_tools
412+
413+
filtered_mcp_servers = {}
414+
for tool_name in final_tool_names:
415+
config = INTEGRATIONS_CONFIG.get(tool_name, {})
416+
if config:
417+
mcp_config = config.get("mcp_server_config", {})
418+
if mcp_config and mcp_config.get("url") and mcp_config.get("name"):
419+
server_name = mcp_config["name"]
420+
filtered_mcp_servers[server_name] = {
421+
"url": mcp_config["url"],
422+
"headers": {"X-User-ID": user_id},
423+
"transport": "sse"
424+
}
425+
tools = [{"mcpServers": filtered_mcp_servers}]
426+
logger.info(f"Voice Command Tools (Stage 2): {list(filtered_mcp_servers.keys())}")
427+
428+
# History truncation logic
429+
stage_2_expanded_messages = []
430+
if topic_changed:
431+
last_user_message = next((msg for msg in reversed(qwen_formatted_history) if msg.get("role") == "user"), None)
432+
if last_user_message:
433+
stage_2_expanded_messages.append({"role": "user", "content": last_user_message.get("content", "")})
412434
else:
413-
# Stage 1 selected tools, proceed to Stage 2
414-
relevant_tool_names = stage1_output
415-
mandatory_tools = {"memory", "history", "tasks"}
416-
final_tool_names = set(relevant_tool_names) | mandatory_tools
417-
418-
# Build the list of tools for the agent, ensuring headers are included.
419-
filtered_mcp_servers = {}
420-
for tool_name in final_tool_names:
421-
config = INTEGRATIONS_CONFIG.get(tool_name, {})
422-
if config:
423-
mcp_config = config.get("mcp_server_config", {})
424-
if mcp_config and mcp_config.get("url") and mcp_config.get("name"):
425-
server_name = mcp_config["name"]
426-
filtered_mcp_servers[server_name] = {
427-
"url": mcp_config["url"],
428-
"headers": {"X-User-ID": user_id},
429-
"transport": "sse"
430-
}
431-
tools = [{"mcpServers": filtered_mcp_servers}]
432-
logger.info(f"Voice Command Tools (Stage 2): {list(filtered_mcp_servers.keys())}")
433-
434-
expanded_messages = []
435435
for msg in qwen_formatted_history:
436-
expanded_messages.append({
437-
"role": msg.get("role", "user"),
438-
"content": msg.get("content", "")
439-
})
440-
441-
system_prompt = STAGE_2_SYSTEM_PROMPT.format(
442-
username=username,
443-
location=location,
444-
current_user_time=current_user_time
436+
stage_2_expanded_messages.append({"role": msg.get("role", "user"), "content": msg.get("content", "")})
437+
438+
# Handle disconnected tools note
439+
if disconnected_requested_tools:
440+
disconnected_display_names = [INTEGRATIONS_CONFIG.get(t, {}).get('display_name', t) for t in disconnected_requested_tools]
441+
system_note = (
442+
f"System Note: The user's request mentioned functionality requiring the following tools which are currently disconnected: "
443+
f"{', '.join(disconnected_display_names)}. You MUST inform the user that you cannot complete that part of the request "
444+
f"and suggest they connect the tool(s) in the Integrations page. Then, proceed with the rest of the request using the available tools."
445445
)
446-
447-
await send_status_update({"type": "status", "message": "thinking"})
448-
449-
# --- MODIFICATION: Run blocking agent code in a separate thread ---
450-
loop = asyncio.get_running_loop()
451-
def agent_worker():
452-
final_run_response = None
453-
try:
454-
for response in run_agent_with_fallback(system_message=system_prompt, function_list=tools, messages=expanded_messages):
455-
456-
if isinstance(response, list) and response:
457-
# Schedule the async status update on the main event loop
458-
asyncio.run_coroutine_threadsafe(
459-
send_status_update({"type": "status", "message": f"using_tool_{tool_name}"}),
460-
loop
461-
)
462-
return final_run_response
463-
except Exception as e:
464-
logger.error(f"Error inside agent_worker thread for voice command: {e}", exc_info=True)
465-
return None
466-
467-
final_run_response = await asyncio.to_thread(agent_worker)
468-
# --- END MODIFICATION ---
469-
470-
if final_run_response and isinstance(final_run_response, list):
471-
assistant_content_parts = [
472-
msg.get('content', '')
473-
for msg in final_run_response
474-
if msg.get('role') == 'assistant' and msg.get('content')
475-
]
476-
full_response_str = "".join(assistant_content_parts)
477-
final_text_response = _extract_answer_from_llm_response(full_response_str)
478-
479-
if not final_text_response:
480-
last_message = final_run_response[-1]
481-
if last_message.get('role') == 'function':
482-
# Provide a more generic completion message if there's no explicit text answer
483-
final_text_response = "The action has been completed."
446+
if stage_2_expanded_messages and stage_2_expanded_messages[-1]['role'] == 'user':
447+
stage_2_expanded_messages[-1]['content'] = f"{system_note}\n\nUser's original message: {stage_2_expanded_messages[-1]['content']}"
448+
else:
449+
stage_2_expanded_messages.append({'role': 'system', 'content': system_note})
450+
451+
system_prompt = STAGE_2_SYSTEM_PROMPT.format(
452+
username=username,
453+
location=location,
454+
current_user_time=current_user_time
455+
)
456+
457+
await send_status_update({"type": "status", "message": "thinking"})
458+
459+
# 5. Agent Execution in a thread
460+
loop = asyncio.get_running_loop()
461+
def agent_worker():
462+
final_run_response = None
463+
try:
464+
for response in run_agent_with_fallback(system_message=system_prompt, function_list=tools, messages=stage_2_expanded_messages):
465+
final_run_response = response
466+
if isinstance(response, list) and response:
467+
last_message = response[-1]
468+
if last_message.get('role') == 'assistant' and last_message.get('function_call'):
469+
tool_name = last_message['function_call']['name']
470+
asyncio.run_coroutine_threadsafe(
471+
send_status_update({"type": "status", "message": f"using_tool_{tool_name}"}),
472+
loop
473+
)
474+
return final_run_response
475+
except Exception as e:
476+
logger.error(f"Error inside agent_worker thread for voice command: {e}", exc_info=True)
477+
return None
478+
479+
final_run_response = await asyncio.to_thread(agent_worker)
480+
481+
# 6. Process result
482+
full_response_str = ""
483+
if final_run_response and isinstance(final_run_response, list):
484+
assistant_content_parts = [
485+
msg.get('content', '')
486+
for msg in final_run_response
487+
if msg.get('role') == 'assistant' and msg.get('content')
488+
]
489+
full_response_str = "".join(assistant_content_parts)
490+
491+
final_text_response = _extract_answer_from_llm_response(full_response_str)
492+
493+
if not final_text_response:
494+
last_message = final_run_response[-1] if final_run_response else {}
495+
if last_message.get('role') == 'function':
496+
final_text_response = "The action has been completed."
497+
else:
498+
final_text_response = "I'm sorry, I couldn't process that."
499+
484500
await db_manager.messages_collection.update_one(
485501
{"message_id": assistant_message_id, "user_id": user_id},
486502
{"$set": {"content": final_text_response}}

src/server/main/voice/routes.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,17 @@ async def initiate_voice_session(
4444
expires_at = now + TOKEN_EXPIRATION_SECONDS
4545
rtc_token_cache[rtc_token] = {"user_id": user_id, "expires_at": expires_at}
4646

47-
# Get TURN credentials to send to the client
48-
ice_servers_config = get_credentials()
47+
if ENVIRONMENT in ["dev-local", "selfhost"]:
48+
logger.info(f"Initiated voice session for user {user_id} in dev-local mode with token {rtc_token}")
49+
return {"rtc_token": rtc_token, "ice_servers": []} # No TURN server in dev-local mode
50+
51+
else:
52+
logger.info(f"Initiated voice session for user {user_id} with token {rtc_token} using TURN server")
53+
# Get TURN credentials to send to the client
54+
ice_servers_config = await get_credentials()
4955

50-
logger.info(f"Initiated voice session for user {user_id} with token {rtc_token}")
51-
return {"rtc_token": rtc_token, "ice_servers": ice_servers_config}
56+
logger.info(f"Initiated voice session for user {user_id} with token {rtc_token}")
57+
return {"rtc_token": rtc_token, "ice_servers": ice_servers_config}
5258

5359

5460
class MyVoiceChatHandler(ReplyOnPause):
@@ -119,9 +125,12 @@ async def process_audio_chunk(self, audio: tuple[int, np.ndarray]):
119125
await self.send_message(json.dumps({"type": "stt_result", "text": transcription}))
120126

121127
# 2. FULL AGENTIC LLM PROCESSING
122-
async def send_status_update(status_dict: Dict[str, Any]):
123-
await self.send_message(json.dumps(status_dict))
128+
# Define the callback function that process_voice_command will use to send status updates
129+
async def send_status_update(status_update: Dict[str, Any]):
130+
"""Sends a status update message to the client."""
131+
await self.send_message(json.dumps(status_update))
124132

133+
# Call the updated, fully-featured voice command processor
125134
full_response_buffer, assistant_message_id = await process_voice_command(
126135
user_id=user_id,
127136
transcribed_text=transcription,

0 commit comments

Comments
 (0)