Skip to content

Commit ec2f65d

Browse files
asimurkatisnik
authored andcommitted
Utilities to use more generic types refactor
1 parent 302bfb3 commit ec2f65d

18 files changed

Lines changed: 1032 additions & 1130 deletions

src/app/endpoints/a2a.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
from models.requests import QueryRequest
4747
from utils.mcp_headers import mcp_headers_dependency, McpHeaders
4848
from utils.responses import (
49-
extract_text_from_response_output_item,
49+
extract_text_from_output_item,
5050
prepare_responses_params,
5151
)
5252
from utils.suid import normalize_conversation_id
@@ -107,7 +107,7 @@ def _convert_responses_content_to_a2a_parts(output: list[Any]) -> list[Part]:
107107
parts: list[Part] = []
108108

109109
for output_item in output:
110-
text = extract_text_from_response_output_item(output_item)
110+
text = extract_text_from_output_item(output_item)
111111
if text:
112112
parts.append(Part(root=TextPart(text=text)))
113113

src/app/endpoints/query.py

Lines changed: 14 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""Handler for REST API call to provide answer to query using Response API."""
44

55
import datetime
6-
from typing import Annotated, Any, Optional, cast
6+
from typing import Annotated, Any, cast
77

88
from fastapi import APIRouter, Depends, HTTPException, Request
99
from llama_stack_api.openai_responses import OpenAIResponseObject
@@ -52,12 +52,10 @@
5252
)
5353
from utils.quota import check_tokens_available, get_available_quotas
5454
from utils.responses import (
55-
build_tool_call_summary,
56-
extract_text_from_response_output_item,
57-
extract_token_usage,
55+
build_turn_summary,
56+
deduplicate_referenced_documents,
5857
extract_vector_store_ids_from_tools,
5958
get_topic_summary,
60-
parse_referenced_documents,
6159
prepare_responses_params,
6260
)
6361
from utils.shields import (
@@ -66,6 +64,7 @@
6664
)
6765
from utils.suid import normalize_conversation_id
6866
from utils.types import (
67+
RAGChunk,
6968
ResponsesApiParams,
7069
TurnSummary,
7170
)
@@ -130,7 +129,9 @@ async def query_endpoint_handler(
130129
check_tokens_available(configuration.quota_limiters, user_id)
131130

132131
# Enforce RBAC: optionally disallow overriding model/provider in requests
133-
validate_model_provider_override(query_request, request.state.authorized_actions)
132+
validate_model_provider_override(
133+
query_request.model, query_request.provider, request.state.authorized_actions
134+
)
134135

135136
# Validate attachments if provided
136137
if query_request.attachments:
@@ -153,7 +154,7 @@ async def query_endpoint_handler(
153154
client = AsyncLlamaStackClientHolder().get_client()
154155

155156
doc_ids_from_chunks: list[ReferencedDocument] = []
156-
pre_rag_chunks: list[Any] = [] # use your RAGChunk type (or the upstream one)
157+
pre_rag_chunks: list[RAGChunk] = []
157158

158159
_, _, doc_ids_from_chunks, pre_rag_chunks = await perform_vector_search(
159160
client, query_request, configuration
@@ -198,7 +199,7 @@ async def query_endpoint_handler(
198199
turn_summary.rag_chunks = pre_rag_chunks + (turn_summary.rag_chunks or [])
199200

200201
if doc_ids_from_chunks:
201-
turn_summary.referenced_documents = parse_referenced_docs(
202+
turn_summary.referenced_documents = deduplicate_referenced_documents(
202203
doc_ids_from_chunks + (turn_summary.referenced_documents or [])
203204
)
204205

@@ -216,7 +217,6 @@ async def query_endpoint_handler(
216217
user_id=user_id,
217218
model_id=responses_params.model,
218219
token_usage=turn_summary.token_usage,
219-
configuration=configuration,
220220
)
221221

222222
logger.info("Getting available quotas")
@@ -238,7 +238,6 @@ async def query_endpoint_handler(
238238
completed_at=completed_at,
239239
summary=turn_summary,
240240
query_request=query_request,
241-
configuration=configuration,
242241
skip_userid_check=_skip_userid_check,
243242
topic_summary=topic_summary,
244243
)
@@ -258,26 +257,11 @@ async def query_endpoint_handler(
258257
)
259258

260259

261-
def parse_referenced_docs(
262-
docs: list[ReferencedDocument],
263-
) -> list[ReferencedDocument]:
264-
"""Remove duplicate referenced documents based on URL and title."""
265-
seen: set[tuple[str | None, str | None]] = set()
266-
out: list[ReferencedDocument] = []
267-
for d in docs:
268-
key = (str(d.doc_url) if d.doc_url else None, d.doc_title)
269-
if key in seen:
270-
continue
271-
seen.add(key)
272-
out.append(d)
273-
return out
274-
275-
276260
async def retrieve_response( # pylint: disable=too-many-locals
277261
client: AsyncLlamaStackClient,
278262
responses_params: ResponsesApiParams,
279-
vector_store_ids: Optional[list[str]] = None,
280-
rag_id_mapping: Optional[dict[str, str]] = None,
263+
vector_store_ids: list[str] | None = None,
264+
rag_id_mapping: dict[str, str] | None = None,
281265
) -> TurnSummary:
282266
"""
283267
Retrieve response from LLMs and agents.
@@ -294,8 +278,6 @@ async def retrieve_response( # pylint: disable=too-many-locals
294278
Returns:
295279
TurnSummary: Summary of the LLM response content
296280
"""
297-
summary = TurnSummary()
298-
299281
try:
300282
moderation_result = await run_shield_moderation(client, responses_params.input)
301283
if moderation_result.blocked:
@@ -307,8 +289,7 @@ async def retrieve_response( # pylint: disable=too-many-locals
307289
responses_params.input,
308290
violation_message,
309291
)
310-
summary.llm_response = violation_message
311-
return summary
292+
return TurnSummary(llm_response=violation_message)
312293
response = await client.responses.create(**responses_params.model_dump())
313294
response = cast(OpenAIResponseObject, response)
314295

@@ -327,30 +308,6 @@ async def retrieve_response( # pylint: disable=too-many-locals
327308
error_response = handle_known_apistatus_errors(e, responses_params.model)
328309
raise HTTPException(**error_response.model_dump()) from e
329310

330-
# Process OpenAI response format
331-
for output_item in response.output:
332-
message_text = extract_text_from_response_output_item(output_item)
333-
if message_text:
334-
summary.llm_response += message_text
335-
336-
tool_call, tool_result = build_tool_call_summary(
337-
output_item, summary.rag_chunks, vector_store_ids, rag_id_mapping
338-
)
339-
if tool_call:
340-
summary.tool_calls.append(tool_call)
341-
if tool_result:
342-
summary.tool_results.append(tool_result)
343-
344-
logger.info(
345-
"Response processing complete - Tool calls: %d, Response length: %d chars",
346-
len(summary.tool_calls),
347-
len(summary.llm_response),
348-
)
349-
350-
# Extract referenced documents and token usage from Responses API response
351-
summary.referenced_documents = parse_referenced_documents(
352-
response, vector_store_ids, rag_id_mapping
311+
return build_turn_summary(
312+
response, responses_params.model, vector_store_ids, rag_id_mapping
353313
)
354-
summary.token_usage = extract_token_usage(response, responses_params.model)
355-
356-
return summary

src/app/endpoints/rlsapi_v1.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@
3434
from models.rlsapi.responses import RlsapiV1InferData, RlsapiV1InferResponse
3535
from observability import InferenceEventData, build_inference_event, send_splunk_event
3636
from utils.query import handle_known_apistatus_errors
37-
from utils.responses import extract_text_from_response_output_item, get_mcp_tools
37+
from utils.responses import (
38+
extract_text_from_output_items,
39+
get_mcp_tools,
40+
)
3841
from utils.suid import get_suid
3942
from log import get_logger
4043

@@ -189,10 +192,7 @@ async def retrieve_simple_response(
189192
)
190193
response = cast(OpenAIResponseObject, response)
191194

192-
return "".join(
193-
extract_text_from_response_output_item(output_item)
194-
for output_item in response.output
195-
)
195+
return extract_text_from_output_items(response.output)
196196

197197

198198
def _get_cla_version(request: Request) -> str:
@@ -307,7 +307,7 @@ async def infer_endpoint(
307307
input_source = infer_request.get_input_source()
308308
instructions = _build_instructions(infer_request.context.systeminfo)
309309
model_id = _get_default_model_id()
310-
mcp_tools = await get_mcp_tools(configuration.mcp_servers)
310+
mcp_tools = await get_mcp_tools()
311311
logger.debug(
312312
"Request %s: Combined input source length: %d", request_id, len(input_source)
313313
)

src/app/endpoints/streaming_query.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,9 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
153153
check_tokens_available(configuration.quota_limiters, user_id)
154154

155155
# Enforce RBAC: optionally disallow overriding model/provider in requests
156-
validate_model_provider_override(query_request, request.state.authorized_actions)
156+
validate_model_provider_override(
157+
query_request.model, query_request.provider, request.state.authorized_actions
158+
)
157159

158160
# Validate attachments if provided
159161
if query_request.attachments:
@@ -379,7 +381,6 @@ async def generate_response(
379381
user_id=context.user_id,
380382
model_id=responses_params.model,
381383
token_usage=turn_summary.token_usage,
382-
configuration=configuration,
383384
)
384385
# Get available quotas
385386
logger.info("Getting available quotas")
@@ -405,7 +406,6 @@ async def generate_response(
405406
started_at=context.started_at,
406407
summary=turn_summary,
407408
query_request=context.query_request,
408-
configuration=configuration,
409409
skip_userid_check=context.skip_userid_check,
410410
topic_summary=topic_summary,
411411
)
@@ -591,8 +591,11 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
591591
)
592592

593593
# Extract token usage and referenced documents from the final response object
594+
if not latest_response_object:
595+
return
596+
594597
turn_summary.token_usage = extract_token_usage(
595-
latest_response_object, context.model_id
598+
latest_response_object.usage, context.model_id
596599
)
597600
tool_based_documents = parse_referenced_documents(
598601
latest_response_object,

src/models/responses.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1753,16 +1753,20 @@ class NotFoundResponse(AbstractErrorResponse):
17531753
}
17541754
}
17551755

1756-
def __init__(self, *, resource: str, resource_id: str):
1756+
def __init__(self, *, resource: str, resource_id: str | None = None):
17571757
"""
17581758
Create a NotFoundResponse for a missing resource and set the HTTP status to 404.
17591759
17601760
Parameters:
17611761
resource (str): Resource type that was not found (e.g., "conversation", "model").
1762-
resource_id (str): Identifier of the missing resource.
1762+
resource_id (str | None): Identifier of the missing resource. If None, indicates
1763+
the resource type is not configured (e.g., no model selected).
17631764
"""
17641765
response = f"{resource.title()} not found"
1765-
cause = f"{resource.title()} with ID {resource_id} does not exist"
1766+
if resource_id is None:
1767+
cause = f"No {resource.title()} is configured"
1768+
else:
1769+
cause = f"{resource.title()} with ID {resource_id} does not exist"
17661770
super().__init__(
17671771
response=response, cause=cause, status_code=status.HTTP_404_NOT_FOUND
17681772
)

src/utils/prompts.py

Lines changed: 34 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,41 +3,40 @@
33
from fastapi import HTTPException
44

55
import constants
6-
from configuration import AppConfig
7-
from models.requests import QueryRequest
6+
from configuration import configuration
87
from models.responses import UnprocessableEntityResponse
98

109

11-
def get_system_prompt(query_request: QueryRequest, config: AppConfig) -> str:
10+
def get_system_prompt(system_prompt: str | None) -> str:
1211
"""
1312
Resolve which system prompt to use for a query.
1413
15-
Precedence (highest to lowest):
16-
1. Per-request `system_prompt` from `query_request.system_prompt`.
17-
2. The `custom_profile`'s "default" prompt (when present), accessed via
18-
`config.customization.custom_profile.get_prompts().get("default")`.
19-
3. `config.customization.system_prompt` from application configuration.
14+
get_system_prompt resolves the system prompt with the following precedence
15+
(highest to lowest):
16+
1. Per-request system prompt from the `system_prompt` argument (when allowed).
17+
2. The custom profile's "default" prompt (when present), from application
18+
configuration.
19+
3. The application configuration system prompt.
2020
4. The module default `constants.DEFAULT_SYSTEM_PROMPT` (lowest precedence).
2121
22-
If configuration disables per-request system prompts
23-
(config.customization.disable_query_system_prompt) and the incoming
24-
`query_request` contains a `system_prompt`, an HTTP 422 Unprocessable
25-
Entity is raised instructing the client to remove the field.
26-
2722
Parameters:
28-
query_request (QueryRequest): The incoming query payload; may contain a
29-
per-request `system_prompt`.
30-
config (AppConfig): Application configuration which may include
31-
customization flags, a custom profile, and a default `system_prompt`.
23+
system_prompt: Optional per-request system prompt from the query; may be
24+
None.
3225
3326
Returns:
34-
str: The resolved system prompt to apply to the request.
27+
The resolved system prompt string to apply to the request.
28+
29+
Raises:
30+
HTTPException: 422 Unprocessable Entity when per-request system prompts
31+
are disabled (disable_query_system_prompt) and a non-None
32+
`system_prompt` is provided; the response instructs the client to
33+
remove the system_prompt field from the request.
3534
"""
3635
system_prompt_disabled = (
37-
config.customization is not None
38-
and config.customization.disable_query_system_prompt
36+
configuration.customization is not None
37+
and configuration.customization.disable_query_system_prompt
3938
)
40-
if system_prompt_disabled and query_request.system_prompt:
39+
if system_prompt_disabled and system_prompt:
4140
response = UnprocessableEntityResponse(
4241
response="System prompt customization is disabled",
4342
cause=(
@@ -48,49 +47,47 @@ def get_system_prompt(query_request: QueryRequest, config: AppConfig) -> str:
4847
)
4948
raise HTTPException(**response.model_dump())
5049

51-
if query_request.system_prompt:
50+
if system_prompt:
5251
# Query taking precedence over configuration is the only behavior that
5352
# makes sense here - if the configuration wants precedence, it can
5453
# disable query system prompt altogether with disable_query_system_prompt.
55-
return query_request.system_prompt
54+
return system_prompt
5655

5756
# profile takes precedence for setting prompt
5857
if (
59-
config.customization is not None
60-
and config.customization.custom_profile is not None
58+
configuration.customization is not None
59+
and configuration.customization.custom_profile is not None
6160
):
62-
prompt = config.customization.custom_profile.get_prompts().get("default")
61+
prompt = configuration.customization.custom_profile.get_prompts().get("default")
6362
if prompt:
6463
return prompt
6564

6665
if (
67-
config.customization is not None
68-
and config.customization.system_prompt is not None
66+
configuration.customization is not None
67+
and configuration.customization.system_prompt is not None
6968
):
70-
return config.customization.system_prompt
69+
return configuration.customization.system_prompt
7170

7271
# default system prompt has the lowest precedence
7372
return constants.DEFAULT_SYSTEM_PROMPT
7473

7574

76-
def get_topic_summary_system_prompt(config: AppConfig) -> str:
75+
def get_topic_summary_system_prompt() -> str:
7776
"""
7877
Get the topic summary system prompt.
7978
80-
Parameters:
81-
config (AppConfig): Application configuration from which to read
82-
customization/profile settings.
83-
8479
Returns:
8580
str: The topic summary system prompt from the active custom profile if
8681
set, otherwise the default prompt.
8782
"""
8883
# profile takes precedence for setting prompt
8984
if (
90-
config.customization is not None
91-
and config.customization.custom_profile is not None
85+
configuration.customization is not None
86+
and configuration.customization.custom_profile is not None
9287
):
93-
prompt = config.customization.custom_profile.get_prompts().get("topic_summary")
88+
prompt = configuration.customization.custom_profile.get_prompts().get(
89+
"topic_summary"
90+
)
9491
if prompt:
9592
return prompt
9693

0 commit comments

Comments
 (0)