Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from models.common.responses.responses_api_params import ResponsesApiParams
from models.common.turn_summary import TurnSummary
from models.config import Action
from utils.agents.query import retrieve_agent_response
from utils.conversations import append_turn_items_to_conversation
from utils.endpoints import (
check_configuration_loaded,
Expand Down Expand Up @@ -206,8 +207,13 @@ async def query_endpoint_handler(
client = await AsyncLlamaStackClientHolder().update_azure_token()

# Retrieve response using Responses API
turn_summary = await retrieve_response(
client, responses_params, moderation_result, endpoint_path
turn_summary = await retrieve_agent_response(
client=client,
responses_params=responses_params,
moderation_result=moderation_result,
endpoint_path=endpoint_path,
vector_store_ids=query_request.vector_store_ids or [],
rag_id_mapping=configuration.rag_id_mapping or {},
)

if moderation_result.decision == "passed":
Expand Down
89 changes: 12 additions & 77 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import asyncio
import datetime
import json
from collections.abc import AsyncIterator
from typing import Annotated, Any, Optional, cast

Expand Down Expand Up @@ -78,6 +77,10 @@
from models.common.responses.responses_api_params import ResponsesApiParams
from models.common.turn_summary import ReferencedDocument, TurnSummary
from models.config import Action
from utils.agents.streaming import (
generate_agent_response,
retrieve_agent_response_generator,
)
from utils.conversations import append_turn_items_to_conversation
from utils.endpoints import (
check_configuration_loaded,
Expand Down Expand Up @@ -115,6 +118,11 @@
validate_shield_ids_override,
)
from utils.stream_interrupts import get_stream_interrupt_registry
from utils.streaming_sse import (
format_stream_data,
shield_violation_generator,
stream_event,
)
from utils.suid import get_suid, normalize_conversation_id
from utils.token_counter import TokenCounter
from utils.vector_search import build_rag_context
Expand Down Expand Up @@ -287,7 +295,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
)
recording.record_llm_call(provider_id, model_id, endpoint_path)

generator, turn_summary = await retrieve_response_generator(
generator, turn_summary = await retrieve_agent_response_generator(
responses_params=responses_params,
context=context,
endpoint_path=endpoint_path,
Expand All @@ -306,7 +314,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
)

return StreamingResponse(
generate_response(
generate_agent_response(
generator=generator,
context=context,
responses_params=responses_params,
Expand Down Expand Up @@ -738,6 +746,7 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
# Store MCP call item info for later lookup when arguments.done event occurs
elif event_type == "response.output_item.added":
item_added_chunk = cast(OutputItemAddedChunk, chunk)

if item_added_chunk.item.type == "mcp_call":
mcp_call_item = cast(MCPCall, item_added_chunk.item)
mcp_calls[item_added_chunk.output_index] = (
Expand Down Expand Up @@ -923,22 +932,6 @@ def stream_http_error_event(
)


def format_stream_data(d: dict) -> str:
"""
Create a response generator function for Responses API streaming.

Parameters:
----------
d (dict): The data to be formatted as an SSE event.

Returns:
-------
str: The formatted SSE data string.
"""
data = json.dumps(d)
return f"data: {data}\n\n"


def stream_start_event(conversation_id: str, request_id: str) -> str:
"""Format an SSE start event for a streaming response.

Expand Down Expand Up @@ -1038,61 +1031,3 @@ def stream_end_event(
"available_quotas": available_quotas,
}
)


def stream_event(data: dict, event_type: str, media_type: str) -> str:
"""Build an item to yield based on media type.

Args:
data: Dictionary containing the event data
event_type: Type of event (token, tool call, etc.)
media_type: The media type for the response format

Returns:
SSE-formatted string representing the event
"""
if media_type == MEDIA_TYPE_TEXT:
if event_type == LLM_TOKEN_EVENT:
return data.get("token", "")
if event_type == LLM_TOOL_CALL_EVENT:
return f"[Tool Call: {data.get('function_name', 'unknown')}]\n"
if event_type == LLM_TOOL_RESULT_EVENT:
return "[Tool Result]\n"
if event_type == LLM_TURN_COMPLETE_EVENT:
return ""
return ""

return format_stream_data(
{
"event": event_type,
"data": data,
}
)


async def shield_violation_generator(
violation_message: str,
media_type: str = MEDIA_TYPE_TEXT,
) -> AsyncIterator[str]:
"""
Create an SSE stream for shield violation responses.

Yields start, token, and end events immediately for shield violations.
This function creates a minimal streaming response without going through
the Llama Stack response format.

Args:
violation_message: The violation message to display.
media_type: The media type for the response format.

Yields:
str: SSE-formatted strings for start, token, and end events.
"""
yield stream_event(
{
"id": 0,
"token": violation_message,
},
LLM_TOKEN_EVENT,
media_type,
)
3 changes: 2 additions & 1 deletion src/pydantic_ai_lightspeed/llamastack/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Pydantic AI provider for Llama Stack."""

from pydantic_ai_lightspeed.llamastack._model import LlamaStackResponsesModel
from pydantic_ai_lightspeed.llamastack._provider import LlamaStackProvider

__all__ = ["LlamaStackProvider"]
__all__ = ["LlamaStackProvider", "LlamaStackResponsesModel"]
Loading
Loading