Skip to content

Commit ef8a04a

Browse files
committed
Fix unittest for conversations API addition
1 parent 4f1d914 commit ef8a04a

10 files changed

Lines changed: 293 additions & 127 deletions

File tree

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,12 @@ exclude = [
6161
# service/ols/src/auth/k8s.py and currently has 58 Pyright issues. It
6262
# might need to be rewritten down the line.
6363
"src/authentication/k8s.py",
64+
# Agent API v1 endpoints - deprecated API but still supported
65+
# Type errors due to llama-stack-client not exposing Agent API types
66+
"src/app/endpoints/conversations.py",
67+
"src/app/endpoints/query.py",
68+
"src/app/endpoints/streaming_query.py",
69+
"src/utils/endpoints.py",
6470
]
6571
extraPaths = ["./src"]
6672

src/app/endpoints/conversations_v3.py

Lines changed: 42 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Any
55

66
from fastapi import APIRouter, Depends, HTTPException, Request, status
7-
from llama_stack_client import APIConnectionError, NotFoundError
7+
from llama_stack_client import APIConnectionError, NOT_GIVEN, NotFoundError
88

99
from app.database import get_session
1010
from authentication import get_auth_dependency
@@ -15,13 +15,13 @@
1515
from models.database.conversations import UserConversation
1616
from models.requests import ConversationUpdateRequest
1717
from models.responses import (
18-
AccessDeniedResponse,
1918
BadRequestResponse,
2019
ConversationDeleteResponse,
2120
ConversationDetails,
2221
ConversationResponse,
2322
ConversationsListResponse,
2423
ConversationUpdateResponse,
24+
ForbiddenResponse,
2525
NotFoundResponse,
2626
ServiceUnavailableResponse,
2727
UnauthorizedResponse,
@@ -55,7 +55,7 @@
5555
"description": "Unauthorized: Invalid or missing Bearer token",
5656
},
5757
403: {
58-
"model": AccessDeniedResponse,
58+
"model": ForbiddenResponse,
5959
"description": "Client does not have permission to access conversation",
6060
},
6161
404: {
@@ -82,7 +82,7 @@
8282
"description": "Unauthorized: Invalid or missing Bearer token",
8383
},
8484
403: {
85-
"model": AccessDeniedResponse,
85+
"model": ForbiddenResponse,
8686
"description": "Client does not have permission to access conversation",
8787
},
8888
404: {
@@ -124,7 +124,7 @@
124124
"description": "Unauthorized: Invalid or missing Bearer token",
125125
},
126126
403: {
127-
"model": AccessDeniedResponse,
127+
"model": ForbiddenResponse,
128128
"description": "Client does not have permission to access conversation",
129129
},
130130
404: {
@@ -283,7 +283,7 @@ async def get_conversation_endpoint_handler(
283283
status_code=status.HTTP_400_BAD_REQUEST,
284284
detail=BadRequestResponse(
285285
resource="conversation", resource_id=conversation_id
286-
).dump_detail(),
286+
).model_dump(),
287287
)
288288

289289
# Normalize the conversation ID for database operations (strip conv_ prefix if present)
@@ -309,12 +309,11 @@ async def get_conversation_endpoint_handler(
309309
)
310310
raise HTTPException(
311311
status_code=status.HTTP_403_FORBIDDEN,
312-
detail=AccessDeniedResponse(
313-
user_id=user_id,
314-
resource="conversation",
315-
resource_id=normalized_conv_id,
312+
detail=ForbiddenResponse.conversation(
316313
action="read",
317-
).dump_detail(),
314+
resource_id=normalized_conv_id,
315+
user_id=user_id,
316+
).model_dump(),
318317
)
319318

320319
# If reached this, user is authorized to retrieve this conversation
@@ -342,8 +341,6 @@ async def get_conversation_endpoint_handler(
342341
)
343342

344343
# Use Conversations API to retrieve conversation items
345-
from llama_stack_client import NOT_GIVEN
346-
347344
conversation_items_response = await client.conversations.items.list(
348345
conversation_id=llama_stack_conv_id,
349346
after=NOT_GIVEN, # No pagination cursor
@@ -384,7 +381,7 @@ async def get_conversation_endpoint_handler(
384381
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
385382
detail=ServiceUnavailableResponse(
386383
backend_name="Llama Stack", cause=str(e)
387-
).dump_detail(),
384+
).model_dump(),
388385
) from e
389386

390387
except NotFoundError as e:
@@ -393,7 +390,7 @@ async def get_conversation_endpoint_handler(
393390
status_code=status.HTTP_404_NOT_FOUND,
394391
detail=NotFoundResponse(
395392
resource="conversation", resource_id=normalized_conv_id
396-
).dump_detail(),
393+
).model_dump(),
397394
) from e
398395

399396
except HTTPException:
@@ -402,11 +399,14 @@ async def get_conversation_endpoint_handler(
402399
except Exception as e:
403400
# Handle case where conversation doesn't exist or other errors
404401
logger.exception("Error retrieving conversation %s: %s", normalized_conv_id, e)
402+
error_msg = (
403+
f"Unknown error while getting conversation {normalized_conv_id} : {e}"
404+
)
405405
raise HTTPException(
406406
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
407407
detail={
408408
"response": "Unknown error",
409-
"cause": f"Unknown error while getting conversation {normalized_conv_id} : {str(e)}",
409+
"cause": error_msg,
410410
},
411411
) from e
412412

@@ -444,7 +444,7 @@ async def delete_conversation_endpoint_handler(
444444
status_code=status.HTTP_400_BAD_REQUEST,
445445
detail=BadRequestResponse(
446446
resource="conversation", resource_id=conversation_id
447-
).dump_detail(),
447+
).model_dump(),
448448
)
449449

450450
# Normalize the conversation ID for database operations (strip conv_ prefix if present)
@@ -465,12 +465,11 @@ async def delete_conversation_endpoint_handler(
465465
)
466466
raise HTTPException(
467467
status_code=status.HTTP_403_FORBIDDEN,
468-
detail=AccessDeniedResponse(
469-
user_id=user_id,
470-
resource="conversation",
471-
resource_id=normalized_conv_id,
468+
detail=ForbiddenResponse.conversation(
472469
action="delete",
473-
).dump_detail(),
470+
resource_id=normalized_conv_id,
471+
user_id=user_id,
472+
).model_dump(),
474473
)
475474

476475
# If reached this, user is authorized to delete this conversation
@@ -480,7 +479,7 @@ async def delete_conversation_endpoint_handler(
480479
status_code=status.HTTP_404_NOT_FOUND,
481480
detail=NotFoundResponse(
482481
resource="conversation", resource_id=normalized_conv_id
483-
).dump_detail(),
482+
).model_dump(),
484483
)
485484

486485
logger.info("Deleting conversation %s using Conversations API", normalized_conv_id)
@@ -502,16 +501,15 @@ async def delete_conversation_endpoint_handler(
502501

503502
return ConversationDeleteResponse(
504503
conversation_id=normalized_conv_id,
505-
success=True,
506-
response="Conversation deleted successfully",
504+
deleted=True,
507505
)
508506

509507
except APIConnectionError as e:
510508
raise HTTPException(
511509
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
512510
detail=ServiceUnavailableResponse(
513511
backend_name="Llama Stack", cause=str(e)
514-
).dump_detail(),
512+
).model_dump(),
515513
) from e
516514

517515
except NotFoundError:
@@ -524,8 +522,7 @@ async def delete_conversation_endpoint_handler(
524522

525523
return ConversationDeleteResponse(
526524
conversation_id=normalized_conv_id,
527-
success=True,
528-
response="Conversation deleted successfully",
525+
deleted=True,
529526
)
530527

531528
except HTTPException:
@@ -534,11 +531,14 @@ async def delete_conversation_endpoint_handler(
534531
except Exception as e:
535532
# Handle case where conversation doesn't exist or other errors
536533
logger.exception("Error deleting conversation %s: %s", normalized_conv_id, e)
534+
error_msg = (
535+
f"Unknown error while deleting conversation {normalized_conv_id} : {e}"
536+
)
537537
raise HTTPException(
538538
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
539539
detail={
540540
"response": "Unknown error",
541-
"cause": f"Unknown error while deleting conversation {normalized_conv_id} : {str(e)}",
541+
"cause": error_msg,
542542
},
543543
) from e
544544

@@ -574,7 +574,7 @@ async def update_conversation_endpoint_handler(
574574
status_code=status.HTTP_400_BAD_REQUEST,
575575
detail=BadRequestResponse(
576576
resource="conversation", resource_id=conversation_id
577-
).dump_detail(),
577+
).model_dump(),
578578
)
579579

580580
# Normalize the conversation ID for database operations (strip conv_ prefix if present)
@@ -595,12 +595,11 @@ async def update_conversation_endpoint_handler(
595595
)
596596
raise HTTPException(
597597
status_code=status.HTTP_403_FORBIDDEN,
598-
detail=AccessDeniedResponse(
599-
user_id=user_id,
600-
resource="conversation",
601-
resource_id=normalized_conv_id,
598+
detail=ForbiddenResponse.conversation(
602599
action="update",
603-
).dump_detail(),
600+
resource_id=normalized_conv_id,
601+
user_id=user_id,
602+
).model_dump(),
604603
)
605604

606605
# If reached this, user is authorized to update this conversation
@@ -610,7 +609,7 @@ async def update_conversation_endpoint_handler(
610609
status_code=status.HTTP_404_NOT_FOUND,
611610
detail=NotFoundResponse(
612611
resource="conversation", resource_id=normalized_conv_id
613-
).dump_detail(),
612+
).model_dump(),
614613
)
615614

616615
logger.info(
@@ -629,7 +628,7 @@ async def update_conversation_endpoint_handler(
629628
metadata = {"topic_summary": update_request.topic_summary}
630629

631630
# Use Conversations API to update the conversation metadata
632-
await client.conversations.update_conversation(
631+
await client.conversations.update(
633632
conversation_id=llama_stack_conv_id,
634633
metadata=metadata,
635634
)
@@ -663,15 +662,15 @@ async def update_conversation_endpoint_handler(
663662
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
664663
detail=ServiceUnavailableResponse(
665664
backend_name="Llama Stack", cause=str(e)
666-
).dump_detail(),
665+
).model_dump(),
667666
) from e
668667

669668
except NotFoundError as e:
670669
raise HTTPException(
671670
status_code=status.HTTP_404_NOT_FOUND,
672671
detail=NotFoundResponse(
673672
resource="conversation", resource_id=normalized_conv_id
674-
).dump_detail(),
673+
).model_dump(),
675674
) from e
676675

677676
except HTTPException:
@@ -680,10 +679,13 @@ async def update_conversation_endpoint_handler(
680679
except Exception as e:
681680
# Handle case where conversation doesn't exist or other errors
682681
logger.exception("Error updating conversation %s: %s", normalized_conv_id, e)
682+
error_msg = (
683+
f"Unknown error while updating conversation {normalized_conv_id} : {e}"
684+
)
683685
raise HTTPException(
684686
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
685687
detail={
686688
"response": "Unknown error",
687-
"cause": f"Unknown error while updating conversation {normalized_conv_id} : {str(e)}",
689+
"cause": error_msg,
688690
},
689691
) from e

src/utils/endpoints.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -306,16 +306,19 @@ async def get_agent(
306306
existing_agent_id = agent_response.agent_id
307307

308308
logger.debug("Creating new agent")
309+
# pylint: disable=unexpected-keyword-arg,no-member
309310
agent = AsyncAgent(
310311
client, # type: ignore[arg-type]
311312
model=model_id,
312313
instructions=system_prompt,
314+
# type: ignore[call-arg]
313315
input_shields=available_input_shields if available_input_shields else [],
316+
# type: ignore[call-arg]
314317
output_shields=available_output_shields if available_output_shields else [],
315318
tool_parser=None if no_tools else GraniteToolParser.get_parser(model_id),
316-
enable_session_persistence=True,
319+
enable_session_persistence=True, # type: ignore[call-arg]
317320
)
318-
await agent.initialize()
321+
await agent.initialize() # type: ignore[attr-defined]
319322

320323
if existing_agent_id and conversation_id:
321324
logger.debug("Existing conversation ID: %s", conversation_id)
@@ -335,11 +338,12 @@ async def get_agent(
335338
raise HTTPException(**response.model_dump()) from e
336339
else:
337340
conversation_id = agent.agent_id
341+
# pylint: enable=unexpected-keyword-arg,no-member
338342
logger.debug("New conversation ID: %s", conversation_id)
339343
session_id = await agent.create_session(get_suid())
340344
logger.debug("New session ID: %s", session_id)
341345

342-
return agent, conversation_id, session_id
346+
return agent, conversation_id, session_id # type: ignore[return-value]
343347

344348

345349
async def get_temp_agent(
@@ -360,16 +364,19 @@ async def get_temp_agent(
360364
tuple[AsyncAgent, str]: A tuple containing the agent and session_id.
361365
"""
362366
logger.debug("Creating temporary agent")
367+
# pylint: disable=unexpected-keyword-arg,no-member
363368
agent = AsyncAgent(
364369
client, # type: ignore[arg-type]
365370
model=model_id,
366371
instructions=system_prompt,
367-
enable_session_persistence=False, # Temporary agent doesn't need persistence
372+
# type: ignore[call-arg] # Temporary agent doesn't need persistence
373+
enable_session_persistence=False,
368374
)
369-
await agent.initialize()
375+
await agent.initialize() # type: ignore[attr-defined]
370376

371377
# Generate new IDs for the temporary agent
372378
conversation_id = agent.agent_id
379+
# pylint: enable=unexpected-keyword-arg,no-member
373380
session_id = await agent.create_session(get_suid())
374381

375382
return agent, session_id, conversation_id

src/utils/types.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
from typing import Any, Optional
44
import json
55
from llama_stack_client.lib.agents.tool_parser import ToolParser
6-
from llama_stack_client.types.shared.completion_message import CompletionMessage
7-
from llama_stack_client.types.shared.tool_call import ToolCall
6+
from llama_stack_client.lib.agents.types import (
7+
CompletionMessage as AgentCompletionMessage,
8+
ToolCall as AgentToolCall,
9+
)
810
from llama_stack_client.types.shared.interleaved_content_item import (
911
TextContentItem,
1012
ImageContentItem,
@@ -58,16 +60,18 @@ def __call__(cls, *args, **kwargs): # type: ignore
5860
class GraniteToolParser(ToolParser):
5961
"""Workaround for 'tool_calls' with granite models."""
6062

61-
def get_tool_calls(self, output_message: CompletionMessage) -> list[ToolCall]:
63+
def get_tool_calls(
64+
self, output_message: AgentCompletionMessage
65+
) -> list[AgentToolCall]:
6266
"""
6367
Return the `tool_calls` list from a CompletionMessage, or an empty list if none are present.
6468
6569
Parameters:
66-
output_message (CompletionMessage | None): Completion
70+
output_message (AgentCompletionMessage | None): Completion
6771
message potentially containing `tool_calls`.
6872
6973
Returns:
70-
list[ToolCall]: The list of tool call entries
74+
list[AgentToolCall]: The list of tool call entries
7175
extracted from `output_message`, or an empty list.
7276
"""
7377
if output_message and output_message.tool_calls:

0 commit comments

Comments
 (0)