@@ -20,6 +20,9 @@ import {
2020 Content ,
2121 CountTokensRequest ,
2222 CountTokensResponse ,
23+ FunctionCall ,
24+ FunctionDeclaration ,
25+ FunctionResponse ,
2326 GenerateContentRequest ,
2427 GenerateContentResult ,
2528 GenerateContentStreamResult ,
@@ -40,6 +43,8 @@ import { mergeRequestOptions } from '../requests/request-options';
4043import { AIModel } from './ai-model' ;
4144import { AI } from '../public-types' ;
4245
46+ const DEFAULT_MAX_SEQUENTIAL_FUNCTION_CALLS = 10 ;
47+
4348/**
4449 * Class for generative model APIs.
4550 * @public
@@ -71,19 +76,17 @@ export class GenerativeModel extends AIModel {
7176 singleRequestOptions ?: SingleRequestOptions ,
7277 ) : Promise < GenerateContentResult > {
7378 const formattedParams = formatGenerateContentInput ( request ) ;
74- return generateContent (
75- this . _apiSettings ,
76- this . model ,
77- {
78- generationConfig : this . generationConfig ,
79- safetySettings : this . safetySettings ,
80- tools : this . tools ,
81- toolConfig : this . toolConfig ,
82- systemInstruction : this . systemInstruction ,
83- ...formattedParams ,
84- } ,
85- mergeRequestOptions ( this . requestOptions , singleRequestOptions ) ,
86- ) ;
79+ const params : GenerateContentRequest = {
80+ generationConfig : this . generationConfig ,
81+ safetySettings : this . safetySettings ,
82+ tools : this . tools ,
83+ toolConfig : this . toolConfig ,
84+ systemInstruction : this . systemInstruction ,
85+ ...formattedParams ,
86+ } ;
87+ const requestOptions = mergeRequestOptions ( this . requestOptions , singleRequestOptions ) ;
88+ const result = await generateContent ( this . _apiSettings , this . model , params , requestOptions ) ;
89+ return this . _generateContentWithAutomaticFunctionCalling ( params , result , requestOptions ) ;
8790 }
8891
8992 /**
@@ -152,4 +155,92 @@ export class GenerativeModel extends AIModel {
152155 mergeRequestOptions ( this . requestOptions , singleRequestOptions ) ,
153156 ) ;
154157 }
158+
159+ private async _generateContentWithAutomaticFunctionCalling (
160+ params : GenerateContentRequest ,
161+ result : GenerateContentResult ,
162+ requestOptions ?: SingleRequestOptions ,
163+ ) : Promise < GenerateContentResult > {
164+ let remainingFunctionCalls =
165+ requestOptions ?. maxSequentialFunctionCalls ?? DEFAULT_MAX_SEQUENTIAL_FUNCTION_CALLS ;
166+ let currentParams = params ;
167+ let currentResult = result ;
168+
169+ while ( remainingFunctionCalls > 0 ) {
170+ const functionCalls = currentResult . response . functionCalls ?.( ) ;
171+ if ( ! functionCalls ?. length ) {
172+ return currentResult ;
173+ }
174+
175+ const functionResponses = await this . _callFunctionReferences (
176+ currentParams . tools ,
177+ functionCalls ,
178+ ) ;
179+ if ( ! functionResponses ) {
180+ return currentResult ;
181+ }
182+
183+ const responseContent = currentResult . response . candidates ?. [ 0 ] ?. content ;
184+ if ( ! responseContent ) {
185+ return currentResult ;
186+ }
187+
188+ remainingFunctionCalls -= 1 ;
189+ currentParams = {
190+ ...currentParams ,
191+ contents : [
192+ ...currentParams . contents ,
193+ responseContent ,
194+ {
195+ role : 'function' ,
196+ parts : functionResponses . map ( functionResponse => ( { functionResponse } ) ) ,
197+ } ,
198+ ] ,
199+ } ;
200+ currentResult = await generateContent (
201+ this . _apiSettings ,
202+ this . model ,
203+ currentParams ,
204+ requestOptions ,
205+ ) ;
206+ }
207+
208+ return currentResult ;
209+ }
210+
211+ private async _callFunctionReferences (
212+ tools : Tool [ ] | undefined ,
213+ functionCalls : FunctionCall [ ] ,
214+ ) : Promise < FunctionResponse [ ] | undefined > {
215+ const declarations = this . _getFunctionDeclarationsWithReferences ( tools ) ;
216+ if ( ! declarations . length ) {
217+ return undefined ;
218+ }
219+
220+ const functionResponses : FunctionResponse [ ] = [ ] ;
221+ for ( const functionCall of functionCalls ) {
222+ const declaration = declarations . find ( candidate => candidate . name === functionCall . name ) ;
223+ if ( ! declaration ?. functionReference ) {
224+ return undefined ;
225+ }
226+
227+ const response = ( await declaration . functionReference ( functionCall . args ) ) as object ;
228+ functionResponses . push ( {
229+ id : functionCall . id ,
230+ name : functionCall . name ,
231+ response,
232+ } ) ;
233+ }
234+ return functionResponses ;
235+ }
236+
237+ private _getFunctionDeclarationsWithReferences ( tools : Tool [ ] | undefined ) : FunctionDeclaration [ ] {
238+ return (
239+ tools ?. flatMap ( tool =>
240+ 'functionDeclarations' in tool
241+ ? ( tool . functionDeclarations ?. filter ( declaration => declaration . functionReference ) ?? [ ] )
242+ : [ ] ,
243+ ) ?? [ ]
244+ ) ;
245+ }
155246}
0 commit comments