Skip to content

Commit 8bd3fcb

Browse files
authored
Implement SEP-1577: Sampling With Tools (#628)
* feat: implement SEP-1577 sampling with tools support * feat: add TryFrom<Content> for backward-compatible migration
1 parent be23334 commit 8bd3fcb

12 files changed

Lines changed: 1850 additions & 141 deletions

crates/rmcp/src/model.rs

Lines changed: 286 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,6 +1209,152 @@ pub enum Role {
12091209
Assistant,
12101210
}
12111211

1212+
/// Tool selection mode (SEP-1577).
1213+
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1214+
#[serde(rename_all = "lowercase")]
1215+
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
1216+
pub enum ToolChoiceMode {
1217+
/// Model decides whether to use tools
1218+
Auto,
1219+
/// Model must use at least one tool
1220+
Required,
1221+
/// Model must not use tools
1222+
None,
1223+
}
1224+
1225+
impl Default for ToolChoiceMode {
1226+
fn default() -> Self {
1227+
Self::Auto
1228+
}
1229+
}
1230+
1231+
/// Tool choice configuration (SEP-1577).
1232+
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
1233+
#[serde(rename_all = "camelCase")]
1234+
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
1235+
pub struct ToolChoice {
1236+
#[serde(skip_serializing_if = "Option::is_none")]
1237+
pub mode: Option<ToolChoiceMode>,
1238+
}
1239+
1240+
impl ToolChoice {
1241+
pub fn auto() -> Self {
1242+
Self {
1243+
mode: Some(ToolChoiceMode::Auto),
1244+
}
1245+
}
1246+
1247+
pub fn required() -> Self {
1248+
Self {
1249+
mode: Some(ToolChoiceMode::Required),
1250+
}
1251+
}
1252+
1253+
pub fn none() -> Self {
1254+
Self {
1255+
mode: Some(ToolChoiceMode::None),
1256+
}
1257+
}
1258+
}
1259+
1260+
/// Single or array content wrapper (SEP-1577).
1261+
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1262+
#[serde(untagged)]
1263+
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
1264+
pub enum SamplingContent<T> {
1265+
Single(T),
1266+
Multiple(Vec<T>),
1267+
}
1268+
1269+
impl<T> SamplingContent<T> {
1270+
/// Convert to a Vec regardless of whether it's single or multiple
1271+
pub fn into_vec(self) -> Vec<T> {
1272+
match self {
1273+
SamplingContent::Single(item) => vec![item],
1274+
SamplingContent::Multiple(items) => items,
1275+
}
1276+
}
1277+
1278+
/// Check if the content is empty
1279+
pub fn is_empty(&self) -> bool {
1280+
match self {
1281+
SamplingContent::Single(_) => false,
1282+
SamplingContent::Multiple(items) => items.is_empty(),
1283+
}
1284+
}
1285+
1286+
/// Get the number of content items
1287+
pub fn len(&self) -> usize {
1288+
match self {
1289+
SamplingContent::Single(_) => 1,
1290+
SamplingContent::Multiple(items) => items.len(),
1291+
}
1292+
}
1293+
}
1294+
1295+
impl<T> Default for SamplingContent<T> {
1296+
fn default() -> Self {
1297+
SamplingContent::Multiple(Vec::new())
1298+
}
1299+
}
1300+
1301+
impl<T> SamplingContent<T> {
1302+
/// Get the first item if present
1303+
pub fn first(&self) -> Option<&T> {
1304+
match self {
1305+
SamplingContent::Single(item) => Some(item),
1306+
SamplingContent::Multiple(items) => items.first(),
1307+
}
1308+
}
1309+
1310+
/// Iterate over all content items
1311+
pub fn iter(&self) -> impl Iterator<Item = &T> {
1312+
let items: Vec<&T> = match self {
1313+
SamplingContent::Single(item) => vec![item],
1314+
SamplingContent::Multiple(items) => items.iter().collect(),
1315+
};
1316+
items.into_iter()
1317+
}
1318+
}
1319+
1320+
impl SamplingMessageContent {
1321+
/// Get the text content if this is a Text variant
1322+
pub fn as_text(&self) -> Option<&RawTextContent> {
1323+
match self {
1324+
SamplingMessageContent::Text(text) => Some(text),
1325+
_ => None,
1326+
}
1327+
}
1328+
1329+
/// Get the tool use content if this is a ToolUse variant
1330+
pub fn as_tool_use(&self) -> Option<&ToolUseContent> {
1331+
match self {
1332+
SamplingMessageContent::ToolUse(tool_use) => Some(tool_use),
1333+
_ => None,
1334+
}
1335+
}
1336+
1337+
/// Get the tool result content if this is a ToolResult variant
1338+
pub fn as_tool_result(&self) -> Option<&ToolResultContent> {
1339+
match self {
1340+
SamplingMessageContent::ToolResult(tool_result) => Some(tool_result),
1341+
_ => None,
1342+
}
1343+
}
1344+
}
1345+
1346+
impl<T> From<T> for SamplingContent<T> {
1347+
fn from(item: T) -> Self {
1348+
SamplingContent::Single(item)
1349+
}
1350+
}
1351+
1352+
impl<T> From<Vec<T>> for SamplingContent<T> {
1353+
fn from(items: Vec<T>) -> Self {
1354+
SamplingContent::Multiple(items)
1355+
}
1356+
}
1357+
12121358
/// A message in a sampling conversation, containing a role and content.
12131359
///
12141360
/// This represents a single message in a conversation flow, used primarily
@@ -1219,8 +1365,135 @@ pub enum Role {
12191365
pub struct SamplingMessage {
12201366
/// The role of the message sender (User or Assistant)
12211367
pub role: Role,
1222-
/// The actual content of the message (text, image, etc.)
1223-
pub content: Content,
1368+
/// The actual content of the message (text, image, audio, tool use, or tool result)
1369+
pub content: SamplingContent<SamplingMessageContent>,
1370+
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
1371+
pub meta: Option<Meta>,
1372+
}
1373+
1374+
/// Content types for sampling messages (SEP-1577).
1375+
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1376+
#[serde(tag = "type", rename_all = "snake_case")]
1377+
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
1378+
pub enum SamplingMessageContent {
1379+
Text(RawTextContent),
1380+
Image(RawImageContent),
1381+
Audio(RawAudioContent),
1382+
/// Assistant only
1383+
ToolUse(ToolUseContent),
1384+
/// User only
1385+
ToolResult(ToolResultContent),
1386+
}
1387+
1388+
impl SamplingMessageContent {
1389+
/// Create a text content
1390+
pub fn text(text: impl Into<String>) -> Self {
1391+
Self::Text(RawTextContent {
1392+
text: text.into(),
1393+
meta: None,
1394+
})
1395+
}
1396+
1397+
pub fn tool_use(id: impl Into<String>, name: impl Into<String>, input: JsonObject) -> Self {
1398+
Self::ToolUse(ToolUseContent::new(id, name, input))
1399+
}
1400+
1401+
pub fn tool_result(tool_use_id: impl Into<String>, content: Vec<Content>) -> Self {
1402+
Self::ToolResult(ToolResultContent::new(tool_use_id, content))
1403+
}
1404+
}
1405+
1406+
impl SamplingMessage {
1407+
pub fn new(role: Role, content: impl Into<SamplingMessageContent>) -> Self {
1408+
Self {
1409+
role,
1410+
content: SamplingContent::Single(content.into()),
1411+
meta: None,
1412+
}
1413+
}
1414+
1415+
pub fn new_multiple(role: Role, contents: Vec<SamplingMessageContent>) -> Self {
1416+
Self {
1417+
role,
1418+
content: SamplingContent::Multiple(contents),
1419+
meta: None,
1420+
}
1421+
}
1422+
1423+
pub fn user_text(text: impl Into<String>) -> Self {
1424+
Self::new(Role::User, SamplingMessageContent::text(text))
1425+
}
1426+
1427+
pub fn assistant_text(text: impl Into<String>) -> Self {
1428+
Self::new(Role::Assistant, SamplingMessageContent::text(text))
1429+
}
1430+
1431+
pub fn user_tool_result(tool_use_id: impl Into<String>, content: Vec<Content>) -> Self {
1432+
Self::new(
1433+
Role::User,
1434+
SamplingMessageContent::tool_result(tool_use_id, content),
1435+
)
1436+
}
1437+
1438+
pub fn assistant_tool_use(
1439+
id: impl Into<String>,
1440+
name: impl Into<String>,
1441+
input: JsonObject,
1442+
) -> Self {
1443+
Self::new(
1444+
Role::Assistant,
1445+
SamplingMessageContent::tool_use(id, name, input),
1446+
)
1447+
}
1448+
}
1449+
1450+
// Conversion from RawTextContent to SamplingMessageContent
1451+
impl From<RawTextContent> for SamplingMessageContent {
1452+
fn from(text: RawTextContent) -> Self {
1453+
SamplingMessageContent::Text(text)
1454+
}
1455+
}
1456+
1457+
// Conversion from String to SamplingMessageContent (as text)
1458+
impl From<String> for SamplingMessageContent {
1459+
fn from(text: String) -> Self {
1460+
SamplingMessageContent::text(text)
1461+
}
1462+
}
1463+
1464+
impl From<&str> for SamplingMessageContent {
1465+
fn from(text: &str) -> Self {
1466+
SamplingMessageContent::text(text)
1467+
}
1468+
}
1469+
1470+
// Backward compatibility: Convert Content to SamplingMessageContent
1471+
// Note: Resource and ResourceLink variants are not supported in sampling messages
1472+
impl TryFrom<Content> for SamplingMessageContent {
1473+
type Error = &'static str;
1474+
1475+
fn try_from(content: Content) -> Result<Self, Self::Error> {
1476+
match content.raw {
1477+
RawContent::Text(text) => Ok(SamplingMessageContent::Text(text)),
1478+
RawContent::Image(image) => Ok(SamplingMessageContent::Image(image)),
1479+
RawContent::Audio(audio) => Ok(SamplingMessageContent::Audio(audio)),
1480+
RawContent::Resource(_) => {
1481+
Err("Resource content is not supported in sampling messages")
1482+
}
1483+
RawContent::ResourceLink(_) => {
1484+
Err("ResourceLink content is not supported in sampling messages")
1485+
}
1486+
}
1487+
}
1488+
}
1489+
1490+
// Backward compatibility: Convert Content to SamplingContent<SamplingMessageContent>
1491+
impl TryFrom<Content> for SamplingContent<SamplingMessageContent> {
1492+
type Error = &'static str;
1493+
1494+
fn try_from(content: Content) -> Result<Self, Self::Error> {
1495+
Ok(SamplingContent::Single(content.try_into()?))
1496+
}
12241497
}
12251498

12261499
/// Specifies how much context should be included in sampling requests.
@@ -1281,6 +1554,12 @@ pub struct CreateMessageRequestParams {
12811554
/// Additional metadata for the request
12821555
#[serde(skip_serializing_if = "Option::is_none")]
12831556
pub metadata: Option<Value>,
1557+
/// Tools available for the model to call (SEP-1577)
1558+
#[serde(skip_serializing_if = "Option::is_none")]
1559+
pub tools: Option<Vec<Tool>>,
1560+
/// Tool selection behavior (SEP-1577)
1561+
#[serde(skip_serializing_if = "Option::is_none")]
1562+
pub tool_choice: Option<ToolChoice>,
12841563
}
12851564

12861565
impl RequestParamsMeta for CreateMessageRequestParams {
@@ -1926,6 +2205,7 @@ pub type CallToolRequestParam = CallToolRequestParams;
19262205
/// Request to call a specific tool
19272206
pub type CallToolRequest = Request<CallToolRequestMethod, CallToolRequestParams>;
19282207

2208+
/// Result of sampling/createMessage (SEP-1577).
19292209
/// The result of a sampling/createMessage request containing the generated response.
19302210
///
19312211
/// This structure contains the generated message along with metadata about
@@ -1948,6 +2228,7 @@ impl CreateMessageResult {
19482228
pub const STOP_REASON_END_TURN: &str = "endTurn";
19492229
pub const STOP_REASON_END_SEQUENCE: &str = "stopSequence";
19502230
pub const STOP_REASON_END_MAX_TOKEN: &str = "maxTokens";
2231+
pub const STOP_REASON_TOOL_USE: &str = "toolUse";
19512232
}
19522233

19532234
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
@@ -2477,7 +2758,9 @@ mod tests {
24772758
..
24782759
}) => {
24792760
assert_eq!(capabilities.roots.unwrap().list_changed, Some(true));
2480-
assert_eq!(capabilities.sampling.unwrap().len(), 0);
2761+
let sampling = capabilities.sampling.unwrap();
2762+
assert_eq!(sampling.tools, None);
2763+
assert_eq!(sampling.context, None);
24812764
assert_eq!(client_info.name, "ExampleClient");
24822765
assert_eq!(client_info.version, "1.0.0");
24832766
}

crates/rmcp/src/model/capabilities.rs

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,19 @@ pub struct ElicitationCapability {
194194
pub schema_validation: Option<bool>,
195195
}
196196

197+
/// Sampling capability with optional sub-capabilities (SEP-1577).
198+
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
199+
#[serde(rename_all = "camelCase")]
200+
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
201+
pub struct SamplingCapability {
202+
/// Support for `tools` and `toolChoice` parameters
203+
#[serde(skip_serializing_if = "Option::is_none")]
204+
pub tools: Option<JsonObject>,
205+
/// Support for `includeContext` (soft-deprecated)
206+
#[serde(skip_serializing_if = "Option::is_none")]
207+
pub context: Option<JsonObject>,
208+
}
209+
197210
///
198211
/// # Builder
199212
/// ```rust
@@ -217,8 +230,9 @@ pub struct ClientCapabilities {
217230
pub extensions: Option<ExtensionCapabilities>,
218231
#[serde(skip_serializing_if = "Option::is_none")]
219232
pub roots: Option<RootsCapabilities>,
233+
/// Capability for LLM sampling requests (SEP-1577)
220234
#[serde(skip_serializing_if = "Option::is_none")]
221-
pub sampling: Option<JsonObject>,
235+
pub sampling: Option<SamplingCapability>,
222236
/// Capability to handle elicitation requests from servers for interactive user input
223237
#[serde(skip_serializing_if = "Option::is_none")]
224238
pub elicitation: Option<ElicitationCapability>,
@@ -449,7 +463,7 @@ builder! {
449463
experimental: ExperimentalCapabilities,
450464
extensions: ExtensionCapabilities,
451465
roots: RootsCapabilities,
452-
sampling: JsonObject,
466+
sampling: SamplingCapability,
453467
elicitation: ElicitationCapability,
454468
tasks: TasksCapability,
455469
}
@@ -466,6 +480,26 @@ impl<const E: bool, const EXT: bool, const S: bool, const EL: bool, const TASKS:
466480
}
467481
}
468482

483+
impl<const E: bool, const EXT: bool, const R: bool, const EL: bool, const TASKS: bool>
484+
ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<E, EXT, R, true, EL, TASKS>>
485+
{
486+
/// Enable tool calling in sampling requests
487+
pub fn enable_sampling_tools(mut self) -> Self {
488+
if let Some(c) = self.sampling.as_mut() {
489+
c.tools = Some(JsonObject::default());
490+
}
491+
self
492+
}
493+
494+
/// Enable context inclusion in sampling (soft-deprecated)
495+
pub fn enable_sampling_context(mut self) -> Self {
496+
if let Some(c) = self.sampling.as_mut() {
497+
c.context = Some(JsonObject::default());
498+
}
499+
self
500+
}
501+
}
502+
469503
#[cfg(feature = "elicitation")]
470504
impl<const E: bool, const EXT: bool, const R: bool, const S: bool, const TASKS: bool>
471505
ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<E, EXT, R, S, true, TASKS>>

0 commit comments

Comments
 (0)