|
1 | | -import { describe, it, expect, beforeEach } from '@jest/globals'; |
| 1 | +import { describe, it, expect, beforeEach, jest } from '@jest/globals'; |
2 | 2 | import { z } from 'zod'; |
3 | 3 | import { MCPTool } from '../../src/tools/BaseTool.js'; |
| 4 | +import { Server } from '@modelcontextprotocol/sdk/server/index.js'; |
| 5 | +import { CreateMessageRequest, CreateMessageResult } from '@modelcontextprotocol/sdk/types.js'; |
| 6 | +import {RequestOptions} from '@modelcontextprotocol/sdk/shared/protocol.js'; |
| 7 | + |
| 8 | +// Mock the Server class |
| 9 | +jest.mock('@modelcontextprotocol/sdk/server/index.js', () => ({ |
| 10 | + Server: jest.fn().mockImplementation(() => ({ |
| 11 | + createMessage: jest.fn(), |
| 12 | + })), |
| 13 | +})); |
4 | 14 |
|
5 | 15 | describe('BaseTool', () => { |
6 | 16 | describe('Legacy Pattern (Separate Schema Definition)', () => { |
@@ -488,4 +498,147 @@ describe('BaseTool', () => { |
488 | 498 | console.log(JSON.stringify(definition, null, 2)); |
489 | 499 | }); |
490 | 500 | }); |
| 501 | + |
| 502 | + describe('Sampling', () => { |
| 503 | + // Expose the protected samplingRequest for direct testing |
| 504 | + class SamplingTestTool extends MCPTool { |
| 505 | + name = 'sampling_tool'; |
| 506 | + description = 'A tool that uses sampling'; |
| 507 | + schema = z.object({ |
| 508 | + prompt: z.string().describe('The prompt to sample'), |
| 509 | + }); |
| 510 | + |
| 511 | + protected async execute(input: { prompt: string }): Promise<unknown> { |
| 512 | + const result = await this.samplingRequest({ |
| 513 | + messages: [ |
| 514 | + { |
| 515 | + role: 'user', |
| 516 | + content: { type: 'text', text: input.prompt }, |
| 517 | + }, |
| 518 | + ], |
| 519 | + maxTokens: 100, |
| 520 | + }); |
| 521 | + return { sampledText: result.content.text }; |
| 522 | + } |
| 523 | + |
| 524 | + // Expose protected method for testing |
| 525 | + public testSamplingRequest( |
| 526 | + request: CreateMessageRequest['params'], |
| 527 | + options?: RequestOptions, |
| 528 | + ) { |
| 529 | + return this.samplingRequest(request, options); |
| 530 | + } |
| 531 | + } |
| 532 | + |
| 533 | + let tool: SamplingTestTool; |
| 534 | + let mockServer: jest.Mocked<Server>; |
| 535 | + |
| 536 | + beforeEach(() => { |
| 537 | + tool = new SamplingTestTool(); |
| 538 | + mockServer = new Server( |
| 539 | + { name: 'test-server', version: '1.0.0' }, |
| 540 | + { capabilities: {} }, |
| 541 | + ) as jest.Mocked<Server>; |
| 542 | + mockServer.createMessage = jest.fn(); |
| 543 | + }); |
| 544 | + |
| 545 | + it('should inject server without throwing', () => { |
| 546 | + expect(() => tool.injectServer(mockServer)).not.toThrow(); |
| 547 | + }); |
| 548 | + |
| 549 | + it('should silently handle double injection', () => { |
| 550 | + tool.injectServer(mockServer); |
| 551 | + expect(() => tool.injectServer(mockServer)).not.toThrow(); |
| 552 | + }); |
| 553 | + |
| 554 | + it('should throw when samplingRequest called without server', async () => { |
| 555 | + await expect( |
| 556 | + tool.testSamplingRequest({ |
| 557 | + messages: [{ role: 'user', content: { type: 'text', text: 'test' } }], |
| 558 | + maxTokens: 100, |
| 559 | + }), |
| 560 | + ).rejects.toThrow( |
| 561 | + "Cannot make sampling request: server not available in tool 'sampling_tool'.", |
| 562 | + ); |
| 563 | + }); |
| 564 | + |
| 565 | + it('should call server.createMessage with correct params', async () => { |
| 566 | + const mockResult: CreateMessageResult = { |
| 567 | + model: 'test-model', |
| 568 | + role: 'assistant', |
| 569 | + content: { type: 'text', text: 'Sampled response' }, |
| 570 | + }; |
| 571 | + mockServer.createMessage.mockResolvedValue(mockResult); |
| 572 | + tool.injectServer(mockServer); |
| 573 | + |
| 574 | + const request: CreateMessageRequest['params'] = { |
| 575 | + messages: [{ role: 'user', content: { type: 'text', text: 'Hello' } }], |
| 576 | + maxTokens: 100, |
| 577 | + temperature: 0.7, |
| 578 | + systemPrompt: 'Be helpful', |
| 579 | + }; |
| 580 | + |
| 581 | + const result = await tool.testSamplingRequest(request); |
| 582 | + |
| 583 | + expect(mockServer.createMessage).toHaveBeenCalledWith(request, undefined); |
| 584 | + expect(result).toEqual(mockResult); |
| 585 | + }); |
| 586 | + |
| 587 | + it('should propagate createMessage errors', async () => { |
| 588 | + tool.injectServer(mockServer); |
| 589 | + mockServer.createMessage.mockRejectedValue(new Error('Sampling failed')); |
| 590 | + |
| 591 | + await expect( |
| 592 | + tool.testSamplingRequest({ |
| 593 | + messages: [{ role: 'user', content: { type: 'text', text: 'test' } }], |
| 594 | + maxTokens: 100, |
| 595 | + }), |
| 596 | + ).rejects.toThrow('Sampling failed'); |
| 597 | + }); |
| 598 | + |
| 599 | + it('should pass request options to createMessage', async () => { |
| 600 | + const mockResult: CreateMessageResult = { |
| 601 | + model: 'claude-3-sonnet', |
| 602 | + role: 'assistant', |
| 603 | + content: { type: 'text', text: 'Complex response' }, |
| 604 | + stopReason: 'endTurn', |
| 605 | + }; |
| 606 | + mockServer.createMessage.mockResolvedValue(mockResult); |
| 607 | + tool.injectServer(mockServer); |
| 608 | + |
| 609 | + const request: CreateMessageRequest['params'] = { |
| 610 | + messages: [ |
| 611 | + { role: 'user', content: { type: 'text', text: 'First message' } }, |
| 612 | + { role: 'assistant', content: { type: 'text', text: 'Assistant response' } }, |
| 613 | + { role: 'user', content: { type: 'text', text: 'Follow up' } }, |
| 614 | + ], |
| 615 | + maxTokens: 500, |
| 616 | + temperature: 0.8, |
| 617 | + systemPrompt: 'You are a helpful assistant', |
| 618 | + modelPreferences: { |
| 619 | + hints: [{ name: 'claude-3' }], |
| 620 | + costPriority: 0.3, |
| 621 | + speedPriority: 0.7, |
| 622 | + intelligencePriority: 0.9, |
| 623 | + }, |
| 624 | + stopSequences: ['END', 'STOP'], |
| 625 | + metadata: { taskType: 'analysis' }, |
| 626 | + }; |
| 627 | + |
| 628 | + const options: RequestOptions = { |
| 629 | + timeout: 5000, |
| 630 | + maxTotalTimeout: 10000, |
| 631 | + signal: new AbortController().signal, |
| 632 | + resetTimeoutOnProgress: true, |
| 633 | + onprogress: (progress) => { |
| 634 | + console.log('Progress:', progress); |
| 635 | + }, |
| 636 | + }; |
| 637 | + |
| 638 | + const result = await tool.testSamplingRequest(request, options); |
| 639 | + |
| 640 | + expect(mockServer.createMessage).toHaveBeenCalledWith(request, options); |
| 641 | + expect(result).toEqual(mockResult); |
| 642 | + }); |
| 643 | + }); |
491 | 644 | }); |
0 commit comments