-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Expand file tree
/
Copy pathbase.py
More file actions
190 lines (149 loc) · 7.13 KB
/
base.py
File metadata and controls
190 lines (149 loc) · 7.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
"""Base classes for MCPServer prompts."""
from __future__ import annotations
import inspect
from collections.abc import Awaitable, Callable, Sequence
from typing import TYPE_CHECKING, Any, Literal
import pydantic_core
from pydantic import BaseModel, Field, TypeAdapter, validate_call
from mcp.server.mcpserver.exceptions import PromptError
from mcp.server.mcpserver.utilities.context_injection import find_context_parameter, inject_context
from mcp.server.mcpserver.utilities.func_metadata import func_metadata
from mcp.shared.exceptions import MCPError
from mcp.types import ContentBlock, Icon, TextContent
if TYPE_CHECKING:
from mcp.server.context import LifespanContextT, RequestT
from mcp.server.mcpserver.context import Context
class Message(BaseModel):
"""Base class for all prompt messages."""
role: Literal["user", "assistant"]
content: ContentBlock
def __init__(self, content: str | ContentBlock, **kwargs: Any):
if isinstance(content, str):
content = TextContent(type="text", text=content)
super().__init__(content=content, **kwargs)
class UserMessage(Message):
"""A message from the user."""
role: Literal["user", "assistant"] = "user"
def __init__(self, content: str | ContentBlock, **kwargs: Any):
super().__init__(content=content, **kwargs)
class AssistantMessage(Message):
"""A message from the assistant."""
role: Literal["user", "assistant"] = "assistant"
def __init__(self, content: str | ContentBlock, **kwargs: Any):
super().__init__(content=content, **kwargs)
message_validator = TypeAdapter[UserMessage | AssistantMessage](UserMessage | AssistantMessage)
SyncPromptResult = str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]]
PromptResult = SyncPromptResult | Awaitable[SyncPromptResult]
class PromptArgument(BaseModel):
"""An argument that can be passed to a prompt."""
name: str = Field(description="Name of the argument")
description: str | None = Field(None, description="Description of what the argument does")
required: bool = Field(default=False, description="Whether the argument is required")
class Prompt(BaseModel):
"""A prompt template that can be rendered with parameters."""
name: str = Field(description="Name of the prompt")
title: str | None = Field(None, description="Human-readable title of the prompt")
description: str | None = Field(None, description="Description of what the prompt does")
arguments: list[PromptArgument] | None = Field(None, description="Arguments that can be passed to the prompt")
fn: Callable[..., PromptResult | Awaitable[PromptResult]] = Field(exclude=True)
icons: list[Icon] | None = Field(default=None, description="Optional list of icons for this prompt")
context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context", exclude=True)
@classmethod
def from_function(
cls,
fn: Callable[..., PromptResult | Awaitable[PromptResult]],
name: str | None = None,
title: str | None = None,
description: str | None = None,
icons: list[Icon] | None = None,
context_kwarg: str | None = None,
) -> Prompt:
"""Create a Prompt from a function.
The function can return:
- A string (converted to a message)
- A Message object
- A dict (converted to a message)
- A sequence of any of the above
"""
func_name = name or fn.__name__
if func_name == "<lambda>": # pragma: no cover
raise ValueError("You must provide a name for lambda functions")
# Find context parameter if it exists
if context_kwarg is None: # pragma: no branch
context_kwarg = find_context_parameter(fn)
# Get schema from func_metadata, excluding context parameter
func_arg_metadata = func_metadata(
fn,
skip_names=[context_kwarg] if context_kwarg is not None else [],
)
parameters = func_arg_metadata.arg_model.model_json_schema()
# Convert parameters to PromptArguments
arguments: list[PromptArgument] = []
if "properties" in parameters: # pragma: no branch
for param_name, param in parameters["properties"].items():
required = param_name in parameters.get("required", [])
arguments.append(
PromptArgument(
name=param_name,
description=param.get("description"),
required=required,
)
)
# ensure the arguments are properly cast
fn = validate_call(fn)
return cls(
name=func_name,
title=title,
description=description or fn.__doc__ or "",
arguments=arguments,
fn=fn,
icons=icons,
context_kwarg=context_kwarg,
)
async def render(
self,
arguments: dict[str, Any] | None,
context: Context[LifespanContextT, RequestT],
) -> list[Message]:
"""Render the prompt with arguments.
Raises:
PromptError: If required arguments are missing, or if rendering fails.
"""
# Validate required arguments
if self.arguments:
required = {arg.name for arg in self.arguments if arg.required}
provided = set(arguments or {})
missing = required - provided
if missing:
raise PromptError(f"Missing required arguments: {missing}")
try:
# Add context to arguments if needed
call_args = inject_context(self.fn, arguments or {}, context, self.context_kwarg)
# Call function and check if result is a coroutine
result = self.fn(**call_args)
if inspect.iscoroutine(result):
result = await result
# Validate messages
if not isinstance(result, list | tuple):
result = [result]
# Convert result to messages
messages: list[Message] = []
for msg in result: # type: ignore[reportUnknownVariableType]
try:
if isinstance(msg, Message):
messages.append(msg)
elif isinstance(msg, dict):
messages.append(message_validator.validate_python(msg))
elif isinstance(msg, str):
content = TextContent(type="text", text=msg)
messages.append(UserMessage(content=content))
else: # pragma: no cover
content = pydantic_core.to_json(msg, fallback=str, indent=2).decode()
messages.append(Message(role="user", content=content))
except Exception: # pragma: no cover
raise ValueError(f"Could not convert prompt result to message: {msg}")
return messages
except (PromptError, MCPError): # pragma: no cover
raise
except Exception as e: # pragma: no cover
raise ValueError(f"Error rendering prompt {self.name}: {e}")