Skip to content

Commit befd894

Browse files
committed
feat(p8.2): MCP v3 sampling 模块 - server 借用 client LLM
新增 src/dbjavagenix/mcp_apps/sampling.py。 sampling 是 elicitation 的镜像 — 不是问用户,是问客户端的 LLM。 server 通过 sampling/createMessage 反向调 client,client 决定是否同意 (通常会请求用户授权),用它自己的 LLM 跑,把结果返回。 3 个核心: - ModelPreferences: 软偏好 (intelligence / speed / cost priority + hints) 序列化成 modelPreferences 字段 - build_sampling_request: 构造 sampling/createMessage payload - SamplingClient: 注入 dispatcher 的薄适配器,服务端代码可以用 `await client.complete(msg)` 像调 anthropic SDK 那样调 最实际用途: ai_infer_business_names 在没有 ANTHROPIC_API_KEY 的 CI 环境 也能跑 LLM 增强 — 通过 client 的 LLM 配额。 15 个 unit test 覆盖 ModelPreferences 序列化 / sampling request 构造 / 错误输入 / 同步+异步 dispatcher / 各种响应形态提取。 Why: ADR-004 (规则优先 LLM 可选) 当前要求 ANTHROPIC_API_KEY。sampling 让 LLM 路径在 CI/无 key 环境也能用,且把 LLM 成本和管控交给 client。
1 parent 3636566 commit befd894

2 files changed

Lines changed: 272 additions & 0 deletions

File tree

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
"""MCP v3 (2025-06-18 spec) sampling helper.
2+
3+
`sampling` is the *inverse* of regular tool calls: the server asks the client
4+
to run an LLM inference on its behalf. The client decides whether to comply
5+
(may prompt the user for approval) and uses its own LLM/budget.
6+
7+
Wire format:
8+
```
9+
{
10+
"method": "sampling/createMessage",
11+
"params": {
12+
"messages": [...],
13+
"systemPrompt": "...",
14+
"modelPreferences": {...},
15+
"maxTokens": 1024
16+
}
17+
}
18+
```
19+
20+
This module:
21+
1. Builds the request payload (`build_sampling_request`)
22+
2. Provides a thin Python adapter `SamplingClient` that mirrors a subset of
23+
the anthropic SDK surface so server-side code can swap between direct API
24+
calls and sampling with minimal changes.
25+
26+
Use case in DBJavaGenix:
27+
- `ai_infer_business_names` would normally call Anthropic API directly
28+
(requires ANTHROPIC_API_KEY).
29+
- With sampling, the same call goes to the client; CI/offline environments
30+
without an API key can still get LLM-enhanced inference via the client
31+
(which has its own model + budget).
32+
33+
Compatibility (2026-06):
34+
- Claude Desktop 4.6+: full support
35+
- Claude Code 2.x: full support
36+
- Cherry Studio: not yet (planned)
37+
- Cursor / Continue.dev: not yet
38+
"""
39+
40+
from __future__ import annotations
41+
42+
from dataclasses import dataclass, field
43+
from typing import Any
44+
45+
46+
@dataclass
47+
class ModelPreferences:
48+
"""Soft preferences for which client model to use.
49+
50+
The client is free to ignore these (e.g. force its own default).
51+
"""
52+
53+
intelligence_priority: float = 0.5
54+
speed_priority: float = 0.5
55+
cost_priority: float = 0.0
56+
hints: list[str] = field(default_factory=list)
57+
58+
def to_dict(self) -> dict[str, Any]:
59+
out: dict[str, Any] = {
60+
"intelligencePriority": self.intelligence_priority,
61+
"speedPriority": self.speed_priority,
62+
"costPriority": self.cost_priority,
63+
}
64+
if self.hints:
65+
out["hints"] = [{"name": h} for h in self.hints]
66+
return out
67+
68+
69+
def build_sampling_request(
70+
user_message: str,
71+
system_prompt: str | None = None,
72+
max_tokens: int = 1024,
73+
model_prefs: ModelPreferences | None = None,
74+
) -> dict[str, Any]:
75+
"""Build an MCP sampling/createMessage request payload.
76+
77+
Args:
78+
user_message: the message the server wants the LLM to respond to
79+
system_prompt: optional system instruction
80+
max_tokens: max output tokens
81+
model_prefs: optional client-side model selection hints
82+
83+
Returns:
84+
Dict ready to be sent as `params` of sampling/createMessage.
85+
"""
86+
if not user_message:
87+
raise ValueError("user_message must be non-empty")
88+
if max_tokens <= 0 or max_tokens > 8192:
89+
raise ValueError("max_tokens out of range (1-8192)")
90+
91+
payload: dict[str, Any] = {
92+
"messages": [
93+
{
94+
"role": "user",
95+
"content": {"type": "text", "text": user_message},
96+
}
97+
],
98+
"maxTokens": max_tokens,
99+
}
100+
if system_prompt:
101+
payload["systemPrompt"] = system_prompt
102+
if model_prefs:
103+
payload["modelPreferences"] = model_prefs.to_dict()
104+
return payload
105+
106+
107+
class SamplingClient:
108+
"""A thin Python adapter that wraps a sampling-dispatch callable.
109+
110+
The real MCP server transport injects the dispatcher (it knows how to
111+
serialize and forward to the client). This adapter lets server-side code
112+
write `client.complete(...)` the same way it would call the anthropic SDK,
113+
without depending on the actual transport.
114+
"""
115+
116+
def __init__(self, dispatcher):
117+
"""Args:
118+
dispatcher: callable(payload_dict) -> response_dict, async or sync.
119+
Returns the response.content[0].text from the client's LLM.
120+
"""
121+
if not callable(dispatcher):
122+
raise TypeError("dispatcher must be callable")
123+
self._dispatch = dispatcher
124+
125+
async def complete(
126+
self,
127+
user_message: str,
128+
system_prompt: str | None = None,
129+
max_tokens: int = 1024,
130+
model_prefs: ModelPreferences | None = None,
131+
) -> str:
132+
"""Send a sampling request and return the text response."""
133+
payload = build_sampling_request(
134+
user_message=user_message,
135+
system_prompt=system_prompt,
136+
max_tokens=max_tokens,
137+
model_prefs=model_prefs,
138+
)
139+
result = self._dispatch(payload)
140+
if hasattr(result, "__await__"):
141+
result = await result
142+
return _extract_text(result)
143+
144+
145+
def _extract_text(response: Any) -> str:
146+
"""Pull the response text out of various plausible response shapes."""
147+
if response is None:
148+
return ""
149+
if isinstance(response, str):
150+
return response
151+
if isinstance(response, dict):
152+
content = response.get("content")
153+
if isinstance(content, list) and content:
154+
first = content[0]
155+
if isinstance(first, dict) and first.get("type") == "text":
156+
return str(first.get("text", ""))
157+
if "text" in response:
158+
return str(response["text"])
159+
return str(response)

tests/unit/test_sampling.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
"""Unit tests for mcp_apps.sampling."""
2+
import asyncio
3+
4+
import pytest
5+
6+
from dbjavagenix.mcp_apps.sampling import (
7+
ModelPreferences,
8+
SamplingClient,
9+
build_sampling_request,
10+
)
11+
12+
13+
def _run(coro):
14+
return asyncio.get_event_loop().run_until_complete(coro)
15+
16+
17+
class TestModelPreferences:
18+
def test_default_serialization(self):
19+
p = ModelPreferences()
20+
d = p.to_dict()
21+
assert d["intelligencePriority"] == 0.5
22+
assert d["speedPriority"] == 0.5
23+
assert "hints" not in d
24+
25+
def test_hints_serialized_as_name_objects(self):
26+
p = ModelPreferences(hints=["claude-sonnet-4-6"])
27+
d = p.to_dict()
28+
assert d["hints"] == [{"name": "claude-sonnet-4-6"}]
29+
30+
def test_custom_priorities(self):
31+
p = ModelPreferences(intelligence_priority=0.9, cost_priority=0.1)
32+
d = p.to_dict()
33+
assert d["intelligencePriority"] == 0.9
34+
assert d["costPriority"] == 0.1
35+
36+
37+
class TestBuildSamplingRequest:
38+
def test_minimal(self):
39+
r = build_sampling_request("infer names for sys_user")
40+
assert r["maxTokens"] == 1024
41+
assert len(r["messages"]) == 1
42+
assert r["messages"][0]["role"] == "user"
43+
assert "systemPrompt" not in r
44+
assert "modelPreferences" not in r
45+
46+
def test_with_system_prompt(self):
47+
r = build_sampling_request("hello", system_prompt="You are an expert.")
48+
assert r["systemPrompt"] == "You are an expert."
49+
50+
def test_with_model_prefs(self):
51+
r = build_sampling_request("hello", model_prefs=ModelPreferences(intelligence_priority=0.9))
52+
assert r["modelPreferences"]["intelligencePriority"] == 0.9
53+
54+
def test_empty_message_raises(self):
55+
with pytest.raises(ValueError):
56+
build_sampling_request("")
57+
58+
def test_invalid_max_tokens(self):
59+
with pytest.raises(ValueError):
60+
build_sampling_request("hi", max_tokens=0)
61+
with pytest.raises(ValueError):
62+
build_sampling_request("hi", max_tokens=10000)
63+
64+
65+
class TestSamplingClient:
66+
def test_sync_dispatcher(self):
67+
captured = {}
68+
69+
def dispatcher(payload):
70+
captured["payload"] = payload
71+
return {"content": [{"type": "text", "text": "ok"}]}
72+
73+
client = SamplingClient(dispatcher)
74+
result = _run(client.complete("ping"))
75+
assert result == "ok"
76+
assert captured["payload"]["messages"][0]["content"]["text"] == "ping"
77+
78+
def test_async_dispatcher(self):
79+
async def dispatcher(payload):
80+
return {"content": [{"type": "text", "text": "async-ok"}]}
81+
82+
client = SamplingClient(dispatcher)
83+
result = _run(client.complete("ping"))
84+
assert result == "async-ok"
85+
86+
def test_non_callable_raises(self):
87+
with pytest.raises(TypeError):
88+
SamplingClient("not a function")
89+
90+
def test_extracts_text_from_string_response(self):
91+
client = SamplingClient(lambda p: "plain string")
92+
result = _run(client.complete("ping"))
93+
assert result == "plain string"
94+
95+
def test_extracts_text_from_text_key(self):
96+
client = SamplingClient(lambda p: {"text": "from-text-key"})
97+
result = _run(client.complete("ping"))
98+
assert result == "from-text-key"
99+
100+
def test_none_response_returns_empty(self):
101+
client = SamplingClient(lambda p: None)
102+
result = _run(client.complete("ping"))
103+
assert result == ""
104+
105+
def test_propagates_system_prompt(self):
106+
captured = {}
107+
def dispatcher(p):
108+
captured["p"] = p
109+
return {"text": "ok"}
110+
111+
client = SamplingClient(dispatcher)
112+
_run(client.complete("ping", system_prompt="sys"))
113+
assert captured["p"]["systemPrompt"] == "sys"

0 commit comments

Comments
 (0)