Skip to content

Commit 3a31cef

Browse files
QuantGeekDevclaude
andcommitted
feat: add sampling support to tools
Enables tools to request LLM completions via the MCP sampling protocol (sampling/createMessage). The SDK Server instance is injected into tools at startup, allowing tools to call this.samplingRequest() from execute(). Aligned with MCP specification 2025-11-25. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent dde5519 commit 3a31cef

File tree

3 files changed

+199
-13
lines changed

3 files changed

+199
-13
lines changed

src/core/MCPServer.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,7 @@ export class MCPServer {
424424
{ name: this.serverName, version: this.serverVersion },
425425
{ capabilities: this.capabilities }
426426
);
427+
tools.forEach((tool) => tool.injectServer(this.server));
427428
logger.debug(
428429
`SDK Server instance created with capabilities: ${JSON.stringify(this.capabilities)}`
429430
);

src/tools/BaseTool.ts

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,8 @@
11
import { z } from 'zod';
2-
import { Tool as SDKTool } from '@modelcontextprotocol/sdk/types.js';
2+
import { CreateMessageRequest, CreateMessageResult, Tool as SDKTool } from '@modelcontextprotocol/sdk/types.js';
33
import { ImageContent } from '../transports/utils/image-handler.js';
4-
5-
// Type to check if a Zod type has a description
6-
type HasDescription<T> = T extends { _def: { description: string } } ? T : never;
7-
8-
// Type to ensure all properties in a Zod object have descriptions
9-
type AllFieldsHaveDescriptions<T extends z.ZodRawShape> = {
10-
[K in keyof T]: HasDescription<T[K]>;
11-
};
12-
13-
// Strict Zod object type that requires all fields to have descriptions
14-
type StrictZodObject<T extends z.ZodRawShape> = z.ZodObject<AllFieldsHaveDescriptions<T>>;
4+
import { Server } from '@modelcontextprotocol/sdk/server/index.js';
5+
import { RequestOptions } from '@modelcontextprotocol/sdk/shared/protocol.js';
156

167
export type ToolInputSchema<T> = {
178
[K in keyof T]: {
@@ -64,6 +55,7 @@ export interface ToolProtocol extends SDKTool {
6455
toolCall(request: {
6556
params: { name: string; arguments?: Record<string, unknown> };
6657
}): Promise<ToolResponse>;
58+
injectServer(server: Server): void;
6759
}
6860

6961
/**
@@ -100,6 +92,46 @@ export abstract class MCPTool<TInput extends Record<string, any> = any, TSchema
10092
protected useStringify: boolean = true;
10193
[key: string]: unknown;
10294

95+
private server: Server | undefined;
96+
97+
/**
98+
* Injects the server into this tool to allow sampling requests.
99+
* Automatically called by the MCP server when registering the tool.
100+
* Subsequent calls are silently ignored.
101+
*/
102+
public injectServer(server: Server): void {
103+
if (this.server) {
104+
return;
105+
}
106+
this.server = server;
107+
}
108+
109+
/**
110+
* Submit a sampling request to the client via the MCP sampling protocol.
111+
* Can only be called from within a tool's execute() method after the server
112+
* has been injected.
113+
*
114+
* @example
115+
* ```typescript
116+
* const result = await this.samplingRequest({
117+
* messages: [{ role: "user", content: { type: "text", text: "Hello!" } }],
118+
* maxTokens: 100
119+
* });
120+
* ```
121+
*/
122+
protected async samplingRequest(
123+
request: CreateMessageRequest['params'],
124+
options?: RequestOptions,
125+
): Promise<CreateMessageResult> {
126+
if (!this.server) {
127+
throw new Error(
128+
`Cannot make sampling request: server not available in tool '${this.name}'. ` +
129+
`Sampling is only available during tool execution within an MCPServer.`,
130+
);
131+
}
132+
return this.server.createMessage(request, options);
133+
}
134+
103135
/**
104136
* Validates the tool schema. This is called automatically when the tool is registered
105137
* with an MCP server, but can also be called manually for testing.

tests/tools/BaseTool.test.ts

Lines changed: 154 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
1-
import { describe, it, expect, beforeEach } from '@jest/globals';
1+
import { describe, it, expect, beforeEach, jest } from '@jest/globals';
22
import { z } from 'zod';
33
import { MCPTool } from '../../src/tools/BaseTool.js';
4+
import { Server } from '@modelcontextprotocol/sdk/server/index.js';
5+
import { CreateMessageRequest, CreateMessageResult } from '@modelcontextprotocol/sdk/types.js';
6+
import {RequestOptions} from '@modelcontextprotocol/sdk/shared/protocol.js';
7+
8+
// Mock the Server class
9+
jest.mock('@modelcontextprotocol/sdk/server/index.js', () => ({
10+
Server: jest.fn().mockImplementation(() => ({
11+
createMessage: jest.fn(),
12+
})),
13+
}));
414

515
describe('BaseTool', () => {
616
describe('Legacy Pattern (Separate Schema Definition)', () => {
@@ -488,4 +498,147 @@ describe('BaseTool', () => {
488498
console.log(JSON.stringify(definition, null, 2));
489499
});
490500
});
501+
502+
describe('Sampling', () => {
503+
// Expose the protected samplingRequest for direct testing
504+
class SamplingTestTool extends MCPTool {
505+
name = 'sampling_tool';
506+
description = 'A tool that uses sampling';
507+
schema = z.object({
508+
prompt: z.string().describe('The prompt to sample'),
509+
});
510+
511+
protected async execute(input: { prompt: string }): Promise<unknown> {
512+
const result = await this.samplingRequest({
513+
messages: [
514+
{
515+
role: 'user',
516+
content: { type: 'text', text: input.prompt },
517+
},
518+
],
519+
maxTokens: 100,
520+
});
521+
return { sampledText: result.content.text };
522+
}
523+
524+
// Expose protected method for testing
525+
public testSamplingRequest(
526+
request: CreateMessageRequest['params'],
527+
options?: RequestOptions,
528+
) {
529+
return this.samplingRequest(request, options);
530+
}
531+
}
532+
533+
let tool: SamplingTestTool;
534+
let mockServer: jest.Mocked<Server>;
535+
536+
beforeEach(() => {
537+
tool = new SamplingTestTool();
538+
mockServer = new Server(
539+
{ name: 'test-server', version: '1.0.0' },
540+
{ capabilities: {} },
541+
) as jest.Mocked<Server>;
542+
mockServer.createMessage = jest.fn();
543+
});
544+
545+
it('should inject server without throwing', () => {
546+
expect(() => tool.injectServer(mockServer)).not.toThrow();
547+
});
548+
549+
it('should silently handle double injection', () => {
550+
tool.injectServer(mockServer);
551+
expect(() => tool.injectServer(mockServer)).not.toThrow();
552+
});
553+
554+
it('should throw when samplingRequest called without server', async () => {
555+
await expect(
556+
tool.testSamplingRequest({
557+
messages: [{ role: 'user', content: { type: 'text', text: 'test' } }],
558+
maxTokens: 100,
559+
}),
560+
).rejects.toThrow(
561+
"Cannot make sampling request: server not available in tool 'sampling_tool'.",
562+
);
563+
});
564+
565+
it('should call server.createMessage with correct params', async () => {
566+
const mockResult: CreateMessageResult = {
567+
model: 'test-model',
568+
role: 'assistant',
569+
content: { type: 'text', text: 'Sampled response' },
570+
};
571+
mockServer.createMessage.mockResolvedValue(mockResult);
572+
tool.injectServer(mockServer);
573+
574+
const request: CreateMessageRequest['params'] = {
575+
messages: [{ role: 'user', content: { type: 'text', text: 'Hello' } }],
576+
maxTokens: 100,
577+
temperature: 0.7,
578+
systemPrompt: 'Be helpful',
579+
};
580+
581+
const result = await tool.testSamplingRequest(request);
582+
583+
expect(mockServer.createMessage).toHaveBeenCalledWith(request, undefined);
584+
expect(result).toEqual(mockResult);
585+
});
586+
587+
it('should propagate createMessage errors', async () => {
588+
tool.injectServer(mockServer);
589+
mockServer.createMessage.mockRejectedValue(new Error('Sampling failed'));
590+
591+
await expect(
592+
tool.testSamplingRequest({
593+
messages: [{ role: 'user', content: { type: 'text', text: 'test' } }],
594+
maxTokens: 100,
595+
}),
596+
).rejects.toThrow('Sampling failed');
597+
});
598+
599+
it('should pass request options to createMessage', async () => {
600+
const mockResult: CreateMessageResult = {
601+
model: 'claude-3-sonnet',
602+
role: 'assistant',
603+
content: { type: 'text', text: 'Complex response' },
604+
stopReason: 'endTurn',
605+
};
606+
mockServer.createMessage.mockResolvedValue(mockResult);
607+
tool.injectServer(mockServer);
608+
609+
const request: CreateMessageRequest['params'] = {
610+
messages: [
611+
{ role: 'user', content: { type: 'text', text: 'First message' } },
612+
{ role: 'assistant', content: { type: 'text', text: 'Assistant response' } },
613+
{ role: 'user', content: { type: 'text', text: 'Follow up' } },
614+
],
615+
maxTokens: 500,
616+
temperature: 0.8,
617+
systemPrompt: 'You are a helpful assistant',
618+
modelPreferences: {
619+
hints: [{ name: 'claude-3' }],
620+
costPriority: 0.3,
621+
speedPriority: 0.7,
622+
intelligencePriority: 0.9,
623+
},
624+
stopSequences: ['END', 'STOP'],
625+
metadata: { taskType: 'analysis' },
626+
};
627+
628+
const options: RequestOptions = {
629+
timeout: 5000,
630+
maxTotalTimeout: 10000,
631+
signal: new AbortController().signal,
632+
resetTimeoutOnProgress: true,
633+
onprogress: (progress) => {
634+
console.log('Progress:', progress);
635+
},
636+
};
637+
638+
const result = await tool.testSamplingRequest(request, options);
639+
640+
expect(mockServer.createMessage).toHaveBeenCalledWith(request, options);
641+
expect(result).toEqual(mockResult);
642+
});
643+
});
491644
});

0 commit comments

Comments
 (0)