Skip to content

Commit f1d9fda

Browse files
author
benbot
committed
adds support for gemini-pro
1 parent 8faec39 commit f1d9fda

3 files changed

Lines changed: 79 additions & 18 deletions

File tree

packages/plugins/googleai/server/src/constants.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ export const GOOGLEAI_ENDPOINT =
88
process.env['NEXT_GOOGLEAI_ENDPOINT'] ||
99
process.env['REACT_APP_GOOGLEAI_ENDPOINT'] ||
1010
process.env['GOOGLEAI_ENDPOINT'] ||
11-
'https://generativelanguage.googleapis.com/v1beta2/models'
11+
'https://generativelanguage.googleapis.com/v1beta/models'

packages/plugins/googleai/server/src/functions/makeChatCompletion.ts

Lines changed: 77 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,52 @@ type ChatMessage = {
1010
content: string
1111
}
1212

13-
/**
14-
* Generate a completion text based on prior chat conversation input.
15-
* @param data - CompletionHandlerInputData object.
16-
* @returns An object with success status and either a result or an error message.
17-
*/
18-
export async function makeChatCompletion(
19-
data: CompletionHandlerInputData
20-
): Promise<{
21-
success: boolean
22-
result?: string | null
23-
error?: string | null
24-
}> {
13+
interface ContentMessage {
14+
contents: Content[]
15+
}
16+
interface Content {
17+
parts: { text: string }[]
18+
role: string
19+
}
20+
21+
async function makeGeminiChatCompletion(data: CompletionHandlerInputData): Promise<any> {
2522
const { node, inputs, context } = data
2623

2724
// Get the system message and conversation inputs
2825
const system = inputs['system']?.[0] as ChatMessage
2926
const conversation = inputs['conversation']?.[0] as any
3027

28+
// Initialize conversationMessages array
29+
const conversationMessages: Content[] = []
30+
31+
// Add elements to conversationMessages
32+
conversation?.forEach(event => {
33+
conversationMessages.push({ role: event.observer === event.sender ? 'model' : 'user', parts: [{ text: event.content }] })
34+
});
35+
// Get the user input
36+
const input = inputs['input']?.[0] as string
37+
38+
conversationMessages.push({ role: 'user', parts: [{text: input}]})
39+
40+
const examples = (inputs['examples']?.[0] as string[]) || []
41+
42+
// Get or set default settings
43+
const request = {
44+
contents: conversationMessages,
45+
generationConfig: {
46+
temperature: parseFloat((node?.data?.temperature as string) ?? '0.0'),
47+
top_p: parseFloat((node?.data?.top_p as string) ?? '0.95'),
48+
top_k: parseFloat((node?.data?.top_k as string) ?? '40'),
49+
}
50+
} as any
51+
}
52+
53+
async function makePalmChatCompletion(data: CompletionHandlerInputData): Promise<any> {
54+
const { node, inputs, context } = data
55+
// Get the system message and conversation inputs
56+
const system = inputs['system']?.[0] as ChatMessage
57+
const conversation = inputs['conversation']?.[0] as any
58+
3159
// Initialize conversationMessages array
3260
const conversationMessages: ChatMessage[] = []
3361

@@ -56,6 +84,30 @@ export async function makeChatCompletion(
5684
top_k: parseFloat((node?.data?.top_k as string) ?? '40'),
5785
} as any
5886

87+
return settings;
88+
}
89+
90+
/**
91+
* Generate a completion text based on prior chat conversation input.
92+
* @param data - CompletionHandlerInputData object.
93+
* @returns An object with success status and either a result or an error message.
94+
*/
95+
export async function makeChatCompletion(
96+
data: CompletionHandlerInputData
97+
): Promise<{
98+
success: boolean
99+
result?: string | null
100+
error?: string | null
101+
}> {
102+
const { node, inputs, context } = data
103+
104+
let requestData: any = null;
105+
if ((node?.data?.model as string).includes('gemini')) {
106+
requestData = await makeGeminiChatCompletion(data);
107+
} else {
108+
requestData = await makePalmChatCompletion(data);
109+
}
110+
59111
const apiKey =
60112
(context?.module?.secrets &&
61113
context?.module?.secrets['googleai_api_key']) ||
@@ -70,18 +122,27 @@ export async function makeChatCompletion(
70122
}
71123

72124
try {
125+
126+
let commandType = 'generateMessage';
127+
128+
// palm generates "messages" but gemini generates "content"
129+
if ((node?.data?.model as string).includes('gemini')) {
130+
commandType = 'generateContent';
131+
}
132+
73133
const start = Date.now()
74-
const endpoint = `${GOOGLEAI_ENDPOINT}/${node?.data?.model}:generateMessage?key=${apiKey}`
134+
const endpoint = `${GOOGLEAI_ENDPOINT}/${node?.data?.model}:${commandType}?key=${apiKey}`
75135
// Make the API call to GoogleAI
76136
const completion = await fetch(endpoint, {
77137
method: 'POST',
78138
headers: {
79139
'Content-Type': 'application/json',
80140
},
81-
body: JSON.stringify(settings),
141+
body: JSON.stringify(requestData),
82142
})
83143

84144
const completionData = await completion.json()
145+
console.log(completionData)
85146

86147
if (completionData.error) {
87148
console.error('GoogleAI Error', completionData.error)
@@ -97,13 +158,13 @@ export async function makeChatCompletion(
97158
saveRequest({
98159
projectId: context.projectId,
99160
agentId: context.agent?.id || 'preview',
100-
requestData: JSON.stringify(settings),
161+
requestData: JSON.stringify(requestData),
101162
responseData: JSON.stringify(completionData),
102163
startTime: start,
103164
statusCode: completion.status,
104165
status: completion.statusText,
105166
model: node?.data?.model as string,
106-
parameters: JSON.stringify(settings),
167+
parameters: JSON.stringify(requestData),
107168
type: 'completion',
108169
provider: 'googleai',
109170
totalTokens: 0, // usage.total_tokens,

packages/plugins/googleai/shared/src/index.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ const completionProviders: CompletionProvider[] = [
6060
type: stringSocket,
6161
},
6262
],
63-
models: ['chat-bison-001'],
63+
models: ['chat-bison-001', 'gemini-pro'],
6464
},
6565
{
6666
type: 'text',

0 commit comments

Comments
 (0)