Skip to content

Commit 2b2bbd3

Browse files
committed
feat(ai): support generateContent function auto-calling
1 parent f1b9eff commit 2b2bbd3

2 files changed

Lines changed: 254 additions & 13 deletions

File tree

packages/ai/__tests__/generative-model.test.ts

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import { AI, FunctionCallingMode, ThinkingLevel } from '../lib/public-types';
2121
import * as request from '../lib/requests/request';
2222
import { BackendName, getMockResponse } from './test-utils/mock-response';
2323
import { VertexAIBackend } from '../lib/backend';
24+
import { GenerateContentResponse } from '../lib';
2425

2526
const fakeAI: AI = {
2627
app: {
@@ -36,6 +37,10 @@ const fakeAI: AI = {
3637
location: 'us-central1',
3738
};
3839

40+
function responseFromJson(json: GenerateContentResponse): Response {
41+
return { json: async () => json } as Response;
42+
}
43+
3944
describe('GenerativeModel', () => {
4045
it('passes CodeExecutionTool and URLContextTool with other tools through to generateContent', async function () {
4146
const genModel = new GenerativeModel(fakeAI, {
@@ -362,6 +367,151 @@ describe('GenerativeModel', () => {
362367
makeRequestStub.mockRestore();
363368
});
364369

370+
it('automatically calls functionReference from generateContent function calls', async () => {
371+
const getWeather = jest.fn<(args: object) => object>().mockReturnValue({ temperature: 72 });
372+
const genModel = new GenerativeModel(fakeAI, {
373+
model: 'my-model',
374+
tools: [
375+
{
376+
functionDeclarations: [
377+
{
378+
name: 'getWeather',
379+
description: 'Gets weather for a city.',
380+
functionReference: getWeather,
381+
},
382+
],
383+
},
384+
],
385+
});
386+
const makeRequestStub = jest
387+
.spyOn(request, 'makeRequest')
388+
.mockResolvedValueOnce(
389+
responseFromJson({
390+
candidates: [
391+
{
392+
index: 0,
393+
content: {
394+
role: 'model',
395+
parts: [
396+
{
397+
functionCall: {
398+
id: 'call-1',
399+
name: 'getWeather',
400+
args: { city: 'London' },
401+
},
402+
},
403+
],
404+
},
405+
},
406+
],
407+
}),
408+
)
409+
.mockResolvedValueOnce(
410+
responseFromJson({
411+
candidates: [
412+
{
413+
index: 0,
414+
content: {
415+
role: 'model',
416+
parts: [{ text: 'It is 72 degrees.' }],
417+
},
418+
},
419+
],
420+
}),
421+
);
422+
423+
const result = await genModel.generateContent('weather in London');
424+
425+
expect(result.response.text()).toBe('It is 72 degrees.');
426+
expect(getWeather).toHaveBeenCalledWith({ city: 'London' });
427+
expect(makeRequestStub).toHaveBeenCalledTimes(2);
428+
const followUpBody = JSON.parse(makeRequestStub.mock.calls[1]![1] as string);
429+
expect(followUpBody.contents).toEqual([
430+
{
431+
role: 'user',
432+
parts: [{ text: 'weather in London' }],
433+
},
434+
{
435+
role: 'model',
436+
parts: [
437+
{
438+
functionCall: {
439+
id: 'call-1',
440+
name: 'getWeather',
441+
args: { city: 'London' },
442+
},
443+
},
444+
],
445+
},
446+
{
447+
role: 'function',
448+
parts: [
449+
{
450+
functionResponse: {
451+
id: 'call-1',
452+
name: 'getWeather',
453+
response: { temperature: 72 },
454+
},
455+
},
456+
],
457+
},
458+
]);
459+
makeRequestStub.mockRestore();
460+
});
461+
462+
it('returns the latest response when maxSequentialFunctionCalls is reached', async () => {
463+
const getWeather = jest.fn<(args: object) => object>().mockReturnValue({ temperature: 72 });
464+
const genModel = new GenerativeModel(
465+
fakeAI,
466+
{
467+
model: 'my-model',
468+
tools: [
469+
{
470+
functionDeclarations: [
471+
{
472+
name: 'getWeather',
473+
description: 'Gets weather for a city.',
474+
functionReference: getWeather,
475+
},
476+
],
477+
},
478+
],
479+
},
480+
{ maxSequentialFunctionCalls: 1 },
481+
);
482+
const functionCallResponse: GenerateContentResponse = {
483+
candidates: [
484+
{
485+
index: 0,
486+
content: {
487+
role: 'model',
488+
parts: [
489+
{
490+
functionCall: {
491+
name: 'getWeather',
492+
args: { city: 'London' },
493+
},
494+
},
495+
],
496+
},
497+
},
498+
],
499+
};
500+
const makeRequestStub = jest
501+
.spyOn(request, 'makeRequest')
502+
.mockResolvedValueOnce(responseFromJson(functionCallResponse))
503+
.mockResolvedValueOnce(responseFromJson(functionCallResponse));
504+
505+
const result = await genModel.generateContent('weather in London');
506+
507+
expect(result.response.functionCalls()).toEqual([
508+
{ name: 'getWeather', args: { city: 'London' } },
509+
]);
510+
expect(getWeather).toHaveBeenCalledTimes(1);
511+
expect(makeRequestStub).toHaveBeenCalledTimes(2);
512+
makeRequestStub.mockRestore();
513+
});
514+
365515
it('passes base model params through to ChatSession when there are no startChatParams', async () => {
366516
const genModel = new GenerativeModel(fakeAI, {
367517
model: 'my-model',

packages/ai/lib/models/generative-model.ts

Lines changed: 104 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ import {
2020
Content,
2121
CountTokensRequest,
2222
CountTokensResponse,
23+
FunctionCall,
24+
FunctionDeclaration,
25+
FunctionResponse,
2326
GenerateContentRequest,
2427
GenerateContentResult,
2528
GenerateContentStreamResult,
@@ -40,6 +43,8 @@ import { mergeRequestOptions } from '../requests/request-options';
4043
import { AIModel } from './ai-model';
4144
import { AI } from '../public-types';
4245

46+
const DEFAULT_MAX_SEQUENTIAL_FUNCTION_CALLS = 10;
47+
4348
/**
4449
* Class for generative model APIs.
4550
* @public
@@ -71,19 +76,17 @@ export class GenerativeModel extends AIModel {
7176
singleRequestOptions?: SingleRequestOptions,
7277
): Promise<GenerateContentResult> {
7378
const formattedParams = formatGenerateContentInput(request);
74-
return generateContent(
75-
this._apiSettings,
76-
this.model,
77-
{
78-
generationConfig: this.generationConfig,
79-
safetySettings: this.safetySettings,
80-
tools: this.tools,
81-
toolConfig: this.toolConfig,
82-
systemInstruction: this.systemInstruction,
83-
...formattedParams,
84-
},
85-
mergeRequestOptions(this.requestOptions, singleRequestOptions),
86-
);
79+
const params: GenerateContentRequest = {
80+
generationConfig: this.generationConfig,
81+
safetySettings: this.safetySettings,
82+
tools: this.tools,
83+
toolConfig: this.toolConfig,
84+
systemInstruction: this.systemInstruction,
85+
...formattedParams,
86+
};
87+
const requestOptions = mergeRequestOptions(this.requestOptions, singleRequestOptions);
88+
const result = await generateContent(this._apiSettings, this.model, params, requestOptions);
89+
return this._generateContentWithAutomaticFunctionCalling(params, result, requestOptions);
8790
}
8891

8992
/**
@@ -152,4 +155,92 @@ export class GenerativeModel extends AIModel {
152155
mergeRequestOptions(this.requestOptions, singleRequestOptions),
153156
);
154157
}
158+
159+
private async _generateContentWithAutomaticFunctionCalling(
160+
params: GenerateContentRequest,
161+
result: GenerateContentResult,
162+
requestOptions?: SingleRequestOptions,
163+
): Promise<GenerateContentResult> {
164+
let remainingFunctionCalls =
165+
requestOptions?.maxSequentialFunctionCalls ?? DEFAULT_MAX_SEQUENTIAL_FUNCTION_CALLS;
166+
let currentParams = params;
167+
let currentResult = result;
168+
169+
while (remainingFunctionCalls > 0) {
170+
const functionCalls = currentResult.response.functionCalls?.();
171+
if (!functionCalls?.length) {
172+
return currentResult;
173+
}
174+
175+
const functionResponses = await this._callFunctionReferences(
176+
currentParams.tools,
177+
functionCalls,
178+
);
179+
if (!functionResponses) {
180+
return currentResult;
181+
}
182+
183+
const responseContent = currentResult.response.candidates?.[0]?.content;
184+
if (!responseContent) {
185+
return currentResult;
186+
}
187+
188+
remainingFunctionCalls -= 1;
189+
currentParams = {
190+
...currentParams,
191+
contents: [
192+
...currentParams.contents,
193+
responseContent,
194+
{
195+
role: 'function',
196+
parts: functionResponses.map(functionResponse => ({ functionResponse })),
197+
},
198+
],
199+
};
200+
currentResult = await generateContent(
201+
this._apiSettings,
202+
this.model,
203+
currentParams,
204+
requestOptions,
205+
);
206+
}
207+
208+
return currentResult;
209+
}
210+
211+
private async _callFunctionReferences(
212+
tools: Tool[] | undefined,
213+
functionCalls: FunctionCall[],
214+
): Promise<FunctionResponse[] | undefined> {
215+
const declarations = this._getFunctionDeclarationsWithReferences(tools);
216+
if (!declarations.length) {
217+
return undefined;
218+
}
219+
220+
const functionResponses: FunctionResponse[] = [];
221+
for (const functionCall of functionCalls) {
222+
const declaration = declarations.find(candidate => candidate.name === functionCall.name);
223+
if (!declaration?.functionReference) {
224+
return undefined;
225+
}
226+
227+
const response = (await declaration.functionReference(functionCall.args)) as object;
228+
functionResponses.push({
229+
id: functionCall.id,
230+
name: functionCall.name,
231+
response,
232+
});
233+
}
234+
return functionResponses;
235+
}
236+
237+
private _getFunctionDeclarationsWithReferences(tools: Tool[] | undefined): FunctionDeclaration[] {
238+
return (
239+
tools?.flatMap(tool =>
240+
'functionDeclarations' in tool
241+
? (tool.functionDeclarations?.filter(declaration => declaration.functionReference) ?? [])
242+
: [],
243+
) ?? []
244+
);
245+
}
155246
}

0 commit comments

Comments
 (0)