Skip to content

Commit 670a247

Browse files
committed
feat(ai): support streaming chat function auto-calling
1 parent 0f5417c commit 670a247

3 files changed

Lines changed: 189 additions & 9 deletions

File tree

packages/ai/__tests__/chat-session.test.ts

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import { describe, expect, it, afterEach, jest } from '@jest/globals';
1818

1919
import * as generateContentMethods from '../lib/methods/generate-content';
20-
import { GenerateContentStreamResult } from '../lib/types';
20+
import { EnhancedGenerateContentResponse, GenerateContentStreamResult } from '../lib/types';
2121
import { ChatSession } from '../lib/methods/chat-session';
2222
import { ApiSettings } from '../lib/types/internal';
2323
import { RequestOptions } from '../lib/types/requests';
@@ -35,6 +35,15 @@ const requestOptions: RequestOptions = {
3535
timeout: 1000,
3636
};
3737

38+
function streamResult(response: EnhancedGenerateContentResponse): GenerateContentStreamResult {
39+
return {
40+
stream: (async function* () {
41+
yield response;
42+
})(),
43+
response: Promise.resolve(response),
44+
};
45+
}
46+
3847
describe('ChatSession', () => {
3948
afterEach(() => {
4049
jest.restoreAllMocks();
@@ -129,6 +138,102 @@ describe('ChatSession', () => {
129138
);
130139
});
131140

141+
it('automatically calls functionReference from stream function calls', async () => {
142+
const getWeather = jest.fn<(args: object) => object>().mockReturnValue({ temperature: 72 });
143+
const functionCallResponse = {
144+
candidates: [
145+
{
146+
index: 0,
147+
content: {
148+
role: 'model',
149+
parts: [
150+
{
151+
functionCall: {
152+
name: 'getWeather',
153+
args: { city: 'London' },
154+
},
155+
},
156+
],
157+
},
158+
},
159+
],
160+
functionCalls: () => [{ name: 'getWeather', args: { city: 'London' } }],
161+
} as EnhancedGenerateContentResponse;
162+
const finalResponse = {
163+
candidates: [
164+
{
165+
index: 0,
166+
content: {
167+
role: 'model',
168+
parts: [{ text: 'It is 72 degrees.' }],
169+
},
170+
},
171+
],
172+
functionCalls: () => undefined,
173+
} as EnhancedGenerateContentResponse;
174+
const generateContentStreamStub = jest
175+
.spyOn(generateContentMethods, 'generateContentStream')
176+
.mockResolvedValueOnce(streamResult(functionCallResponse))
177+
.mockResolvedValueOnce(streamResult(finalResponse));
178+
const chatSession = new ChatSession(
179+
fakeApiSettings,
180+
'a-model',
181+
{
182+
tools: [
183+
{
184+
functionDeclarations: [
185+
{
186+
name: 'getWeather',
187+
description: 'Gets weather for a city.',
188+
functionReference: getWeather,
189+
},
190+
],
191+
},
192+
],
193+
},
194+
requestOptions,
195+
);
196+
197+
const result = await chatSession.sendMessageStream('weather in London');
198+
await result.response;
199+
const history = await chatSession.getHistory();
200+
201+
expect(getWeather).toHaveBeenCalledWith({ city: 'London' });
202+
expect(generateContentStreamStub).toHaveBeenCalledTimes(2);
203+
expect(history).toEqual([
204+
{
205+
role: 'user',
206+
parts: [{ text: 'weather in London' }],
207+
},
208+
{
209+
role: 'model',
210+
parts: [
211+
{
212+
functionCall: {
213+
name: 'getWeather',
214+
args: { city: 'London' },
215+
},
216+
},
217+
],
218+
},
219+
{
220+
role: 'function',
221+
parts: [
222+
{
223+
functionResponse: {
224+
name: 'getWeather',
225+
response: { temperature: 72 },
226+
},
227+
},
228+
],
229+
},
230+
{
231+
role: 'model',
232+
parts: [{ text: 'It is 72 degrees.' }],
233+
},
234+
]);
235+
});
236+
132237
it('downstream sendPromise errors should log but not throw', async () => {
133238
const consoleStub = jest.spyOn(console, 'error').mockImplementation(() => {});
134239
// make response undefined so that response.candidates errors

packages/ai/lib/methods/automatic-function-calling.ts

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@ import {
2222
FunctionResponse,
2323
GenerateContentRequest,
2424
GenerateContentResult,
25+
GenerateContentStreamResult,
26+
EnhancedGenerateContentResponse,
2527
SingleRequestOptions,
2628
Tool,
2729
} from '../types';
2830
import { ApiSettings } from '../types/internal';
29-
import { generateContent } from './generate-content';
31+
import { generateContent, generateContentStream } from './generate-content';
3032

3133
const DEFAULT_MAX_SEQUENTIAL_FUNCTION_CALLS = 10;
3234

@@ -35,6 +37,11 @@ export interface AutomaticFunctionCallingResult {
3537
addedContents: Content[];
3638
}
3739

40+
export interface AutomaticFunctionCallingStreamResult {
41+
result: GenerateContentStreamResult;
42+
addedContents: Content[];
43+
}
44+
3845
export async function generateContentWithAutomaticFunctionCalling(
3946
apiSettings: ApiSettings,
4047
model: string,
@@ -80,8 +87,61 @@ export async function generateContentWithAutomaticFunctionCalling(
8087
return { result: currentResult, addedContents };
8188
}
8289

83-
function getModelResponseContent(result: GenerateContentResult): Content | undefined {
84-
const responseContent = result.response.candidates?.[0]?.content;
90+
export async function generateContentStreamWithAutomaticFunctionCalling(
91+
apiSettings: ApiSettings,
92+
model: string,
93+
params: GenerateContentRequest,
94+
result: GenerateContentStreamResult,
95+
requestOptions?: SingleRequestOptions,
96+
): Promise<AutomaticFunctionCallingStreamResult> {
97+
if (!getFunctionDeclarationsWithReferences(params.tools).length) {
98+
return { result, addedContents: [] };
99+
}
100+
101+
let remainingFunctionCalls =
102+
requestOptions?.maxSequentialFunctionCalls ?? DEFAULT_MAX_SEQUENTIAL_FUNCTION_CALLS;
103+
let currentParams = params;
104+
let currentResult = result;
105+
const addedContents: Content[] = [];
106+
107+
while (remainingFunctionCalls > 0) {
108+
const response = await currentResult.response;
109+
const functionCalls = response.functionCalls?.();
110+
if (!functionCalls?.length) {
111+
return { result: currentResult, addedContents };
112+
}
113+
114+
const functionResponses = await callFunctionReferences(currentParams.tools, functionCalls);
115+
if (!functionResponses) {
116+
return { result: currentResult, addedContents };
117+
}
118+
119+
const responseContent = getModelResponseContent(response);
120+
if (!responseContent) {
121+
return { result: currentResult, addedContents };
122+
}
123+
124+
remainingFunctionCalls -= 1;
125+
const functionResponseContent: Content = {
126+
role: 'function',
127+
parts: functionResponses.map(functionResponse => ({ functionResponse })),
128+
};
129+
addedContents.push(responseContent, functionResponseContent);
130+
currentParams = {
131+
...currentParams,
132+
contents: [...currentParams.contents, responseContent, functionResponseContent],
133+
};
134+
currentResult = await generateContentStream(apiSettings, model, currentParams, requestOptions);
135+
}
136+
137+
return { result: currentResult, addedContents };
138+
}
139+
140+
function getModelResponseContent(
141+
responseOrResult: GenerateContentResult | EnhancedGenerateContentResponse,
142+
): Content | undefined {
143+
const response = 'response' in responseOrResult ? responseOrResult.response : responseOrResult;
144+
const responseContent = response.candidates?.[0]?.content;
85145
if (!responseContent) {
86146
return undefined;
87147
}

packages/ai/lib/methods/chat-session.ts

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@ import {
3636
templateGenerateContent,
3737
templateGenerateContentStream,
3838
} from './generate-content';
39-
import { generateContentWithAutomaticFunctionCalling } from './automatic-function-calling';
39+
import {
40+
generateContentStreamWithAutomaticFunctionCalling,
41+
generateContentWithAutomaticFunctionCalling,
42+
} from './automatic-function-calling';
4043
import { ApiSettings } from '../types/internal';
4144
import { logger } from '../logger';
4245
import { mergeRequestOptions } from '../requests/request-options';
@@ -176,11 +179,20 @@ export class ChatSession extends ChatSessionBase<StartChatParams> {
176179
systemInstruction: this.params?.systemInstruction,
177180
contents: [...this._history, newContent],
178181
};
182+
const requestOptions = mergeRequestOptions(this.requestOptions, singleRequestOptions);
179183
const streamPromise = generateContentStream(
180184
this._apiSettings,
181185
this.model,
182186
generateContentRequest,
183-
mergeRequestOptions(this.requestOptions, singleRequestOptions),
187+
requestOptions,
188+
).then(result =>
189+
generateContentStreamWithAutomaticFunctionCalling(
190+
this._apiSettings,
191+
this.model,
192+
generateContentRequest,
193+
result,
194+
requestOptions,
195+
),
184196
);
185197

186198
// Add onto the chain.
@@ -191,10 +203,13 @@ export class ChatSession extends ChatSessionBase<StartChatParams> {
191203
.catch(_ignored => {
192204
throw new Error(SILENT_ERROR);
193205
})
194-
.then(streamResult => streamResult.response)
195-
.then((response: EnhancedGenerateContentResponse) => {
206+
.then(({ result, addedContents }) =>
207+
result.response.then(response => ({ response, addedContents })),
208+
)
209+
.then(({ response, addedContents }) => {
196210
if (response.candidates && response.candidates.length > 0) {
197211
this._history.push(newContent);
212+
this._history.push(...addedContents);
198213
const responseContent = { ...response.candidates[0]?.content };
199214
// Response seems to come back without a role set.
200215
if (!responseContent.role) {
@@ -220,7 +235,7 @@ export class ChatSession extends ChatSessionBase<StartChatParams> {
220235
logger.error(e);
221236
}
222237
});
223-
return streamPromise;
238+
return (await streamPromise).result;
224239
}
225240
}
226241

0 commit comments

Comments
 (0)