Skip to content

Commit effd5da

Browse files
authored
Merge pull request lightspeed-core#1612 from asimurka/refactor_responses_endpoint
LCORE-1262: Use context data class in responses
2 parents ca125c4 + f7389dd commit effd5da

13 files changed

Lines changed: 670 additions & 529 deletions

File tree

src/app/endpoints/query.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from client import AsyncLlamaStackClientHolder
2424
from configuration import configuration
2525
from log import get_logger
26+
from models.common.responses.responses_api_params import ResponsesApiParams
2627
from models.config import Action
2728
from models.requests import QueryRequest
2829
from models.responses import (
@@ -65,7 +66,6 @@
6566
from utils.shields import run_shield_moderation, validate_shield_ids_override
6667
from utils.suid import normalize_conversation_id
6768
from utils.types import (
68-
ResponsesApiParams,
6969
ShieldModerationResult,
7070
TurnSummary,
7171
)

src/app/endpoints/responses.py

Lines changed: 180 additions & 278 deletions
Large diffs are not rendered by default.

src/app/endpoints/streaming_query.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
)
6060
from log import get_logger
6161
from metrics import recording
62+
from models.common.responses.responses_api_params import ResponsesApiParams
6263
from models.config import Action
6364
from models.context import ResponseGeneratorContext
6465
from models.requests import QueryRequest
@@ -115,7 +116,7 @@
115116
from utils.stream_interrupts import get_stream_interrupt_registry
116117
from utils.suid import get_suid, normalize_conversation_id
117118
from utils.token_counter import TokenCounter
118-
from utils.types import ReferencedDocument, ResponsesApiParams, TurnSummary
119+
from utils.types import ReferencedDocument, TurnSummary
119120
from utils.vector_search import build_rag_context
120121

121122
logger = get_logger(__name__)
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
"""Request parameter model for Llama Stack responses API calls."""
2+
3+
from collections.abc import Mapping
4+
from typing import Any, Final, Optional
5+
6+
from llama_stack_api.openai_responses import (
7+
OpenAIResponseInputTool as InputTool,
8+
)
9+
from llama_stack_api.openai_responses import (
10+
OpenAIResponseInputToolChoice as ToolChoice,
11+
)
12+
from llama_stack_api.openai_responses import (
13+
OpenAIResponsePrompt as Prompt,
14+
)
15+
from llama_stack_api.openai_responses import (
16+
OpenAIResponseReasoning as Reasoning,
17+
)
18+
from llama_stack_api.openai_responses import (
19+
OpenAIResponseText as Text,
20+
)
21+
from llama_stack_api.openai_responses import (
22+
OpenAIResponseToolMCP as OutputToolMCP,
23+
)
24+
from pydantic import BaseModel, Field
25+
26+
from utils.tool_formatter import translate_vector_store_ids_to_user_facing
27+
from utils.types import IncludeParameter, ResponseInput
28+
29+
# Attribute names that are echoed back in the response.
30+
_ECHOED_FIELDS: Final[set[str]] = set(
31+
{
32+
"instructions",
33+
"max_tool_calls",
34+
"max_output_tokens",
35+
"metadata",
36+
"model",
37+
"parallel_tool_calls",
38+
"previous_response_id",
39+
"prompt",
40+
"reasoning",
41+
"safety_identifier",
42+
"temperature",
43+
"top_p",
44+
"truncation",
45+
"text",
46+
"tool_choice",
47+
"store",
48+
}
49+
)
50+
51+
52+
class ResponsesApiParams(BaseModel):
53+
"""Parameters for a Llama Stack Responses API request.
54+
55+
All fields accepted by the Llama Stack client responses.create() body are
56+
included so that dumped model can be passed directly to response create.
57+
"""
58+
59+
input: ResponseInput = Field(description="The input text or structured input items")
60+
model: str = Field(description='The full model ID in format "provider/model"')
61+
conversation: str = Field(description="The conversation ID in llama-stack format")
62+
include: Optional[list[IncludeParameter]] = Field(
63+
default=None,
64+
description="Output item types to include in the response",
65+
)
66+
instructions: Optional[str] = Field(
67+
default=None, description="The resolved system prompt"
68+
)
69+
max_infer_iters: Optional[int] = Field(
70+
default=None,
71+
description="Maximum number of inference iterations",
72+
)
73+
max_output_tokens: Optional[int] = Field(
74+
default=None,
75+
description="Maximum number of tokens allowed in the response",
76+
)
77+
max_tool_calls: Optional[int] = Field(
78+
default=None,
79+
description="Maximum tool calls allowed in a single response",
80+
)
81+
metadata: Optional[dict[str, str]] = Field(
82+
default=None,
83+
description="Custom metadata for tracking or logging",
84+
)
85+
parallel_tool_calls: Optional[bool] = Field(
86+
default=None,
87+
description="Whether the model can make multiple tool calls in parallel",
88+
)
89+
previous_response_id: Optional[str] = Field(
90+
default=None,
91+
description="Identifier of the previous response in a multi-turn conversation",
92+
)
93+
prompt: Optional[Prompt] = Field(
94+
default=None,
95+
description="Prompt template with variables for dynamic substitution",
96+
)
97+
reasoning: Optional[Reasoning] = Field(
98+
default=None,
99+
description="Reasoning configuration for the response",
100+
)
101+
safety_identifier: Optional[str] = Field(
102+
default=None,
103+
description="Stable identifier for safety monitoring and abuse detection",
104+
)
105+
store: bool = Field(description="Whether to store the response")
106+
stream: bool = Field(description="Whether to stream the response")
107+
temperature: Optional[float] = Field(
108+
default=None,
109+
description="Sampling temperature (e.g. 0.0-2.0)",
110+
)
111+
text: Optional[Text] = Field(
112+
default=None,
113+
description="Text response configuration (format constraints)",
114+
)
115+
tool_choice: Optional[ToolChoice] = Field(
116+
default=None,
117+
description="Tool selection strategy",
118+
)
119+
tools: Optional[list[InputTool]] = Field(
120+
default=None,
121+
description="Prepared tool groups for Responses API (same type as ResponsesRequest.tools)",
122+
)
123+
extra_headers: Optional[dict[str, str]] = Field(
124+
default=None,
125+
description="Extra HTTP headers to send with the request (e.g. x-llamastack-provider-data)",
126+
)
127+
128+
def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
129+
"""Serialize params, re-injecting MCP authorization stripped by exclude=True.
130+
131+
llama-stack-api marks ``InputToolMCP.authorization`` with
132+
``Field(exclude=True)`` to prevent token leakage in API responses.
133+
The base ``model_dump()`` therefore strips the field, but we need it
134+
in the request payload so llama-stack server can authenticate with
135+
MCP servers. See LCORE-1414 / GitHub issue #1269.
136+
"""
137+
result = super().model_dump(*args, **kwargs)
138+
# Only one context option is allowed, previous_response_id has priority
139+
# Turn is added to conversation manually if previous_response_id is used
140+
if self.previous_response_id:
141+
result.pop("conversation", None)
142+
dumped_tools = result.get("tools")
143+
if not self.tools or not isinstance(dumped_tools, list):
144+
return result
145+
if len(dumped_tools) != len(self.tools):
146+
return result
147+
for tool, dumped_tool in zip(self.tools, dumped_tools):
148+
authorization = getattr(tool, "authorization", None)
149+
if authorization is not None and isinstance(dumped_tool, dict):
150+
dumped_tool["authorization"] = authorization
151+
return result
152+
153+
def echoed_params(self, rag_id_mapping: Mapping[str, str]) -> dict[str, Any]:
154+
"""Build kwargs echoed into synthetic OpenAI-style responses (e.g. moderation blocks).
155+
156+
Parameters:
157+
rag_id_mapping: Llama Stack vector_db_id to user-facing RAG id (from app config).
158+
Returns:
159+
dict[str, Any]: Field names and values to merge into the response object.
160+
"""
161+
data = self.model_dump(include=_ECHOED_FIELDS)
162+
if self.tools is not None:
163+
tool_dicts: list[dict[str, Any]] = []
164+
for t in self.tools:
165+
if t.type == "mcp":
166+
validated = OutputToolMCP.model_validate(t.model_dump())
167+
tool_dicts.append(validated.model_dump())
168+
else:
169+
tool_dicts.append(t.model_dump())
170+
171+
data["tools"] = translate_vector_store_ids_to_user_facing(
172+
tool_dicts, rag_id_mapping
173+
)
174+
175+
return data
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""Request-scoped context model for the responses endpoint pipeline."""
2+
3+
from datetime import datetime
4+
from typing import Optional
5+
6+
from fastapi import BackgroundTasks
7+
from llama_stack_client import AsyncLlamaStackClient
8+
from pydantic import BaseModel, ConfigDict, Field
9+
10+
from utils.types import RAGContext, ShieldModerationResult
11+
12+
13+
class ResponsesContext(BaseModel):
14+
"""Shared request-scoped context for the /responses endpoint pipeline."""
15+
16+
model_config = ConfigDict(arbitrary_types_allowed=True)
17+
18+
client: AsyncLlamaStackClient = Field(description="The Llama Stack client")
19+
auth: tuple[str, str, bool, str] = Field(
20+
description="Authentication tuple (user_id, username, skip_userid_check, token)",
21+
)
22+
input_text: str = Field(description="Extracted user input text for the turn")
23+
started_at: datetime = Field(description="UTC timestamp when the request started")
24+
moderation_result: ShieldModerationResult = Field(
25+
description="Shield moderation outcome",
26+
)
27+
inline_rag_context: RAGContext = Field(
28+
description="Inline RAG context for the turn"
29+
)
30+
filter_server_tools: bool = Field(
31+
default=False,
32+
description="Whether to filter server-deployed MCP tool events from output",
33+
)
34+
background_tasks: Optional[BackgroundTasks] = Field(
35+
default=None,
36+
description="Background tasks for telemetry, if enabled",
37+
)
38+
rh_identity_context: tuple[str, str] = Field(
39+
default=("", ""),
40+
description="RH identity (org_id, system_id) for Splunk events",
41+
)
42+
user_agent: Optional[str] = Field(
43+
default=None,
44+
description="User-Agent string from request headers",
45+
)
46+
endpoint_path: str = Field(
47+
...,
48+
description="API endpoint path used for metric labeling",
49+
)
50+
generate_topic_summary: bool = Field(
51+
default=False,
52+
description="Whether to generate a topic summary for new conversations",
53+
)

src/models/requests.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,8 @@
2121
from llama_stack_api.openai_responses import (
2222
OpenAIResponseText as Text,
2323
)
24-
from llama_stack_api.openai_responses import (
25-
OpenAIResponseToolMCP as OutputToolMCP,
26-
)
2724
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
2825

29-
from configuration import configuration
3026
from constants import (
3127
MCP_AUTH_CLIENT,
3228
MCP_AUTH_KUBERNETES,
@@ -38,7 +34,6 @@
3834
)
3935
from log import get_logger
4036
from utils import suid
41-
from utils.tool_formatter import translate_vector_store_ids_to_user_facing
4237
from utils.types import IncludeParameter, ResponseInput
4338

4439
logger = get_logger(__name__)
@@ -867,28 +862,6 @@ def check_previous_response_id(cls, value: Optional[str]) -> Optional[str]:
867862
raise ValueError("You cannot provide context by moderation response.")
868863
return value
869864

870-
def echoed_params(self) -> dict[str, Any]:
871-
"""Build kwargs echoed into synthetic OpenAI-style responses (e.g. moderation blocks).
872-
873-
Returns:
874-
dict[str, Any]: Field names and values to merge into the response object.
875-
"""
876-
data = self.model_dump(include=_ECHOED_FIELDS)
877-
if self.tools is not None:
878-
tool_dicts: list[dict[str, Any]] = [
879-
(
880-
OutputToolMCP.model_validate(t.model_dump()).model_dump()
881-
if t.type == "mcp"
882-
else t.model_dump()
883-
)
884-
for t in self.tools
885-
]
886-
data["tools"] = translate_vector_store_ids_to_user_facing(
887-
tool_dicts, configuration.rag_id_mapping
888-
)
889-
890-
return data
891-
892865

893866
class MCPServerRegistrationRequest(BaseModel):
894867
"""Request model for dynamically registering an MCP server.

src/utils/responses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
from constants import DEFAULT_RAG_TOOL
9292
from log import get_logger
9393
from metrics import recording
94+
from models.common.responses.responses_api_params import ResponsesApiParams
9495
from models.config import ByokRag
9596
from models.database.conversations import UserConversation
9697
from models.requests import QueryRequest
@@ -118,7 +119,6 @@
118119
ReferencedDocument,
119120
ResponseInput,
120121
ResponseItem,
121-
ResponsesApiParams,
122122
ToolCallSummary,
123123
ToolResultSummary,
124124
TurnSummary,

0 commit comments

Comments
 (0)