@@ -11,6 +11,7 @@ import (
1111 "sync"
1212
1313 "github.com/modelcontextprotocol/go-sdk/jsonrpc"
14+ "golang.org/x/sync/errgroup"
1415)
1516
1617const defaultMRTRMaxRetries = 3
@@ -37,6 +38,9 @@ type mrtrResult interface {
3738}
3839
3940func handleMRTRResult (ss * ServerSession , logger * slog.Logger , res mrtrResult ) error {
41+ if res == nil {
42+ return nil
43+ }
4044 hasInputRequests := res .inputRequests () != nil
4145
4246 if hasInputRequests && res .hasContent () {
@@ -169,28 +173,23 @@ func serverMRTRInputRequests(res Result) InputRequestMap {
169173}
170174
171175func fulfillServerInputRequests (ctx context.Context , ss * ServerSession , requests InputRequestMap ) (InputResponseMap , error ) {
172- type result struct {
173- id string
174- resp InputResponse
175- err error
176- }
177- results := make (chan result , len (requests ))
178- var wg sync.WaitGroup
176+ g , ctx := errgroup .WithContext (ctx )
177+ var mu sync.Mutex
178+ responses := make (InputResponseMap , len (requests ))
179179 for id , ir := range requests {
180- wg .Go (func () {
180+ g .Go (func () error {
181181 resp , err := fulfillServerInputRequest (ctx , ss , ir )
182- results <- result {id , resp , err }
182+ if err != nil {
183+ return fmt .Errorf ("fulfilling input request %q: %w" , id , err )
184+ }
185+ mu .Lock ()
186+ responses [id ] = resp
187+ mu .Unlock ()
188+ return nil
183189 })
184190 }
185- wg .Wait ()
186- close (results )
187-
188- responses := make (InputResponseMap , len (requests ))
189- for r := range results {
190- if r .err != nil {
191- return nil , fmt .Errorf ("MRTR: fulfilling input request %q: %w" , r .id , r .err )
192- }
193- responses [r .id ] = r .resp
191+ if err := g .Wait (); err != nil {
192+ return nil , fmt .Errorf ("MRTR: %w" , err )
194193 }
195194 return responses , nil
196195}
@@ -200,14 +199,34 @@ func fulfillServerInputRequest(ctx context.Context, ss *ServerSession, ir InputR
200199 case * ElicitParams :
201200 return ss .Elicit (ctx , p )
202201 case * CreateMessageParams :
203- return ss .CreateMessage (ctx , p )
202+ return ss .CreateMessageWithTools (ctx , createMessageParamsToWithTools (p ))
203+ case * CreateMessageWithToolsParams :
204+ return ss .CreateMessageWithTools (ctx , p )
204205 case * ListRootsParams :
205206 return ss .ListRoots (ctx , p )
206207 default :
207208 return nil , fmt .Errorf ("unknown input request type: %T" , ir )
208209 }
209210}
210211
212+ func createMessageParamsToWithTools (p * CreateMessageParams ) * CreateMessageWithToolsParams {
213+ var msgs []* SamplingMessageV2
214+ for _ , m := range p .Messages {
215+ msgs = append (msgs , & SamplingMessageV2 {Content : []Content {m .Content }, Role : m .Role })
216+ }
217+ return & CreateMessageWithToolsParams {
218+ Meta : p .Meta ,
219+ IncludeContext : p .IncludeContext ,
220+ MaxTokens : p .MaxTokens ,
221+ Messages : msgs ,
222+ Metadata : p .Metadata ,
223+ ModelPreferences : p .ModelPreferences ,
224+ StopSequences : p .StopSequences ,
225+ SystemPrompt : p .SystemPrompt ,
226+ Temperature : p .Temperature ,
227+ }
228+ }
229+
211230func mrtrInputRequests (res Result ) InputRequestMap {
212231 if res == nil {
213232 return nil
@@ -262,28 +281,23 @@ func setMRTRRetryParams(req Request, responses InputResponseMap, state string) {
262281}
263282
264283func fulfillInputRequests (ctx context.Context , cs * ClientSession , requests InputRequestMap ) (InputResponseMap , error ) {
265- type result struct {
266- id string
267- resp InputResponse
268- err error
269- }
270- results := make (chan result , len (requests ))
271- var wg sync.WaitGroup
284+ g , ctx := errgroup .WithContext (ctx )
285+ var mu sync.Mutex
286+ responses := make (InputResponseMap , len (requests ))
272287 for id , ir := range requests {
273- wg .Go (func () {
288+ g .Go (func () error {
274289 resp , err := fulfillInputRequest (ctx , cs , ir )
275- results <- result {id , resp , err }
290+ if err != nil {
291+ return fmt .Errorf ("fulfilling input request %q: %w" , id , err )
292+ }
293+ mu .Lock ()
294+ responses [id ] = resp
295+ mu .Unlock ()
296+ return nil
276297 })
277298 }
278- wg .Wait ()
279- close (results )
280-
281- responses := make (InputResponseMap , len (requests ))
282- for r := range results {
283- if r .err != nil {
284- return nil , fmt .Errorf ("MRTR: fulfilling input request %q: %w" , r .id , r .err )
285- }
286- responses [r .id ] = r .resp
299+ if err := g .Wait (); err != nil {
300+ return nil , fmt .Errorf ("MRTR: %w" , err )
287301 }
288302 return responses , nil
289303}
@@ -293,43 +307,12 @@ func fulfillInputRequest(ctx context.Context, cs *ClientSession, ir InputRequest
293307 case * ElicitParams :
294308 return cs .client .elicit (ctx , newClientRequest (cs , p ))
295309 case * CreateMessageParams :
296- return fulfillCreateMessage (ctx , cs , p )
310+ return cs .client .createMessage (ctx , & CreateMessageWithToolsRequest {Session : cs , Params : createMessageParamsToWithTools (p )})
311+ case * CreateMessageWithToolsParams :
312+ return cs .client .createMessage (ctx , & CreateMessageWithToolsRequest {Session : cs , Params : p })
297313 case * ListRootsParams :
298314 return cs .client .listRoots (ctx , newClientRequest (cs , p ))
299315 default :
300316 return nil , fmt .Errorf ("unknown input request type: %T" , ir )
301317 }
302318}
303-
304- func fulfillCreateMessage (ctx context.Context , cs * ClientSession , p * CreateMessageParams ) (* CreateMessageResult , error ) {
305- var msgs []* SamplingMessageV2
306- for _ , m := range p .Messages {
307- msgs = append (msgs , & SamplingMessageV2 {Content : []Content {m .Content }, Role : m .Role })
308- }
309- wtp := & CreateMessageWithToolsParams {
310- Meta : p .Meta ,
311- IncludeContext : p .IncludeContext ,
312- MaxTokens : p .MaxTokens ,
313- Messages : msgs ,
314- Metadata : p .Metadata ,
315- ModelPreferences : p .ModelPreferences ,
316- StopSequences : p .StopSequences ,
317- SystemPrompt : p .SystemPrompt ,
318- Temperature : p .Temperature ,
319- }
320- result , err := cs .client .createMessage (ctx , & CreateMessageWithToolsRequest {Session : cs , Params : wtp })
321- if err != nil {
322- return nil , err
323- }
324- var content Content
325- if len (result .Content ) > 0 {
326- content = result .Content [0 ]
327- }
328- return & CreateMessageResult {
329- Meta : result .Meta ,
330- Content : content ,
331- Model : result .Model ,
332- Role : result .Role ,
333- StopReason : result .StopReason ,
334- }, nil
335- }
0 commit comments