1- import { AnthropicProviderOptions } from '@ai-sdk/anthropic' ;
2- import { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google' ;
3- import { OpenAIResponsesProviderOptions } from '@ai-sdk/openai' ;
41import {
52 FilePart ,
63 generateText ,
7- LanguageModel ,
84 ModelMessage ,
95 Output ,
106 SystemModelMessage ,
117 TextPart ,
8+ ToolSet ,
129} from 'ai' ;
10+ import { createMCPClient , MCPClient } from '@ai-sdk/mcp' ;
11+ import { Experimental_StdioMCPTransport as StdioClientTransport } from '@ai-sdk/mcp/mcp-stdio' ;
1312import z from 'zod' ;
1413import { combineAbortSignals } from '../../utils/abort-signal.js' ;
1514import { callWithTimeout } from '../../utils/timeout.js' ;
@@ -21,24 +20,33 @@ import {
2120 LocalLlmGenerateFilesResponse ,
2221 LocalLlmGenerateTextRequestOptions ,
2322 LocalLlmGenerateTextResponse ,
23+ McpServerDetails ,
24+ McpServerOptions ,
2425 PromptDataMessage ,
2526} from '../llm-runner.js' ;
2627import { ANTHROPIC_MODELS , getAiSdkModelOptionsForAnthropic } from './anthropic.js' ;
2728import { getAiSdkModelOptionsForGoogle , GOOGLE_MODELS } from './google.js' ;
2829import { getAiSdkModelOptionsForOpenAI , OPENAI_MODELS } from './openai.js' ;
2930import { AiSdkModelOptions } from './ai-sdk-model-options.js' ;
31+ import { getAiSdkModelOptionsForXai , XAI_MODELS } from './xai.js' ;
3032
31- const SUPPORTED_MODELS = [ ...GOOGLE_MODELS , ...ANTHROPIC_MODELS , ...OPENAI_MODELS ] as const ;
33+ const SUPPORTED_MODELS = [
34+ ...GOOGLE_MODELS ,
35+ ...ANTHROPIC_MODELS ,
36+ ...OPENAI_MODELS ,
37+ ...XAI_MODELS ,
38+ ] as const ;
3239
3340// Increased to a very high value as we rely on an actual timeout
3441// that aborts stuck LLM requests. WCS is targeting stability here;
3542// even if it involves many exponential backoff-waiting.
3643const DEFAULT_MAX_RETRIES = 100000 ;
3744
3845export class AiSdkRunner implements LlmRunner {
39- displayName = 'AI SDK' ;
40- id = 'ai-sdk' ;
41- hasBuiltInRepairLoop = true ;
46+ readonly displayName = 'AI SDK' ;
47+ readonly id = 'ai-sdk' ;
48+ readonly hasBuiltInRepairLoop = true ;
49+ private mcpClients : MCPClient [ ] | null = null ;
4250
4351 async generateText (
4452 options : LocalLlmGenerateTextRequestOptions ,
@@ -49,6 +57,7 @@ export class AiSdkRunner implements LlmRunner {
4957 abortSignal : abortSignal ,
5058 messages : this . convertRequestToMessagesList ( options ) ,
5159 maxRetries : DEFAULT_MAX_RETRIES ,
60+ tools : await this . getTools ( ) ,
5261 } ) ,
5362 ) ;
5463
@@ -75,6 +84,7 @@ export class AiSdkRunner implements LlmRunner {
7584 output : Output . object < z . infer < T > > ( { schema : options . schema } ) ,
7685 abortSignal : abortSignal ,
7786 maxRetries : DEFAULT_MAX_RETRIES ,
87+ tools : await this . getTools ( ) ,
7888 } ) ,
7989 ) ;
8090
@@ -120,7 +130,42 @@ export class AiSdkRunner implements LlmRunner {
120130 return [ ...SUPPORTED_MODELS ] ;
121131 }
122132
123- async dispose ( ) : Promise < void > { }
133+ async dispose ( ) : Promise < void > {
134+ if ( this . mcpClients ) {
135+ for ( const client of this . mcpClients ) {
136+ try {
137+ await client . close ( ) ;
138+ } catch ( error ) {
139+ console . error ( `Failed to close MCP client` , error ) ;
140+ }
141+ }
142+ }
143+ }
144+
145+ async startMcpServerHost (
146+ _hostName : string ,
147+ servers : McpServerOptions [ ] ,
148+ ) : Promise < McpServerDetails > {
149+ const details : McpServerDetails = { resources : [ ] , tools : [ ] } ;
150+
151+ for ( const server of servers ) {
152+ const client = await createMCPClient ( {
153+ transport : new StdioClientTransport ( {
154+ command : server . command ,
155+ args : server . args ,
156+ env : server . env ,
157+ } ) ,
158+ } ) ;
159+
160+ const [ resources , tools ] = await Promise . all ( [ client . listResources ( ) , client . tools ( ) ] ) ;
161+ resources . resources . forEach ( r => details . resources . push ( r . name ) ) ;
162+ details . tools . push ( ...Object . keys ( tools ) ) ;
163+ this . mcpClients ??= [ ] ;
164+ this . mcpClients . push ( client ) ;
165+ }
166+
167+ return details ;
168+ }
124169
125170 private async _wrapRequestWithTimeoutAndRateLimiting < T > (
126171 request : LocalLlmGenerateTextRequestOptions | LocalLlmConstrainedOutputGenerateRequestOptions ,
@@ -145,7 +190,8 @@ export class AiSdkRunner implements LlmRunner {
145190 const result =
146191 ( await getAiSdkModelOptionsForGoogle ( request . model ) ) ??
147192 ( await getAiSdkModelOptionsForAnthropic ( request . model ) ) ??
148- ( await getAiSdkModelOptionsForOpenAI ( request . model ) ) ;
193+ ( await getAiSdkModelOptionsForOpenAI ( request . model ) ) ??
194+ ( await getAiSdkModelOptionsForXai ( request . model ) ) ;
149195 if ( result === null ) {
150196 throw new Error ( `Unexpected unsupported model: ${ request . model } ` ) ;
151197 }
@@ -198,4 +244,18 @@ export class AiSdkRunner implements LlmRunner {
198244 }
199245 return result ;
200246 }
247+
248+ private async getTools ( ) : Promise < ToolSet | undefined > {
249+ let tools : ToolSet | undefined ;
250+
251+ if ( this . mcpClients ) {
252+ for ( const client of this . mcpClients ) {
253+ const clientTools = ( await client . tools ( ) ) as ToolSet ;
254+ tools ??= { } ;
255+ Object . keys ( clientTools ) . forEach ( name => ( tools ! [ name ] = clientTools [ name ] ) ) ;
256+ }
257+ }
258+
259+ return tools ;
260+ }
201261}
0 commit comments