diff --git a/README.md b/README.md index a6823236..334f0211 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ It implements the Model Context Protocol specification, handling model context r - Supports resource registration and retrieval - Supports stdio & Streamable HTTP (including SSE) transports - Supports notifications for list changes (tools, prompts, resources) +- Supports sampling (server-to-client LLM completion requests) ### Supported Methods @@ -51,6 +52,7 @@ It implements the Model Context Protocol specification, handling model context r - `resources/read` - Retrieves a specific resource by name - `resources/templates/list` - Lists all registered resource templates and their schemas - `completion/complete` - Returns autocompletion suggestions for prompt arguments and resource URIs +- `sampling/createMessage` - Requests LLM completion from the client (server-to-client) ### Custom Methods @@ -103,6 +105,163 @@ end - Raises `MCP::Server::MethodAlreadyDefinedError` if trying to override an existing method - Supports the same exception reporting and instrumentation as standard methods +### Sampling + +The Model Context Protocol allows servers to request LLM completions from clients through the `sampling/createMessage` method. +This enables servers to leverage the client's LLM capabilities without needing direct access to AI models. + +**Key Concepts:** + +- **Server-to-Client Request**: Unlike typical MCP methods (client→server), sampling is initiated by the server +- **Client Capability**: Clients must declare `sampling` capability during initialization +- **Tool Support**: When using tools in sampling requests, clients must declare `sampling.tools` capability +- **Human-in-the-Loop**: Clients can implement user approval before forwarding requests to LLMs + +**Usage Example (Stdio transport):** + +`Server#create_sampling_message` is for single-client transports (e.g., `StdioTransport`). +For multi-client transports (e.g., `StreamableHTTPTransport`), use `server_context.create_sampling_message` inside tools instead, +which routes the request to the correct client session. + +```ruby +server = MCP::Server.new(name: "my_server") +transport = MCP::Server::Transports::StdioTransport.new(server) +server.transport = transport +``` + +Client must declare sampling capability during initialization. +This happens automatically when the client connects. + +```ruby +result = server.create_sampling_message( + messages: [ + { role: "user", content: { type: "text", text: "What is the capital of France?" } } + ], + max_tokens: 100, + system_prompt: "You are a helpful assistant.", + temperature: 0.7 +) +``` + +Result contains the LLM response: + +```ruby +{ + role: "assistant", + content: { type: "text", text: "The capital of France is Paris." }, + model: "claude-3-sonnet-20240307", + stopReason: "endTurn" +} +``` + +**Parameters:** + +Required: + +- `messages:` (Array) - Array of message objects with `role` and `content` +- `max_tokens:` (Integer) - Maximum tokens in the response + +Optional: + +- `system_prompt:` (String) - System prompt for the LLM +- `model_preferences:` (Hash) - Model selection preferences (e.g., `{ intelligencePriority: 0.8 }`) +- `include_context:` (String) - Context inclusion: `"none"`, `"thisServer"`, or `"allServers"` (soft-deprecated) +- `temperature:` (Float) - Sampling temperature +- `stop_sequences:` (Array) - Sequences that stop generation +- `metadata:` (Hash) - Additional metadata +- `tools:` (Array) - Tools available to the LLM (requires `sampling.tools` capability) +- `tool_choice:` (Hash) - Tool selection mode (e.g., `{ mode: "auto" }`) + +**Using Sampling in Tools (works with both Stdio and HTTP transports):** + +Tools that accept a `server_context:` parameter can call `create_sampling_message` on it. +The request is automatically routed to the correct client session. +Set `server.server_context = server` so that `server_context.create_sampling_message` delegates to the server: + +```ruby +class SummarizeTool < MCP::Tool + description "Summarize text using LLM" + input_schema( + properties: { + text: { type: "string" } + }, + required: ["text"] + ) + + def self.call(text:, server_context:) + result = server_context.create_sampling_message( + messages: [ + { role: "user", content: { type: "text", text: "Please summarize: #{text}" } } + ], + max_tokens: 500 + ) + + MCP::Tool::Response.new([{ + type: "text", + text: result[:content][:text] + }]) + end +end + +server = MCP::Server.new(name: "my_server", tools: [SummarizeTool]) +server.server_context = server +``` + +**Tool Use in Sampling:** + +When tools are provided in a sampling request, the LLM can call them during generation. +The server must handle tool calls and continue the conversation with tool results: + +```ruby +result = server.create_sampling_message( + messages: [ + { role: "user", content: { type: "text", text: "What's the weather in Paris?" } } + ], + max_tokens: 1000, + tools: [ + { + name: "get_weather", + description: "Get weather for a city", + inputSchema: { + type: "object", + properties: { city: { type: "string" } }, + required: ["city"] + } + } + ], + tool_choice: { mode: "auto" } +) + +if result[:stopReason] == "toolUse" + tool_results = result[:content].map do |tool_use| + weather_data = get_weather(tool_use[:input][:city]) + + { + type: "tool_result", + toolUseId: tool_use[:id], + content: [{ type: "text", text: weather_data.to_json }] + } + end + + final_result = server.create_sampling_message( + messages: [ + { role: "user", content: { type: "text", text: "What's the weather in Paris?" } }, + { role: "assistant", content: result[:content] }, + { role: "user", content: tool_results } + ], + max_tokens: 1000, + tools: [...] + ) +end +``` + +**Error Handling:** + +- Raises `RuntimeError` if transport is not set +- Raises `RuntimeError` if client does not support `sampling` capability +- Raises `RuntimeError` if `tools` are used but client lacks `sampling.tools` capability +- Raises `StandardError` if client returns an error response + ### Notifications The server supports sending notifications to clients when lists of tools, prompts, or resources change. This enables real-time updates without polling. diff --git a/conformance/expected_failures.yml b/conformance/expected_failures.yml index c380d270..cb7a3fa8 100644 --- a/conformance/expected_failures.yml +++ b/conformance/expected_failures.yml @@ -1,7 +1,5 @@ server: - # TODO: Server-to-client requests (sampling/createMessage, elicitation/create) are not implemented. - # `Transport#send_request` does not exist in the current SDK. - - tools-call-sampling + # TODO: Server-to-client requests (elicitation/create) are not implemented. - tools-call-elicitation - elicitation-sep1034-defaults - elicitation-sep1330-enums diff --git a/conformance/server.rb b/conformance/server.rb index 9e5fd5ec..4bfac2ad 100644 --- a/conformance/server.rb +++ b/conformance/server.rb @@ -156,7 +156,6 @@ def call(server_context:, **_args) end end - # TODO: Implement when `Transport` supports server-to-client requests. class TestSampling < MCP::Tool tool_name "test_sampling" description "A tool that requests LLM sampling from the client" @@ -166,11 +165,15 @@ class TestSampling < MCP::Tool ) class << self - def call(prompt:) - MCP::Tool::Response.new( - [MCP::Content::Text.new("Sampling not supported in this SDK version").to_h], - error: true, + def call(prompt:, server_context:) + result = server_context.create_sampling_message( + messages: [{ role: "user", content: { type: "text", text: prompt } }], + max_tokens: 100, ) + model = result[:model] || "unknown" + text = result.dig(:content, :text) || "" + + MCP::Tool::Response.new([MCP::Content::Text.new("LLM response: #{text} (model: #{model})").to_h]) end end end diff --git a/lib/mcp/server.rb b/lib/mcp/server.rb index 1b026a13..484b29b5 100644 --- a/lib/mcp/server.rb +++ b/lib/mcp/server.rb @@ -54,6 +54,7 @@ def initialize(method_name) include Instrumentation attr_accessor :description, :icons, :name, :title, :version, :website_url, :instructions, :tools, :prompts, :resources, :server_context, :configuration, :capabilities, :transport, :logging_message_notification + attr_reader :client_capabilities def initialize( description: nil, @@ -92,6 +93,7 @@ def initialize( validate! @capabilities = capabilities || default_capabilities + @client_capabilities = nil @logging_message_notification = nil @handlers = { @@ -204,6 +206,43 @@ def notify_log_message(data:, level:, logger: nil) report_exception(e, { notification: "log_message" }) end + # Sends a `sampling/createMessage` request to the client. + # For single-client transports (e.g., `StdioTransport`). For multi-client transports + # (e.g., `StreamableHTTPTransport`), use `ServerSession#create_sampling_message` instead + # to ensure the request is routed to the correct client. + def create_sampling_message( + messages:, + max_tokens:, + system_prompt: nil, + model_preferences: nil, + include_context: nil, + temperature: nil, + stop_sequences: nil, + metadata: nil, + tools: nil, + tool_choice: nil + ) + unless @transport + raise "Cannot send sampling request without a transport." + end + + params = build_sampling_params( + @client_capabilities, + messages: messages, + max_tokens: max_tokens, + system_prompt: system_prompt, + model_preferences: model_preferences, + include_context: include_context, + temperature: temperature, + stop_sequences: stop_sequences, + metadata: metadata, + tools: tools, + tool_choice: tool_choice, + ) + + @transport.send_request(Methods::SAMPLING_CREATE_MESSAGE, params) + end + # Sets a custom handler for `resources/read` requests. # The block receives the parsed request params and should return resource # contents. The return value is set as the `contents` field of the response. @@ -223,6 +262,45 @@ def completion_handler(&block) @handlers[Methods::COMPLETION_COMPLETE] = block end + def build_sampling_params( + capabilities, + messages:, + max_tokens:, + system_prompt: nil, + model_preferences: nil, + include_context: nil, + temperature: nil, + stop_sequences: nil, + metadata: nil, + tools: nil, + tool_choice: nil + ) + unless capabilities&.dig(:sampling) + raise "Client does not support sampling." + end + + if tools && !capabilities.dig(:sampling, :tools) + raise "Client does not support sampling with tools." + end + + if tool_choice && !capabilities.dig(:sampling, :tools) + raise "Client does not support sampling with tool_choice." + end + + { + messages: messages, + maxTokens: max_tokens, + systemPrompt: system_prompt, + modelPreferences: model_preferences, + includeContext: include_context, + temperature: temperature, + stopSequences: stop_sequences, + metadata: metadata, + tools: tools, + toolChoice: tool_choice, + }.compact + end + private def validate! @@ -371,10 +449,11 @@ def init(params, session: nil) session.store_client_info(client: params[:clientInfo], capabilities: params[:capabilities]) else @client = params[:clientInfo] + @client_capabilities = params[:capabilities] end + protocol_version = params[:protocolVersion] end - protocol_version = params[:protocolVersion] if params negotiated_version = if Configuration::SUPPORTED_STABLE_PROTOCOL_VERSIONS.include?(protocol_version) protocol_version else diff --git a/lib/mcp/server/transports/stdio_transport.rb b/lib/mcp/server/transports/stdio_transport.rb index 9d54cec5..3d6a4ff1 100644 --- a/lib/mcp/server/transports/stdio_transport.rb +++ b/lib/mcp/server/transports/stdio_transport.rb @@ -53,6 +53,41 @@ def send_notification(method, params = nil) MCP.configuration.exception_reporter.call(e, { error: "Failed to send notification" }) false end + + def send_request(method, params = nil) + request_id = generate_request_id + request = { jsonrpc: "2.0", id: request_id, method: method } + request[:params] = params if params + + begin + send_response(request) + rescue => e + MCP.configuration.exception_reporter.call(e, { error: "Failed to send request" }) + raise + end + + while @open && (line = $stdin.gets) + begin + parsed = JSON.parse(line.strip, symbolize_names: true) + rescue JSON::ParserError => e + MCP.configuration.exception_reporter.call(e, { error: "Failed to parse response" }) + raise + end + + if parsed[:id] == request_id && !parsed.key?(:method) + if parsed[:error] + raise StandardError, "Client returned an error for #{method} request (code: #{parsed[:error][:code]}): #{parsed[:error][:message]}" + end + + return parsed[:result] + else + response = @session ? @session.handle(parsed) : @server.handle(parsed) + send_response(response) if response + end + end + + raise "Transport closed while waiting for response to #{method} request." + end end end end diff --git a/lib/mcp/server/transports/streamable_http_transport.rb b/lib/mcp/server/transports/streamable_http_transport.rb index 14b6906c..31ddc896 100644 --- a/lib/mcp/server/transports/streamable_http_transport.rb +++ b/lib/mcp/server/transports/streamable_http_transport.rb @@ -1,7 +1,6 @@ # frozen_string_literal: true require "json" -require "securerandom" require_relative "../../transport" module MCP @@ -16,6 +15,7 @@ def initialize(server, stateless: false, session_idle_timeout: nil) @stateless = stateless @session_idle_timeout = session_idle_timeout + @pending_responses = {} if @session_idle_timeout if @stateless @@ -132,6 +132,77 @@ def send_notification(method, params = nil, session_id: nil) result end + # Sends a server-to-client JSON-RPC request (e.g., `sampling/createMessage`) and + # blocks until the client responds. + # + # Uses a `Queue` for cross-thread synchronization. This method creates a `Queue`, + # sends the request via SSE stream, then blocks on `queue.pop`. + # When the client POSTs a response, `handle_response` matches it by `request_id` + # and pushes the result onto the queue, unblocking this thread. + def send_request(method, params = nil, session_id: nil) + if @stateless + raise "Stateless mode does not support server-to-client requests." + end + + unless session_id + raise "session_id is required for server-to-client requests." + end + + request_id = generate_request_id + queue = Queue.new + + request = { jsonrpc: "2.0", id: request_id, method: method } + request[:params] = params if params + + sent = false + + @mutex.synchronize do + unless (session = @sessions[session_id]) + raise "Session not found: #{session_id}." + end + + @pending_responses[request_id] = { queue: queue, session_id: session_id } + + if (stream = session[:stream]) + begin + send_to_stream(stream, request) + sent = true + rescue *STREAM_WRITE_ERRORS + cleanup_session_unsafe(session_id) + end + end + end + + # TODO: Replace with event store + replay when resumability is implemented. + # Resumability is a separate MCP specification feature (SSE event IDs, Last-Event-ID replay, + # event store management) independent of sampling. + # See: https://modelcontextprotocol.io/specification/latest/basic/transports#resumability-and-redelivery + # + # The TypeScript and Python SDKs buffer messages and replay on reconnect. + # Until then, raise to prevent queue.pop from blocking indefinitely. + unless sent + raise "No active SSE stream for #{method} request." + end + + response = queue.pop + + if response.is_a?(Hash) && response.key?(:error) + raise StandardError, "Client returned an error for #{method} request (code: #{response[:error][:code]}): #{response[:error][:message]}" + end + + if response == :session_closed + raise "SSE session closed while waiting for #{method} response." + end + + response + ensure + if request_id + @mutex.synchronize do + @pending_responses.delete(request_id) + end + end + end + private def start_reaper_thread @@ -187,8 +258,12 @@ def handle_post(request) else return missing_session_id_response if !@stateless && !session_id - if notification?(body) || response?(body) + if notification?(body) handle_accepted + elsif response?(body) + return session_not_found_response if !@stateless && !session_exists?(session_id) + + handle_response(body, session_id: session_id) else handle_regular_request(body_string, session_id) end @@ -245,7 +320,16 @@ def cleanup_session(session_id) # Callers must close the stream outside the mutex to avoid holding the lock during # potentially blocking I/O. def cleanup_session_unsafe(session_id) - @sessions.delete(session_id) + session = @sessions.delete(session_id) + + # Unblock threads waiting on pending responses for this session. + @pending_responses.each_value do |pending_response| + if pending_response[:session_id] == session_id + pending_response[:queue].push(:session_closed) + end + end + + session end def cleanup_and_collect_stream(session_id, streams_to_close) @@ -305,6 +389,24 @@ def response?(body) !!body[:id] && !body[:method] end + # Verifies that the response came from the expected session to prevent + # cross-session response injection if request IDs are ever leaked. + def handle_response(body, session_id:) + request_id = body[:id] + @mutex.synchronize do + if (pending_response = @pending_responses[request_id]) && pending_response[:session_id] == session_id + if body.key?(:error) + error = body[:error] + pending_response[:queue].push(error: { code: error[:code], message: error[:message] }) + else + pending_response[:queue].push(body[:result]) + end + end + end + + handle_accepted + end + def handle_initialization(body_string, body) session_id = nil server_session = nil diff --git a/lib/mcp/server_context.rb b/lib/mcp/server_context.rb index 7d1f2d9a..b532555b 100644 --- a/lib/mcp/server_context.rb +++ b/lib/mcp/server_context.rb @@ -29,6 +29,19 @@ def notify_log_message(data:, level:, logger: nil) @notification_target.notify_log_message(data: data, level: level, logger: logger) end + # Delegates to the session so the request is scoped to the originating client. + # Falls back to `@context` (via `method_missing`) when `@notification_target` + # does not support sampling. + def create_sampling_message(**kwargs) + if @notification_target.respond_to?(:create_sampling_message) + @notification_target.create_sampling_message(**kwargs) + elsif @context.respond_to?(:create_sampling_message) + @context.create_sampling_message(**kwargs) + else + raise NoMethodError, "undefined method 'create_sampling_message' for #{self}" + end + end + def method_missing(name, ...) if @context.respond_to?(name) @context.public_send(name, ...) diff --git a/lib/mcp/server_session.rb b/lib/mcp/server_session.rb index a1cbfd09..93e823fb 100644 --- a/lib/mcp/server_session.rb +++ b/lib/mcp/server_session.rb @@ -13,7 +13,7 @@ def initialize(server:, transport:, session_id: nil) @transport = transport @session_id = session_id @client = nil - @client_capabilities = nil # TODO: Use for per-session capability validation. + @client_capabilities = nil @logging_message_notification = nil end @@ -36,6 +36,17 @@ def configure_logging(logging_message_notification) @logging_message_notification = logging_message_notification end + # Returns per-session client capabilities, falling back to global. + def client_capabilities + @client_capabilities || @server.client_capabilities + end + + # Sends a `sampling/createMessage` request scoped to this session. + def create_sampling_message(**kwargs) + params = @server.build_sampling_params(client_capabilities, **kwargs) + send_to_transport_request(Methods::SAMPLING_CREATE_MESSAGE, params) + end + # Sends a progress notification to this session only. def notify_progress(progress_token:, progress:, total: nil, message: nil) params = { @@ -65,6 +76,9 @@ def notify_log_message(data:, level:, logger: nil) private + # Branches on `@session_id` because `StdioTransport` creates a `ServerSession` without + # a `session_id` (`session_id: nil`), while `StreamableHTTPTransport` always provides one. + # # TODO: When Ruby 2.7 support is dropped, replace with a direct call: # `@transport.send_notification(method, params, session_id: @session_id)` and # add `**` to `Transport#send_notification` and `StdioTransport#send_notification`. @@ -75,5 +89,19 @@ def send_to_transport(method, params) @transport.send_notification(method, params) end end + + # Branches on `@session_id` because `StdioTransport` creates a `ServerSession` without + # a `session_id` (`session_id: nil`), while `StreamableHTTPTransport` always provides one. + # + # TODO: When Ruby 2.7 support is dropped, replace with a direct call: + # `@transport.send_request(method, params, session_id: @session_id)` and + # add `**` to `Transport#send_request` and `StdioTransport#send_request`. + def send_to_transport_request(method, params) + if @session_id + @transport.send_request(method, params, session_id: @session_id) + else + @transport.send_request(method, params) + end + end end end diff --git a/lib/mcp/transport.rb b/lib/mcp/transport.rb index 99b88921..4e3f85f3 100644 --- a/lib/mcp/transport.rb +++ b/lib/mcp/transport.rb @@ -1,5 +1,7 @@ # frozen_string_literal: true +require "securerandom" + module MCP class Transport # Initialize the transport with the server instance @@ -41,5 +43,16 @@ def handle_request(request) def send_notification(method, params = nil) raise NotImplementedError, "Subclasses must implement send_notification" end + + # Send a JSON-RPC request to the client and wait for a response. + def send_request(method, params = nil) + raise NotImplementedError, "Subclasses must implement send_request" + end + + private + + def generate_request_id + SecureRandom.uuid + end end end diff --git a/test/mcp/server/transports/stdio_transport_test.rb b/test/mcp/server/transports/stdio_transport_test.rb index 59d3abe2..e9b1011b 100644 --- a/test/mcp/server/transports/stdio_transport_test.rb +++ b/test/mcp/server/transports/stdio_transport_test.rb @@ -151,6 +151,247 @@ class StdioTransportTest < ActiveSupport::TestCase $stdout = original_stdout end end + + test "send_request sends request to stdout and waits for response" do + reader, writer = IO.pipe + output = StringIO.new + original_stdin = $stdin + original_stdout = $stdout + + begin + $stdin = reader + $stdout = output + @transport.instance_variable_set(:@open, true) + + # Send response from client in a thread. + Thread.new do + sleep(0.05) # Wait for request to be written to `StringIO`. + request = JSON.parse(output.string.lines.first, symbolize_names: true) + response = { + jsonrpc: "2.0", + id: request[:id], + result: { content: "test response" }, + } + writer.puts(response.to_json) + writer.flush + end + + result = @transport.send_request("test/method", { param: "value" }) + + assert_equal({ content: "test response" }, result) + + # Verify request was sent. + request = JSON.parse(output.string.lines.first, symbolize_names: true) + assert_equal("2.0", request[:jsonrpc]) + assert_equal("test/method", request[:method]) + assert_equal({ param: "value" }, request[:params]) + assert(request[:id]) + ensure + $stdin = original_stdin + $stdout = original_stdout + begin + writer.close + rescue + nil + end + begin + reader.close + rescue + nil + end + end + end + + test "send_request raises on error response from client" do + reader, writer = IO.pipe + output = StringIO.new + original_stdin = $stdin + original_stdout = $stdout + + begin + $stdin = reader + $stdout = output + @transport.instance_variable_set(:@open, true) + + Thread.new do + sleep(0.05) # Wait for request to be written to `StringIO`. + request = JSON.parse(output.string.lines.first, symbolize_names: true) + error_response = { + jsonrpc: "2.0", + id: request[:id], + error: { code: -1, message: "User rejected sampling request" }, + } + writer.puts(error_response.to_json) + writer.flush + end + + error = assert_raises(StandardError) do + @transport.send_request("sampling/createMessage", { messages: [] }) + end + + assert_equal("Client returned an error for sampling/createMessage request (code: -1): User rejected sampling request", error.message) + ensure + $stdin = original_stdin + $stdout = original_stdout + begin + writer.close + rescue + nil + end + begin + reader.close + rescue + nil + end + end + end + + test "send_request does not double-report intentional raises via exception_reporter" do + reader, writer = IO.pipe + output = StringIO.new + original_stdin = $stdin + original_stdout = $stdout + reported_errors = [] + original_reporter = MCP.configuration.exception_reporter + + begin + MCP.configuration.exception_reporter = ->(e, ctx) { reported_errors << [e, ctx] } + $stdin = reader + $stdout = output + @transport.instance_variable_set(:@open, true) + + Thread.new do + sleep(0.05) # Wait for request to be written to `StringIO`. + request = JSON.parse(output.string.lines.first, symbolize_names: true) + error_response = { + jsonrpc: "2.0", + id: request[:id], + error: { code: -1, message: "rejected" }, + } + writer.puts(error_response.to_json) + writer.flush + end + + assert_raises(StandardError) do + @transport.send_request("sampling/createMessage", { messages: [] }) + end + + assert_empty(reported_errors) + ensure + MCP.configuration.exception_reporter = original_reporter + $stdin = original_stdin + $stdout = original_stdout + begin + writer.close + rescue + nil + end + begin + reader.close + rescue + nil + end + end + end + + test "send_request processes interleaved requests via session" do + reader, writer = IO.pipe + output = StringIO.new + original_stdin = $stdin + original_stdout = $stdout + + begin + $stdin = reader + $stdout = output + @transport.instance_variable_set(:@open, true) + + # Initialize a session so @session is set. + session = MCP::ServerSession.new(server: @server, transport: @transport) + @transport.instance_variable_set(:@session, session) + + Thread.new do + sleep(0.05) # Wait for request to be written to `StringIO`. + request = JSON.parse(output.string.lines.first, symbolize_names: true) + + # Send an interleaved ping request before the response. + ping = { jsonrpc: "2.0", method: "ping", id: "ping-1" } + writer.puts(ping.to_json) + writer.flush + + sleep(0.05) # Wait for the ping to be processed. + + # Then send the actual response. + response = { + jsonrpc: "2.0", + id: request[:id], + result: { content: "done" }, + } + writer.puts(response.to_json) + writer.flush + end + + result = @transport.send_request("test/method", { param: "value" }) + + assert_equal({ content: "done" }, result) + + # Verify the interleaved ping was handled (response sent to output). + lines = output.string.lines + ping_response = lines.find { |l| l.include?("ping-1") } + assert(ping_response, "Interleaved ping request should have been handled") + ensure + $stdin = original_stdin + $stdout = original_stdout + begin + writer.close + rescue + nil + end + begin + reader.close + rescue + nil + end + end + end + + test "send_request raises when transport is closed while waiting" do + reader, writer = IO.pipe + output = StringIO.new + original_stdin = $stdin + original_stdout = $stdout + + begin + $stdin = reader + $stdout = output + @transport.instance_variable_set(:@open, true) + + # Close transport while waiting for response. + Thread.new do + sleep(0.05) # Wait for request to be written to `StringIO`. + @transport.instance_variable_set(:@open, false) + writer.close + end + + error = assert_raises(RuntimeError) do + @transport.send_request("sampling/createMessage", { messages: [] }) + end + + assert_equal("Transport closed while waiting for response to sampling/createMessage request.", error.message) + ensure + $stdin = original_stdin + $stdout = original_stdout + begin + writer.close + rescue IOError + nil + end + begin + reader.close + rescue IOError + nil + end + end + end end end end diff --git a/test/mcp/server/transports/streamable_http_transport_test.rb b/test/mcp/server/transports/streamable_http_transport_test.rb index d4faa8d4..507ac45b 100644 --- a/test/mcp/server/transports/streamable_http_transport_test.rb +++ b/test/mcp/server/transports/streamable_http_transport_test.rb @@ -1395,6 +1395,294 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase assert_equal "Internal server error", body["error"] end + test "send_request raises error in stateless mode" do + stateless_transport = StreamableHTTPTransport.new(@server, stateless: true) + + error = assert_raises(RuntimeError) do + stateless_transport.send_request("sampling/createMessage", { "messages" => [] }) + end + + assert_equal("Stateless mode does not support server-to-client requests.", error.message) + end + + test "send_request raises error when session_id is not provided" do + error = assert_raises(RuntimeError) do + @transport.send_request("sampling/createMessage", { "messages" => [] }) + end + + assert_equal("session_id is required for server-to-client requests.", error.message) + end + + test "send_request raises error when session is not found" do + error = assert_raises(RuntimeError) do + @transport.send_request("sampling/createMessage", { "messages" => [] }, session_id: "nonexistent") + end + + assert_equal("Session not found: nonexistent.", error.message) + end + + test "send_request raises error when no active SSE streams" do + # Create session but do NOT connect SSE. + init_request = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json" }, + { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json, + ) + init_response = @transport.handle_request(init_request) + session_id = init_response[1]["Mcp-Session-Id"] + + error = assert_raises(RuntimeError) do + @transport.send_request("sampling/createMessage", { "messages" => [] }, session_id: session_id) + end + + assert_equal("No active SSE stream for sampling/createMessage request.", error.message) + end + + test "send_request sends via SSE and waits for response" do + # Create session. + init_request = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json" }, + { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json, + ) + init_response = @transport.handle_request(init_request) + session_id = init_response[1]["Mcp-Session-Id"] + + # Connect SSE. + io = StringIO.new + get_request = create_rack_request( + "GET", + "/", + { "HTTP_MCP_SESSION_ID" => session_id }, + ) + response = @transport.handle_request(get_request) + response[2].call(io) if response[2].is_a?(Proc) + + sleep(0.1) # Give the stream time to set up. + + # Send request in background. + result_queue = Queue.new + Thread.new do + result = @transport.send_request( + "sampling/createMessage", + { messages: [{ role: "user", content: { type: "text", text: "Hello" } }], maxTokens: 100 }, + session_id: session_id, + ) + result_queue.push(result) + end + + sleep(0.1) # Wait for the request to be sent. + + # Verify request was sent to stream. + io.rewind + output = io.read + assert_includes output, "sampling/createMessage" + + # Parse the sent request to get its ID. + data_lines = output.lines.select { |line| line.start_with?("data: ") } + request_data = JSON.parse(data_lines.first.sub("data: ", "")) + request_id = request_data["id"] + + # Simulate client response. + client_response = create_rack_request( + "POST", + "/", + { + "CONTENT_TYPE" => "application/json", + "HTTP_MCP_SESSION_ID" => session_id, + }, + { + jsonrpc: "2.0", + id: request_id, + result: { role: "assistant", content: { type: "text", text: "Hi there" } }, + }.to_json, + ) + @transport.handle_request(client_response) + + # Get result. + result = result_queue.pop + assert_equal "assistant", result[:role] + assert_equal "Hi there", result[:content][:text] + end + + test "send_request ignores response from wrong session" do + # Create two sessions. + init_a = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json" }, + { jsonrpc: "2.0", method: "initialize", id: "init-a" }.to_json, + ) + resp_a = @transport.handle_request(init_a) + session_a = resp_a[1]["Mcp-Session-Id"] + + init_b = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json" }, + { jsonrpc: "2.0", method: "initialize", id: "init-b" }.to_json, + ) + resp_b = @transport.handle_request(init_b) + session_b = resp_b[1]["Mcp-Session-Id"] + + # Connect SSE for session A. + io_a = StringIO.new + get_a = create_rack_request("GET", "/", { "HTTP_MCP_SESSION_ID" => session_a }) + response_a = @transport.handle_request(get_a) + response_a[2].call(io_a) if response_a[2].is_a?(Proc) + + sleep(0.1) # Give the stream time to set up. + + # Send sampling request targeting session A. + result_queue = Queue.new + Thread.new do + result = @transport.send_request( + "sampling/createMessage", + { messages: [{ role: "user", content: { type: "text", text: "Hello" } }], maxTokens: 100 }, + session_id: session_a, + ) + result_queue.push(result) + end + + sleep(0.1) # Wait for the request to be sent. + + # Get the request ID from session A's stream. + io_a.rewind + data_lines = io_a.read.lines.select { |line| line.start_with?("data: ") } + request_data = JSON.parse(data_lines.first.sub("data: ", "")) + request_id = request_data["id"] + + # Session B tries to respond (cross-session injection attempt). + cross_session_response = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json", "HTTP_MCP_SESSION_ID" => session_b }, + { jsonrpc: "2.0", id: request_id, result: { role: "assistant", content: { type: "text", text: "injected" } } }.to_json, + ) + @transport.handle_request(cross_session_response) + + # The request should still be pending (not resolved by wrong session). + assert_empty(result_queue, "Response from wrong session should be ignored") + + # Now send the correct response from session A. + correct_response = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json", "HTTP_MCP_SESSION_ID" => session_a }, + { jsonrpc: "2.0", id: request_id, result: { role: "assistant", content: { type: "text", text: "correct" } } }.to_json, + ) + @transport.handle_request(correct_response) + + result = result_queue.pop + assert_equal "correct", result[:content][:text] + end + + test "send_request raises on error response from client" do + # Create session. + init_request = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json" }, + { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json, + ) + init_response = @transport.handle_request(init_request) + session_id = init_response[1]["Mcp-Session-Id"] + + # Connect SSE. + io = StringIO.new + get_request = create_rack_request( + "GET", + "/", + { "HTTP_MCP_SESSION_ID" => session_id }, + ) + response = @transport.handle_request(get_request) + response[2].call(io) if response[2].is_a?(Proc) + + sleep(0.1) # Give the stream time to set up. + + error_queue = Queue.new + Thread.new do + @transport.send_request("sampling/createMessage", { messages: [] }, session_id: session_id) + rescue => e + error_queue.push(e) + end + + sleep(0.1) # Wait for the request to be sent. + + # Get request ID from stream. + io.rewind + data_lines = io.read.lines.select { |line| line.start_with?("data: ") } + request_data = JSON.parse(data_lines.first.sub("data: ", "")) + request_id = request_data["id"] + + # Send error response. + error_response = create_rack_request( + "POST", + "/", + { + "CONTENT_TYPE" => "application/json", + "HTTP_MCP_SESSION_ID" => session_id, + }, + { + jsonrpc: "2.0", + id: request_id, + error: { code: -1, message: "User rejected" }, + }.to_json, + ) + @transport.handle_request(error_response) + + error = error_queue.pop + assert_kind_of StandardError, error + assert_equal("Client returned an error for sampling/createMessage request (code: -1): User rejected", error.message) + end + + test "send_request unblocks when session is cleaned up" do + # Create session. + init_request = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json" }, + { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json, + ) + init_response = @transport.handle_request(init_request) + session_id = init_response[1]["Mcp-Session-Id"] + + # Connect SSE. + io = StringIO.new + get_request = create_rack_request( + "GET", + "/", + { "HTTP_MCP_SESSION_ID" => session_id }, + ) + response = @transport.handle_request(get_request) + response[2].call(io) if response[2].is_a?(Proc) + + sleep(0.1) # Give the stream time to set up. + + error_queue = Queue.new + Thread.new do + @transport.send_request("sampling/createMessage", { messages: [] }, session_id: session_id) + rescue => e + error_queue.push(e) + end + + sleep(0.1) # Wait for the request to be sent. + + # Delete the session to trigger cleanup (simulates client disconnect). + delete_request = create_rack_request( + "DELETE", + "/", + { "HTTP_MCP_SESSION_ID" => session_id }, + ) + @transport.handle_request(delete_request) + + error = error_queue.pop + assert_kind_of RuntimeError, error + assert_equal("SSE session closed while waiting for sampling/createMessage response.", error.message) + end + test "POST notifications/initialized returns 202 with no body" do # Create a session first (optional for notification, but keep consistent with flow) init_request = create_rack_request( @@ -1821,6 +2109,45 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase assert_equal([], response[2]) end + test "POST response without session ID returns 400" do + request = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json" }, + { jsonrpc: "2.0", result: "success", id: "123" }.to_json, + ) + + response = @transport.handle_request(request) + assert_equal 400, response[0] + body = JSON.parse(response[2][0]) + assert_equal "Missing session ID", body["error"] + end + + test "POST response with invalid session ID returns 404" do + init_request = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json" }, + { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json, + ) + @transport.handle_request(init_request) + + request = create_rack_request( + "POST", + "/", + { + "CONTENT_TYPE" => "application/json", + "HTTP_MCP_SESSION_ID" => "invalid-session-id", + }, + { jsonrpc: "2.0", result: "success", id: "123" }.to_json, + ) + + response = @transport.handle_request(request) + assert_equal 404, response[0] + body = JSON.parse(response[2][0]) + assert_equal "Session not found", body["error"] + end + test "handle_regular_request returns 404 for unknown session_id" do request = create_rack_request( "POST", diff --git a/test/mcp/server_context_test.rb b/test/mcp/server_context_test.rb index dd967961..605e3852 100644 --- a/test/mcp/server_context_test.rb +++ b/test/mcp/server_context_test.rb @@ -41,6 +41,46 @@ class ServerContextTest < ActiveSupport::TestCase assert_raises(NoMethodError) { server_context.nonexistent_method } end + test "ServerContext#create_sampling_message delegates to notification_target over context" do + notification_target = mock + notification_target.expects(:create_sampling_message).with( + messages: [{ role: "user", content: { type: "text", text: "Hello" } }], + max_tokens: 100, + ).returns({ role: "assistant", content: { type: "text", text: "Hi" } }) + + context = mock + progress = Progress.new(notification_target: notification_target, progress_token: nil) + + server_context = ServerContext.new(context, progress: progress, notification_target: notification_target) + + result = server_context.create_sampling_message( + messages: [{ role: "user", content: { type: "text", text: "Hello" } }], + max_tokens: 100, + ) + + assert_equal "Hi", result[:content][:text] + end + + test "ServerContext#create_sampling_message falls back to context when notification_target does not respond" do + notification_target = mock + context = mock + context.expects(:create_sampling_message).with( + messages: [{ role: "user", content: { type: "text", text: "Hello" } }], + max_tokens: 100, + ).returns({ role: "assistant", content: { type: "text", text: "Fallback" } }) + + progress = Progress.new(notification_target: notification_target, progress_token: nil) + + server_context = ServerContext.new(context, progress: progress, notification_target: notification_target) + + result = server_context.create_sampling_message( + messages: [{ role: "user", content: { type: "text", text: "Hello" } }], + max_tokens: 100, + ) + + assert_equal "Fallback", result[:content][:text] + end + test "ServerContext delegates to custom object context" do context = Object.new def context.custom_method diff --git a/test/mcp/server_sampling_test.rb b/test/mcp/server_sampling_test.rb new file mode 100644 index 00000000..57c488dc --- /dev/null +++ b/test/mcp/server_sampling_test.rb @@ -0,0 +1,398 @@ +# frozen_string_literal: true + +require "test_helper" + +module MCP + class ServerSamplingTest < ActiveSupport::TestCase + include InstrumentationTestHelper + + class MockTransport < Transport + attr_reader :requests + + def initialize(server) + super + @requests = [] + end + + def send_request(method, params = nil) + @requests << { method: method, params: params } + { + role: "assistant", + content: { type: "text", text: "Response from LLM" }, + model: "test-model", + stopReason: "endTurn", + } + end + + def send_response(response); end + def send_notification(method, params = nil); end + def open; end + def close; end + end + + setup do + configuration = MCP::Configuration.new + configuration.instrumentation_callback = instrumentation_helper.callback + + @server = Server.new( + name: "test_server", + version: "1.0.0", + configuration: configuration, + ) + + @mock_transport = MockTransport.new(@server) + @server.transport = @mock_transport + + # Simulate client initialization with sampling capability. + @server.handle({ + jsonrpc: "2.0", + method: "initialize", + id: 1, + params: { + protocolVersion: "2025-11-25", + capabilities: { sampling: {} }, + clientInfo: { name: "test-client", version: "1.0" }, + }, + }) + end + + test "create_sampling_message sends request with required params" do + result = @server.create_sampling_message( + messages: [{ role: "user", content: { type: "text", text: "Hello" } }], + max_tokens: 100, + ) + + assert_equal 1, @mock_transport.requests.size + request = @mock_transport.requests.first + assert_equal Methods::SAMPLING_CREATE_MESSAGE, request[:method] + assert_equal 100, request[:params][:maxTokens] + assert_equal [{ role: "user", content: { type: "text", text: "Hello" } }], request[:params][:messages] + + assert_equal "assistant", result[:role] + assert_equal "Response from LLM", result[:content][:text] + end + + test "create_sampling_message sends all optional params" do + @server.create_sampling_message( + messages: [{ role: "user", content: { type: "text", text: "Hello" } }], + max_tokens: 100, + system_prompt: "You are helpful", + model_preferences: { intelligencePriority: 0.8 }, + include_context: "none", + temperature: 0.7, + stop_sequences: ["STOP"], + metadata: { key: "value" }, + ) + + request = @mock_transport.requests.first + params = request[:params] + + assert_equal "You are helpful", params[:systemPrompt] + assert_equal({ intelligencePriority: 0.8 }, params[:modelPreferences]) + assert_equal "none", params[:includeContext] + assert_equal 0.7, params[:temperature] + assert_equal ["STOP"], params[:stopSequences] + assert_equal({ key: "value" }, params[:metadata]) + end + + test "create_sampling_message raises error when transport is not set" do + server_without_transport = Server.new(name: "test", version: "1.0") + + # Initialize with sampling capability but no transport. + server_without_transport.handle({ + jsonrpc: "2.0", + method: "initialize", + id: 1, + params: { + protocolVersion: "2025-11-25", + capabilities: { sampling: {} }, + clientInfo: { name: "test-client", version: "1.0" }, + }, + }) + + error = assert_raises(RuntimeError) do + server_without_transport.create_sampling_message( + messages: [{ role: "user", content: { type: "text", text: "Hello" } }], + max_tokens: 100, + ) + end + + assert_equal("Cannot send sampling request without a transport.", error.message) + end + + test "create_sampling_message raises error when client does not support sampling" do + # Re-initialize without sampling capability. + @server.handle({ + jsonrpc: "2.0", + method: "initialize", + id: 2, + params: { + protocolVersion: "2025-11-25", + capabilities: {}, + clientInfo: { name: "test-client", version: "1.0" }, + }, + }) + + error = assert_raises(RuntimeError) do + @server.create_sampling_message( + messages: [{ role: "user", content: { type: "text", text: "Hello" } }], + max_tokens: 100, + ) + end + + assert_equal("Client does not support sampling.", error.message) + end + + test "create_sampling_message raises error when tools used but client lacks sampling.tools" do + error = assert_raises(RuntimeError) do + @server.create_sampling_message( + messages: [{ role: "user", content: { type: "text", text: "Hello" } }], + max_tokens: 100, + tools: [{ name: "test_tool", inputSchema: { type: "object" } }], + ) + end + + assert_equal("Client does not support sampling with tools.", error.message) + end + + test "create_sampling_message raises error when tool_choice used alone but client lacks sampling.tools" do + error = assert_raises(RuntimeError) do + @server.create_sampling_message( + messages: [{ role: "user", content: { type: "text", text: "Hello" } }], + max_tokens: 100, + tool_choice: { mode: "auto" }, + ) + end + + assert_equal("Client does not support sampling with tool_choice.", error.message) + end + + test "create_sampling_message allows tools when client has sampling.tools capability" do + # Re-initialize with sampling.tools capability. + @server.handle({ + jsonrpc: "2.0", + method: "initialize", + id: 3, + params: { + protocolVersion: "2025-11-25", + capabilities: { sampling: { tools: {} } }, + clientInfo: { name: "test-client", version: "1.0" }, + }, + }) + + result = @server.create_sampling_message( + messages: [{ role: "user", content: { type: "text", text: "Hello" } }], + max_tokens: 100, + tools: [{ name: "test_tool", inputSchema: { type: "object" } }], + tool_choice: { mode: "auto" }, + ) + + request = @mock_transport.requests.first + params = request[:params] + + assert_equal [{ name: "test_tool", inputSchema: { type: "object" } }], params[:tools] + assert_equal({ mode: "auto" }, params[:toolChoice]) + assert_equal "Response from LLM", result[:content][:text] + end + + test "init with sampling capability allows create_sampling_message" do + server = Server.new(name: "test", version: "1.0") + mock_transport = MockTransport.new(server) + server.transport = mock_transport + + server.handle({ + jsonrpc: "2.0", + method: "initialize", + id: 1, + params: { + protocolVersion: "2025-11-25", + capabilities: { sampling: { tools: {} } }, + clientInfo: { name: "test-client", version: "1.0" }, + }, + }) + + result = server.create_sampling_message( + messages: [{ role: "user", content: { type: "text", text: "Hello" } }], + max_tokens: 100, + tools: [{ name: "t", inputSchema: { type: "object" } }], + ) + + assert_equal "assistant", result[:role] + end + + test "init without capabilities rejects create_sampling_message" do + server = Server.new(name: "test", version: "1.0") + mock_transport = MockTransport.new(server) + server.transport = mock_transport + + server.handle({ + jsonrpc: "2.0", + method: "initialize", + id: 1, + params: { + protocolVersion: "2025-11-25", + clientInfo: { name: "test-client", version: "1.0" }, + }, + }) + + error = assert_raises(RuntimeError) do + server.create_sampling_message( + messages: [{ role: "user", content: { type: "text", text: "Hello" } }], + max_tokens: 100, + ) + end + + assert_equal("Client does not support sampling.", error.message) + end + + test "create_sampling_message uses per-session capabilities via ServerSession" do + transport = MCP::Server::Transports::StreamableHTTPTransport.new(@server) + @server.transport = transport + + # Session with sampling capability passes validation (fails at send_request due to no stream). + session_with_sampling = ServerSession.new(server: @server, transport: transport, session_id: "s1") + session_with_sampling.store_client_info(client: { name: "capable" }, capabilities: { sampling: {} }) + transport.instance_variable_get(:@sessions)["s1"] = { stream: nil, server_session: session_with_sampling } + + error_with_sampling = assert_raises(RuntimeError) do + session_with_sampling.create_sampling_message( + messages: [{ role: "user", content: { type: "text", text: "Hello" } }], + max_tokens: 100, + ) + end + assert_equal("No active SSE stream for sampling/createMessage request.", error_with_sampling.message) + + # Session without sampling capability should be rejected. + session_without_sampling = ServerSession.new(server: @server, transport: transport, session_id: "s2") + session_without_sampling.store_client_info(client: { name: "incapable" }, capabilities: {}) + + error = assert_raises(RuntimeError) do + session_without_sampling.create_sampling_message( + messages: [{ role: "user", content: { type: "text", text: "Hello" } }], + max_tokens: 100, + ) + end + assert_equal("Client does not support sampling.", error.message) + end + + test "ServerSession#client_capabilities falls back to server global capabilities" do + transport = MCP::Server::Transports::StreamableHTTPTransport.new(@server) + @server.transport = transport + + # Session without capabilities stored falls back to @server.client_capabilities. + session = ServerSession.new(server: @server, transport: transport, session_id: "s3") + transport.instance_variable_get(:@sessions)["s3"] = { stream: nil, server_session: session } + + # Server was initialized with sampling capability in setup, so fallback should pass validation. + error = assert_raises(RuntimeError) do + session.create_sampling_message( + messages: [{ role: "user", content: { type: "text", text: "Hello" } }], + max_tokens: 100, + ) + end + assert_equal("No active SSE stream for sampling/createMessage request.", error.message) + end + + test "session init does not overwrite server global client_capabilities" do + server = Server.new(name: "test", version: "1.0") + mock_transport = MockTransport.new(server) + server.transport = mock_transport + + # Non-session init sets global capabilities. + server.handle({ + jsonrpc: "2.0", + method: "initialize", + id: 1, + params: { + protocolVersion: "2025-11-25", + capabilities: { sampling: {} }, + clientInfo: { name: "first-client", version: "1.0" }, + }, + }) + + assert_equal({ sampling: {} }, server.client_capabilities) + + # Session-scoped init must NOT overwrite global capabilities. + transport = MCP::Server::Transports::StreamableHTTPTransport.new(server) + server.transport = transport + session = ServerSession.new(server: server, transport: transport, session_id: "s1") + + server.handle( + { + jsonrpc: "2.0", + method: "initialize", + id: 2, + params: { + protocolVersion: "2025-11-25", + capabilities: {}, + clientInfo: { name: "second-client", version: "1.0" }, + }, + }, + session: session, + ) + + # Global must still have sampling. + assert_equal({ sampling: {} }, server.client_capabilities) + # Session must have its own (empty) capabilities. + assert_equal({}, session.client_capabilities) + end + + test "Server#create_sampling_message does not see session-scoped capabilities from HTTP init" do + server = Server.new(name: "test", version: "1.0") + transport = MCP::Server::Transports::StreamableHTTPTransport.new(server) + server.transport = transport + + # HTTP init stores capabilities on the session, not on the server. + session = ServerSession.new(server: server, transport: transport, session_id: "s1") + server.handle( + { + jsonrpc: "2.0", + method: "initialize", + id: 1, + params: { + protocolVersion: "2025-11-25", + capabilities: { sampling: {} }, + clientInfo: { name: "http-client", version: "1.0" }, + }, + }, + session: session, + ) + + # Server-level API should not see session-scoped capabilities. + error = assert_raises(RuntimeError) do + server.create_sampling_message( + messages: [{ role: "user", content: { type: "text", text: "Hello" } }], + max_tokens: 100, + ) + end + assert_equal("Client does not support sampling.", error.message) + + # Session-scoped API should work (fails at transport level, not capability). + transport.instance_variable_get(:@sessions)["s1"] = { stream: nil, server_session: session } + error = assert_raises(RuntimeError) do + session.create_sampling_message( + messages: [{ role: "user", content: { type: "text", text: "Hello" } }], + max_tokens: 100, + ) + end + assert_equal("No active SSE stream for sampling/createMessage request.", error.message) + end + + test "create_sampling_message omits nil optional params" do + @server.create_sampling_message( + messages: [{ role: "user", content: { type: "text", text: "Hello" } }], + max_tokens: 100, + system_prompt: nil, + temperature: nil, + ) + + request = @mock_transport.requests.first + params = request[:params] + + refute params.key?(:systemPrompt) + refute params.key?(:temperature) + assert params.key?(:messages) + assert params.key?(:maxTokens) + end + end +end