Skip to content

Commit 60f1c70

Browse files
committed
use errgroup, handle createMessageWithTools
1 parent 7507222 commit 60f1c70

4 files changed

Lines changed: 67 additions & 80 deletions

File tree

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@ require (
1515

1616
require (
1717
github.com/segmentio/asm v1.1.3 // indirect
18+
golang.org/x/sync v0.20.0 // indirect
1819
golang.org/x/sys v0.41.0 // indirect
1920
)

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zI
1212
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
1313
golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ=
1414
golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
15+
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
16+
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
1517
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
1618
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
1719
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=

mcp/mrtr.go

Lines changed: 54 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"sync"
1212

1313
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
14+
"golang.org/x/sync/errgroup"
1415
)
1516

1617
const defaultMRTRMaxRetries = 3
@@ -37,6 +38,9 @@ type mrtrResult interface {
3738
}
3839

3940
func handleMRTRResult(ss *ServerSession, logger *slog.Logger, res mrtrResult) error {
41+
if res == nil {
42+
return nil
43+
}
4044
hasInputRequests := res.inputRequests() != nil
4145

4246
if hasInputRequests && res.hasContent() {
@@ -169,28 +173,23 @@ func serverMRTRInputRequests(res Result) InputRequestMap {
169173
}
170174

171175
func fulfillServerInputRequests(ctx context.Context, ss *ServerSession, requests InputRequestMap) (InputResponseMap, error) {
172-
type result struct {
173-
id string
174-
resp InputResponse
175-
err error
176-
}
177-
results := make(chan result, len(requests))
178-
var wg sync.WaitGroup
176+
g, ctx := errgroup.WithContext(ctx)
177+
var mu sync.Mutex
178+
responses := make(InputResponseMap, len(requests))
179179
for id, ir := range requests {
180-
wg.Go(func() {
180+
g.Go(func() error {
181181
resp, err := fulfillServerInputRequest(ctx, ss, ir)
182-
results <- result{id, resp, err}
182+
if err != nil {
183+
return fmt.Errorf("fulfilling input request %q: %w", id, err)
184+
}
185+
mu.Lock()
186+
responses[id] = resp
187+
mu.Unlock()
188+
return nil
183189
})
184190
}
185-
wg.Wait()
186-
close(results)
187-
188-
responses := make(InputResponseMap, len(requests))
189-
for r := range results {
190-
if r.err != nil {
191-
return nil, fmt.Errorf("MRTR: fulfilling input request %q: %w", r.id, r.err)
192-
}
193-
responses[r.id] = r.resp
191+
if err := g.Wait(); err != nil {
192+
return nil, fmt.Errorf("MRTR: %w", err)
194193
}
195194
return responses, nil
196195
}
@@ -200,14 +199,34 @@ func fulfillServerInputRequest(ctx context.Context, ss *ServerSession, ir InputR
200199
case *ElicitParams:
201200
return ss.Elicit(ctx, p)
202201
case *CreateMessageParams:
203-
return ss.CreateMessage(ctx, p)
202+
return ss.CreateMessageWithTools(ctx, createMessageParamsToWithTools(p))
203+
case *CreateMessageWithToolsParams:
204+
return ss.CreateMessageWithTools(ctx, p)
204205
case *ListRootsParams:
205206
return ss.ListRoots(ctx, p)
206207
default:
207208
return nil, fmt.Errorf("unknown input request type: %T", ir)
208209
}
209210
}
210211

212+
func createMessageParamsToWithTools(p *CreateMessageParams) *CreateMessageWithToolsParams {
213+
var msgs []*SamplingMessageV2
214+
for _, m := range p.Messages {
215+
msgs = append(msgs, &SamplingMessageV2{Content: []Content{m.Content}, Role: m.Role})
216+
}
217+
return &CreateMessageWithToolsParams{
218+
Meta: p.Meta,
219+
IncludeContext: p.IncludeContext,
220+
MaxTokens: p.MaxTokens,
221+
Messages: msgs,
222+
Metadata: p.Metadata,
223+
ModelPreferences: p.ModelPreferences,
224+
StopSequences: p.StopSequences,
225+
SystemPrompt: p.SystemPrompt,
226+
Temperature: p.Temperature,
227+
}
228+
}
229+
211230
func mrtrInputRequests(res Result) InputRequestMap {
212231
if res == nil {
213232
return nil
@@ -262,28 +281,23 @@ func setMRTRRetryParams(req Request, responses InputResponseMap, state string) {
262281
}
263282

264283
func fulfillInputRequests(ctx context.Context, cs *ClientSession, requests InputRequestMap) (InputResponseMap, error) {
265-
type result struct {
266-
id string
267-
resp InputResponse
268-
err error
269-
}
270-
results := make(chan result, len(requests))
271-
var wg sync.WaitGroup
284+
g, ctx := errgroup.WithContext(ctx)
285+
var mu sync.Mutex
286+
responses := make(InputResponseMap, len(requests))
272287
for id, ir := range requests {
273-
wg.Go(func() {
288+
g.Go(func() error {
274289
resp, err := fulfillInputRequest(ctx, cs, ir)
275-
results <- result{id, resp, err}
290+
if err != nil {
291+
return fmt.Errorf("fulfilling input request %q: %w", id, err)
292+
}
293+
mu.Lock()
294+
responses[id] = resp
295+
mu.Unlock()
296+
return nil
276297
})
277298
}
278-
wg.Wait()
279-
close(results)
280-
281-
responses := make(InputResponseMap, len(requests))
282-
for r := range results {
283-
if r.err != nil {
284-
return nil, fmt.Errorf("MRTR: fulfilling input request %q: %w", r.id, r.err)
285-
}
286-
responses[r.id] = r.resp
299+
if err := g.Wait(); err != nil {
300+
return nil, fmt.Errorf("MRTR: %w", err)
287301
}
288302
return responses, nil
289303
}
@@ -293,43 +307,12 @@ func fulfillInputRequest(ctx context.Context, cs *ClientSession, ir InputRequest
293307
case *ElicitParams:
294308
return cs.client.elicit(ctx, newClientRequest(cs, p))
295309
case *CreateMessageParams:
296-
return fulfillCreateMessage(ctx, cs, p)
310+
return cs.client.createMessage(ctx, &CreateMessageWithToolsRequest{Session: cs, Params: createMessageParamsToWithTools(p)})
311+
case *CreateMessageWithToolsParams:
312+
return cs.client.createMessage(ctx, &CreateMessageWithToolsRequest{Session: cs, Params: p})
297313
case *ListRootsParams:
298314
return cs.client.listRoots(ctx, newClientRequest(cs, p))
299315
default:
300316
return nil, fmt.Errorf("unknown input request type: %T", ir)
301317
}
302318
}
303-
304-
func fulfillCreateMessage(ctx context.Context, cs *ClientSession, p *CreateMessageParams) (*CreateMessageResult, error) {
305-
var msgs []*SamplingMessageV2
306-
for _, m := range p.Messages {
307-
msgs = append(msgs, &SamplingMessageV2{Content: []Content{m.Content}, Role: m.Role})
308-
}
309-
wtp := &CreateMessageWithToolsParams{
310-
Meta: p.Meta,
311-
IncludeContext: p.IncludeContext,
312-
MaxTokens: p.MaxTokens,
313-
Messages: msgs,
314-
Metadata: p.Metadata,
315-
ModelPreferences: p.ModelPreferences,
316-
StopSequences: p.StopSequences,
317-
SystemPrompt: p.SystemPrompt,
318-
Temperature: p.Temperature,
319-
}
320-
result, err := cs.client.createMessage(ctx, &CreateMessageWithToolsRequest{Session: cs, Params: wtp})
321-
if err != nil {
322-
return nil, err
323-
}
324-
var content Content
325-
if len(result.Content) > 0 {
326-
content = result.Content[0]
327-
}
328-
return &CreateMessageResult{
329-
Meta: result.Meta,
330-
Content: content,
331-
Model: result.Model,
332-
Role: result.Role,
333-
StopReason: result.StopReason,
334-
}, nil
335-
}

mcp/protocol.go

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ func (m InputRequestMap) MarshalJSON() ([]byte, error) {
4646
switch v.(type) {
4747
case *ElicitParams:
4848
return methodElicit, nil
49-
case *CreateMessageParams:
49+
case *CreateMessageParams, *CreateMessageWithToolsParams:
5050
return methodCreateMessage, nil
5151
case *ListRootsParams:
5252
return methodListRoots, nil
@@ -84,7 +84,7 @@ func (m *InputRequestMap) UnmarshalJSON(data []byte) error {
8484
}
8585
result[k] = &p
8686
case methodCreateMessage:
87-
var p CreateMessageParams
87+
var p CreateMessageWithToolsParams
8888
if err := json.Unmarshal(raw.Params, &p); err != nil {
8989
return err
9090
}
@@ -122,7 +122,7 @@ func (m InputResponseMap) MarshalJSON() ([]byte, error) {
122122
switch v.(type) {
123123
case *ElicitResult:
124124
return methodElicit, nil
125-
case *CreateMessageResult:
125+
case *CreateMessageResult, *CreateMessageWithToolsResult:
126126
return methodCreateMessage, nil
127127
case *ListRootsResult:
128128
return methodListRoots, nil
@@ -160,7 +160,7 @@ func (m *InputResponseMap) UnmarshalJSON(data []byte) error {
160160
}
161161
result[k] = &p
162162
case methodCreateMessage:
163-
var p CreateMessageResult
163+
var p CreateMessageWithToolsResult
164164
if err := json.Unmarshal(raw.Result, &p); err != nil {
165165
return err
166166
}
@@ -172,7 +172,7 @@ func (m *InputResponseMap) UnmarshalJSON(data []byte) error {
172172
}
173173
result[k] = &p
174174
default:
175-
return fmt.Errorf("unsupported InputRequest method: %q", raw.Method)
175+
return fmt.Errorf("unsupported InputResponse method: %q", raw.Method)
176176
}
177177
}
178178
*m = result
@@ -345,9 +345,8 @@ func (r *CallToolResult) GetError() error {
345345

346346
func (*CallToolResult) isResult() {}
347347

348-
func (r *CallToolResult) setResultType(rt ResultType) { r.resultType = rt }
349-
func (r *CallToolResult) inputRequests() map[string]InputRequest { return r.InputRequests }
350-
func (r *CallToolResult) setInputRequest(k string, v InputRequest) { r.InputRequests[k] = v }
348+
func (r *CallToolResult) setResultType(rt ResultType) { r.resultType = rt }
349+
func (r *CallToolResult) inputRequests() map[string]InputRequest { return r.InputRequests }
351350
func (r *CallToolResult) hasContent() bool {
352351
return len(r.Content) > 0 || r.StructuredContent != nil
353352
}
@@ -672,6 +671,7 @@ type CreateMessageWithToolsParams struct {
672671
}
673672

674673
func (x *CreateMessageWithToolsParams) isParams() {}
674+
func (x *CreateMessageWithToolsParams) isInputRequest() {}
675675
func (x *CreateMessageWithToolsParams) GetProgressToken() any { return getProgressToken(x) }
676676
func (x *CreateMessageWithToolsParams) SetProgressToken(t any) { setProgressToken(x, t) }
677677

@@ -817,7 +817,8 @@ var createMessageWithToolsResultAllow = map[string]bool{
817817
"tool_use": true,
818818
}
819819

820-
func (*CreateMessageWithToolsResult) isResult() {}
820+
func (*CreateMessageWithToolsResult) isResult() {}
821+
func (*CreateMessageWithToolsResult) isInputResponse() {}
821822

822823
// MarshalJSON marshals the result. When Content has a single element, it is
823824
// marshaled as a single object for compatibility with pre-2025-11-25

0 commit comments

Comments
 (0)