forked from lightspeed-core/lightspeed-stack
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtypes.py
More file actions
186 lines (142 loc) · 6.31 KB
/
types.py
File metadata and controls
186 lines (142 loc) · 6.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
"""Common types for the project."""
from typing import Any, Optional
from llama_stack_client.lib.agents.tool_parser import ToolParser
from llama_stack_client.lib.agents.types import (
CompletionMessage as AgentCompletionMessage,
)
from llama_stack_client.lib.agents.types import (
ToolCall as AgentToolCall,
)
from llama_stack_client.types.shared.interleaved_content_item import (
ImageContentItem,
TextContentItem,
)
from pydantic import AnyUrl, BaseModel, Field
from utils.token_counter import TokenCounter
def content_to_str(content: Any) -> str:
"""Convert content (str, TextContentItem, ImageContentItem, or list) to string.
Parameters:
content: Value to normalize into a string (may be None,
str, content item, list, or any other object).
Returns:
str: The normalized string representation of the content.
"""
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, TextContentItem):
return content.text
if isinstance(content, ImageContentItem):
return "<image>"
if isinstance(content, list):
return " ".join(content_to_str(item) for item in content)
return str(content)
class Singleton(type):
"""Metaclass for Singleton support."""
_instances = {} # type: ignore
def __call__(cls, *args, **kwargs): # type: ignore
"""
Return the single cached instance of the class, creating and caching it on first call.
Returns:
object: The singleton instance for this class.
"""
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
# See https://github.com/meta-llama/llama-stack-client-python/issues/206
class GraniteToolParser(ToolParser):
"""Workaround for 'tool_calls' with granite models."""
def get_tool_calls(
self, output_message: AgentCompletionMessage
) -> list[AgentToolCall]:
"""
Return the `tool_calls` list from a CompletionMessage, or an empty list if none are present.
Parameters:
output_message (Optional[AgentCompletionMessage]): Completion
message potentially containing `tool_calls`.
Returns:
list[AgentToolCall]: The list of tool call entries
extracted from `output_message`, or an empty list.
"""
if output_message and output_message.tool_calls:
return output_message.tool_calls
return []
@staticmethod
def get_parser(model_id: str) -> Optional[ToolParser]:
"""
Return a GraniteToolParser when the model identifier denotes a Granite model.
Returns None otherwise.
Parameters:
model_id (str): Model identifier string checked case-insensitively.
If it starts with "granite", a GraniteToolParser instance is
returned.
Returns:
Optional[ToolParser]: GraniteToolParser for Granite models, or None
if `model_id` is falsy or does not start with "granite".
"""
if model_id and model_id.lower().startswith("granite"):
return GraniteToolParser()
return None
class ShieldModerationResult(BaseModel):
"""Result of shield moderation check."""
blocked: bool
message: Optional[str] = None
shield_model: Optional[str] = None
class ResponsesApiParams(BaseModel):
"""Parameters for a Llama Stack Responses API request."""
input: str = Field(description="The input text with attachments appended")
model: str = Field(description='The full model ID in format "provider/model"')
instructions: Optional[str] = Field(
default=None, description="The resolved system prompt"
)
tools: Optional[list[dict[str, Any]]] = Field(
default=None, description="Prepared tool groups for Responses API"
)
conversation: str = Field(description="The conversation ID in llama-stack format")
stream: bool = Field(description="Whether to stream the response")
store: bool = Field(description="Whether to store the response")
class ToolCallSummary(BaseModel):
"""Model representing a tool call made during response generation (for tool_calls list)."""
id: str = Field(description="ID of the tool call")
name: str = Field(description="Name of the tool called")
args: dict[str, Any] = Field(
default_factory=dict, description="Arguments passed to the tool"
)
type: str = Field("tool_call", description="Type indicator for tool call")
class ToolResultSummary(BaseModel):
"""Model representing a result from a tool call (for tool_results list)."""
id: str = Field(
description="ID of the tool call/result, matches the corresponding tool call 'id'"
)
status: str = Field(
..., description="Status of the tool execution (e.g., 'success')"
)
content: str = Field(..., description="Content/result returned from the tool")
type: str = Field("tool_result", description="Type indicator for tool result")
round: int = Field(..., description="Round number or step of tool execution")
class RAGChunk(BaseModel):
"""Model representing a RAG chunk used in the response."""
content: str = Field(description="The content of the chunk")
source: Optional[str] = Field(default=None, description="Source document or URL")
score: Optional[float] = Field(default=None, description="Relevance score")
class ReferencedDocument(BaseModel):
"""Model representing a document referenced in generating a response.
Attributes:
doc_url: Url to the referenced doc.
doc_title: Title of the referenced doc.
"""
doc_url: Optional[AnyUrl] = Field(
None, description="URL of the referenced document"
)
doc_title: Optional[str] = Field(
None, description="Title of the referenced document"
)
class TurnSummary(BaseModel):
"""Summary of a turn in llama stack."""
llm_response: str = ""
tool_calls: list[ToolCallSummary] = Field(default_factory=list)
tool_results: list[ToolResultSummary] = Field(default_factory=list)
rag_chunks: list[RAGChunk] = Field(default_factory=list)
referenced_documents: list[ReferencedDocument] = Field(default_factory=list)
token_usage: TokenCounter = Field(default_factory=TokenCounter)