Skip to content

Commit fefd27b

Browse files
skeshivegemini-cli-robot
authored andcommitted
feat(acp): Add support for AI Gateway auth (#21305)
1 parent b25c813 commit fefd27b

5 files changed

Lines changed: 140 additions & 7 deletions

File tree

packages/cli/src/acp/acpClient.test.ts

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,16 @@ describe('GeminiAgent', () => {
208208
});
209209

210210
expect(response.protocolVersion).toBe(acp.PROTOCOL_VERSION);
211-
expect(response.authMethods).toHaveLength(3);
211+
expect(response.authMethods).toHaveLength(4);
212+
const gatewayAuth = response.authMethods?.find(
213+
(m) => m.id === AuthType.GATEWAY,
214+
);
215+
expect(gatewayAuth?._meta).toEqual({
216+
gateway: {
217+
protocol: 'google',
218+
restartRequired: 'false',
219+
},
220+
});
212221
const geminiAuth = response.authMethods?.find(
213222
(m) => m.id === AuthType.USE_GEMINI,
214223
);
@@ -228,6 +237,8 @@ describe('GeminiAgent', () => {
228237
expect(mockConfig.refreshAuth).toHaveBeenCalledWith(
229238
AuthType.LOGIN_WITH_GOOGLE,
230239
undefined,
240+
undefined,
241+
undefined,
231242
);
232243
expect(mockSettings.setValue).toHaveBeenCalledWith(
233244
SettingScope.User,
@@ -247,6 +258,8 @@ describe('GeminiAgent', () => {
247258
expect(mockConfig.refreshAuth).toHaveBeenCalledWith(
248259
AuthType.USE_GEMINI,
249260
'test-api-key',
261+
undefined,
262+
undefined,
250263
);
251264
expect(mockSettings.setValue).toHaveBeenCalledWith(
252265
SettingScope.User,
@@ -255,6 +268,45 @@ describe('GeminiAgent', () => {
255268
);
256269
});
257270

271+
it('should authenticate correctly with gateway method', async () => {
272+
await agent.authenticate({
273+
methodId: AuthType.GATEWAY,
274+
_meta: {
275+
gateway: {
276+
baseUrl: 'https://example.com',
277+
headers: { Authorization: 'Bearer token' },
278+
},
279+
},
280+
} as unknown as acp.AuthenticateRequest);
281+
282+
expect(mockConfig.refreshAuth).toHaveBeenCalledWith(
283+
AuthType.GATEWAY,
284+
undefined,
285+
'https://example.com',
286+
{ Authorization: 'Bearer token' },
287+
);
288+
expect(mockSettings.setValue).toHaveBeenCalledWith(
289+
SettingScope.User,
290+
'security.auth.selectedType',
291+
AuthType.GATEWAY,
292+
);
293+
});
294+
295+
it('should throw acp.RequestError when gateway payload is malformed', async () => {
296+
await expect(
297+
agent.authenticate({
298+
methodId: AuthType.GATEWAY,
299+
_meta: {
300+
gateway: {
301+
// Invalid baseUrl
302+
baseUrl: 123,
303+
headers: { Authorization: 'Bearer token' },
304+
},
305+
},
306+
} as unknown as acp.AuthenticateRequest),
307+
).rejects.toThrow(/Malformed gateway payload/);
308+
});
309+
258310
it('should create a new session', async () => {
259311
vi.useFakeTimers();
260312
mockConfig.getContentGeneratorConfig = vi.fn().mockReturnValue({

packages/cli/src/acp/acpClient.ts

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ export class GeminiAgent {
9898
private sessions: Map<string, Session> = new Map();
9999
private clientCapabilities: acp.ClientCapabilities | undefined;
100100
private apiKey: string | undefined;
101+
private baseUrl: string | undefined;
102+
private customHeaders: Record<string, string> | undefined;
101103

102104
constructor(
103105
private config: Config,
@@ -131,6 +133,17 @@ export class GeminiAgent {
131133
name: 'Vertex AI',
132134
description: 'Use an API key with Vertex AI GenAI API',
133135
},
136+
{
137+
id: AuthType.GATEWAY,
138+
name: 'AI API Gateway',
139+
description: 'Use a custom AI API Gateway',
140+
_meta: {
141+
gateway: {
142+
protocol: 'google',
143+
restartRequired: 'false',
144+
},
145+
},
146+
},
134147
];
135148

136149
await this.config.initialize();
@@ -179,7 +192,38 @@ export class GeminiAgent {
179192
if (apiKey) {
180193
this.apiKey = apiKey;
181194
}
182-
await this.config.refreshAuth(method, apiKey ?? this.apiKey);
195+
196+
// Extract gateway details if present
197+
const gatewaySchema = z.object({
198+
baseUrl: z.string().optional(),
199+
headers: z.record(z.string()).optional(),
200+
});
201+
202+
let baseUrl: string | undefined;
203+
let headers: Record<string, string> | undefined;
204+
205+
if (meta?.['gateway']) {
206+
const result = gatewaySchema.safeParse(meta['gateway']);
207+
if (result.success) {
208+
baseUrl = result.data.baseUrl;
209+
headers = result.data.headers;
210+
} else {
211+
throw new acp.RequestError(
212+
-32602,
213+
`Malformed gateway payload: ${result.error.message}`,
214+
);
215+
}
216+
}
217+
218+
this.baseUrl = baseUrl;
219+
this.customHeaders = headers;
220+
221+
await this.config.refreshAuth(
222+
method,
223+
apiKey ?? this.apiKey,
224+
baseUrl,
225+
headers,
226+
);
183227
} catch (e) {
184228
throw new acp.RequestError(-32000, getAcpErrorMessage(e));
185229
}
@@ -209,7 +253,12 @@ export class GeminiAgent {
209253
let isAuthenticated = false;
210254
let authErrorMessage = '';
211255
try {
212-
await config.refreshAuth(authType, this.apiKey);
256+
await config.refreshAuth(
257+
authType,
258+
this.apiKey,
259+
this.baseUrl,
260+
this.customHeaders,
261+
);
213262
isAuthenticated = true;
214263

215264
// Extra validation for Gemini API key
@@ -371,7 +420,12 @@ export class GeminiAgent {
371420
// This satisfies the security requirement to verify the user before executing
372421
// potentially unsafe server definitions.
373422
try {
374-
await config.refreshAuth(selectedAuthType, this.apiKey);
423+
await config.refreshAuth(
424+
selectedAuthType,
425+
this.apiKey,
426+
this.baseUrl,
427+
this.customHeaders,
428+
);
375429
} catch (e) {
376430
debugLogger.error(`Authentication failed: ${e}`);
377431
throw acp.RequestError.authRequired();

packages/core/src/config/config.test.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,8 @@ describe('Server Config (config.ts)', () => {
500500
config,
501501
authType,
502502
undefined,
503+
undefined,
504+
undefined,
503505
);
504506
// Verify that contentGeneratorConfig is updated
505507
expect(config.getContentGeneratorConfig()).toEqual(mockContentConfig);

packages/core/src/config/config.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1190,7 +1190,12 @@ export class Config implements McpContext {
11901190
return this.contentGenerator;
11911191
}
11921192

1193-
async refreshAuth(authMethod: AuthType, apiKey?: string) {
1193+
async refreshAuth(
1194+
authMethod: AuthType,
1195+
apiKey?: string,
1196+
baseUrl?: string,
1197+
customHeaders?: Record<string, string>,
1198+
) {
11941199
// Reset availability service when switching auth
11951200
this.modelAvailabilityService.reset();
11961201

@@ -1217,6 +1222,8 @@ export class Config implements McpContext {
12171222
this,
12181223
authMethod,
12191224
apiKey,
1225+
baseUrl,
1226+
customHeaders,
12201227
);
12211228
this.contentGenerator = await createContentGenerator(
12221229
newContentGeneratorConfig,

packages/core/src/core/contentGenerator.ts

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ export enum AuthType {
5959
USE_VERTEX_AI = 'vertex-ai',
6060
LEGACY_CLOUD_SHELL = 'cloud-shell',
6161
COMPUTE_ADC = 'compute-default-credentials',
62+
GATEWAY = 'gateway',
6263
}
6364

6465
/**
@@ -93,12 +94,16 @@ export type ContentGeneratorConfig = {
9394
vertexai?: boolean;
9495
authType?: AuthType;
9596
proxy?: string;
97+
baseUrl?: string;
98+
customHeaders?: Record<string, string>;
9699
};
97100

98101
export async function createContentGeneratorConfig(
99102
config: Config,
100103
authType: AuthType | undefined,
101104
apiKey?: string,
105+
baseUrl?: string,
106+
customHeaders?: Record<string, string>,
102107
): Promise<ContentGeneratorConfig> {
103108
const geminiApiKey =
104109
apiKey ||
@@ -115,6 +120,8 @@ export async function createContentGeneratorConfig(
115120
const contentGeneratorConfig: ContentGeneratorConfig = {
116121
authType,
117122
proxy: config?.getProxy(),
123+
baseUrl,
124+
customHeaders,
118125
};
119126

120127
// If we are using Google auth or we are in Cloud Shell, there is nothing else to validate for now
@@ -203,9 +210,13 @@ export async function createContentGenerator(
203210

204211
if (
205212
config.authType === AuthType.USE_GEMINI ||
206-
config.authType === AuthType.USE_VERTEX_AI
213+
config.authType === AuthType.USE_VERTEX_AI ||
214+
config.authType === AuthType.GATEWAY
207215
) {
208216
let headers: Record<string, string> = { ...baseHeaders };
217+
if (config.customHeaders) {
218+
headers = { ...headers, ...config.customHeaders };
219+
}
209220
if (gcConfig?.getUsageStatisticsEnabled()) {
210221
const installationManager = new InstallationManager();
211222
const installationId = installationManager.getInstallationId();
@@ -214,7 +225,14 @@ export async function createContentGenerator(
214225
'x-gemini-api-privileged-user-id': `${installationId}`,
215226
};
216227
}
217-
const httpOptions = { headers };
228+
const httpOptions: {
229+
baseUrl?: string;
230+
headers: Record<string, string>;
231+
} = { headers };
232+
233+
if (config.baseUrl) {
234+
httpOptions.baseUrl = config.baseUrl;
235+
}
218236

219237
const googleGenAI = new GoogleGenAI({
220238
apiKey: config.apiKey === '' ? undefined : config.apiKey,

0 commit comments

Comments
 (0)