Skip to content

Commit ea3c8f4

Browse files
committed
feat: centralize maxAttempts configuration via ExperimentFlags
This commit centralizes the retry attempt limits to be driven by the `ExperimentFlags.MAX_ATTEMPTS` flag or the user configuration, rather than being hardcoded throughout the codebase. The retry logic in `baseLlmClient`, `geminiChat`, `client`, and `web-fetch` has been updated to retrieve the `maxAttempts` setting directly from `Config`. It also addresses the removal of the previous 10-attempt cap in the Config initialization to allow tests simulating high retry limits to pass successfully.
1 parent ef0f18e commit ea3c8f4

File tree

11 files changed

+189
-15
lines changed

11 files changed

+189
-15
lines changed

packages/core/src/code_assist/experiments/flagNames.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ export const ExperimentFlags = {
2020
PRO_MODEL_NO_ACCESS: 45768879,
2121
GEMINI_3_1_FLASH_LITE_LAUNCHED: 45771641,
2222
DEFAULT_REQUEST_TIMEOUT: 45773134,
23+
MAX_ATTEMPTS: 45774515,
2324
} as const;
2425

2526
export type ExperimentFlagName =

packages/core/src/config/config.test.ts

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,53 @@ describe('Server Config (config.ts)', () => {
305305
});
306306
expect(config.getMaxAttempts()).toBe(DEFAULT_MAX_ATTEMPTS);
307307
});
308+
309+
it('should use experiment flag if present and valid', () => {
310+
const config = new Config({
311+
...baseParams,
312+
experiments: {
313+
flags: {
314+
[ExperimentFlags.MAX_ATTEMPTS]: {
315+
intValue: '15',
316+
},
317+
},
318+
experimentIds: [],
319+
},
320+
});
321+
expect(config.getMaxAttempts()).toBe(15);
322+
});
323+
324+
it('should fallback to maxAttempts if experiment flag is invalid', () => {
325+
const config = new Config({
326+
...baseParams,
327+
maxAttempts: 5,
328+
experiments: {
329+
flags: {
330+
[ExperimentFlags.MAX_ATTEMPTS]: {
331+
intValue: 'abc',
332+
},
333+
},
334+
experimentIds: [],
335+
},
336+
});
337+
expect(config.getMaxAttempts()).toBe(5);
338+
});
339+
340+
it('should fallback to maxAttempts if experiment flag is non-positive', () => {
341+
const config = new Config({
342+
...baseParams,
343+
maxAttempts: 5,
344+
experiments: {
345+
flags: {
346+
[ExperimentFlags.MAX_ATTEMPTS]: {
347+
intValue: '0',
348+
},
349+
},
350+
experimentIds: [],
351+
},
352+
});
353+
expect(config.getMaxAttempts()).toBe(5);
354+
});
308355
});
309356

310357
beforeEach(() => {

packages/core/src/config/config.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3318,6 +3318,14 @@ export class Config implements McpContext, AgentLoopContext {
33183318
}
33193319

33203320
getMaxAttempts(): number {
3321+
const flagVal =
3322+
this.experiments?.flags?.[ExperimentFlags.MAX_ATTEMPTS]?.intValue;
3323+
if (flagVal !== undefined) {
3324+
const parsed = parseInt(flagVal, 10);
3325+
if (!isNaN(parsed) && parsed > 0) {
3326+
return parsed;
3327+
}
3328+
}
33213329
return this.maxAttempts;
33223330
}
33233331

packages/core/src/core/baseLlmClient.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ describe('BaseLlmClient', () => {
252252
expect(retryWithBackoff).toHaveBeenCalledWith(
253253
expect.any(Function),
254254
expect.objectContaining({
255-
maxAttempts: 5,
255+
maxAttempts: 3,
256256
}),
257257
);
258258
});

packages/core/src/core/baseLlmClient.ts

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ import {
3636
createAvailabilityContextProvider,
3737
} from '../availability/policyHelpers.js';
3838

39-
const DEFAULT_MAX_ATTEMPTS = 5;
40-
4139
/**
4240
* Options for the generateJson utility function.
4341
*/
@@ -328,7 +326,9 @@ export class BaseLlmClient {
328326
return await retryWithBackoff(apiCall, {
329327
shouldRetryOnContent,
330328
maxAttempts:
331-
availabilityMaxAttempts ?? maxAttempts ?? DEFAULT_MAX_ATTEMPTS,
329+
availabilityMaxAttempts ??
330+
maxAttempts ??
331+
this.config.getMaxAttempts(),
332332
getAvailabilityContext,
333333
onPersistent429: this.config.isInteractive()
334334
? (authType, error) =>
@@ -339,7 +339,9 @@ export class BaseLlmClient {
339339
retryFetchErrors: this.config.getRetryFetchErrors(),
340340
onRetry: (attempt, error, delayMs) => {
341341
const actualMaxAttempts =
342-
availabilityMaxAttempts ?? maxAttempts ?? DEFAULT_MAX_ATTEMPTS;
342+
availabilityMaxAttempts ??
343+
maxAttempts ??
344+
this.config.getMaxAttempts();
343345
const modelName = getDisplayString(currentModel);
344346
const errorType = getRetryErrorType(error);
345347

packages/core/src/core/client.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1133,7 +1133,7 @@ export class GeminiClient {
11331133
onPersistent429: onPersistent429Callback,
11341134
onValidationRequired: onValidationRequiredCallback,
11351135
authType: this.config.getContentGeneratorConfig()?.authType,
1136-
maxAttempts: availabilityMaxAttempts,
1136+
maxAttempts: availabilityMaxAttempts ?? this.config.getMaxAttempts(),
11371137
retryFetchErrors: this.config.getRetryFetchErrors(),
11381138
getAvailabilityContext,
11391139
onRetry: (attempt, error, delayMs) => {

packages/core/src/core/geminiChat.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ describe('GeminiChat', () => {
176176
},
177177
getContentGenerator: vi.fn().mockReturnValue(mockContentGenerator),
178178
getRetryFetchErrors: vi.fn().mockReturnValue(false),
179-
getMaxAttempts: vi.fn().mockReturnValue(10),
179+
getMaxAttempts: vi.fn().mockReturnValue(4),
180180
getUserTier: vi.fn().mockReturnValue(undefined),
181181
modelConfigService: {
182182
getResolvedConfig: vi.fn().mockImplementation((modelConfigKey) => {

packages/core/src/core/geminiChat.ts

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,16 +78,13 @@ export type StreamEvent =
7878
* Options for retrying mid-stream errors (e.g. invalid content or API disconnects).
7979
*/
8080
interface MidStreamRetryOptions {
81-
/** Total number of attempts to make (1 initial + N retries). */
82-
maxAttempts: number;
8381
/** The base delay in milliseconds for backoff. */
8482
initialDelayMs: number;
8583
/** Whether to use exponential backoff instead of linear. */
8684
useExponentialBackoff: boolean;
8785
}
8886

8987
const MID_STREAM_RETRY_OPTIONS: MidStreamRetryOptions = {
90-
maxAttempts: 4, // 1 initial call + 3 retries mid-stream
9188
initialDelayMs: 1000,
9289
useExponentialBackoff: true,
9390
};
@@ -420,10 +417,8 @@ export class GeminiChat {
420417
: getRetryErrorType(error);
421418

422419
if (isContentError || (isRetryable && !signal.aborted)) {
423-
// The issue requests exactly 3 retries (4 attempts) for API errors during stream iteration.
424-
// Regardless of the global maxAttempts (e.g. 10), we only want to retry these mid-stream API errors
425-
// up to 3 times before finally throwing the error to the user.
426-
const maxMidStreamAttempts = MID_STREAM_RETRY_OPTIONS.maxAttempts;
420+
// We retry mid-stream API errors up to maxAttempts times before finally throwing the error to the user.
421+
const maxMidStreamAttempts = this.context.config.getMaxAttempts();
427422

428423
if (
429424
attempt < maxAttempts - 1 &&

packages/core/src/tools/web-fetch.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ class WebFetchToolInvocation extends BaseToolInvocation<
309309
return res;
310310
},
311311
{
312+
maxAttempts: this.context.config.getMaxAttempts(),
312313
retryFetchErrors: this.context.config.getRetryFetchErrors(),
313314
onRetry: (attempt, error, delayMs) =>
314315
this.handleRetry(attempt, error, delayMs),
@@ -643,6 +644,7 @@ ${aggregatedContent}
643644
return res;
644645
},
645646
{
647+
maxAttempts: this.context.config.getMaxAttempts(),
646648
retryFetchErrors: this.context.config.getRetryFetchErrors(),
647649
onRetry: (attempt, error, delayMs) =>
648650
this.handleRetry(attempt, error, delayMs),

packages/core/src/utils/retry.test.ts

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
99
import { ApiError } from '@google/genai';
1010
import { AuthType } from '../core/contentGenerator.js';
1111
import { type HttpError, ModelNotFoundError } from './httpErrors.js';
12-
import { retryWithBackoff } from './retry.js';
12+
import { retryWithBackoff, isRetryableError } from './retry.js';
1313
import { setSimulate429 } from './testUtils.js';
1414
import { debugLogger } from './debugLogger.js';
1515
import {
1616
TerminalQuotaError,
1717
RetryableQuotaError,
18+
ValidationRequiredError,
1819
} from './googleQuotaErrors.js';
1920
import { PREVIEW_GEMINI_MODEL } from '../config/models.js';
2021
import type { ModelPolicy } from '../availability/modelPolicy.js';
@@ -332,6 +333,81 @@ describe('retryWithBackoff', () => {
332333
});
333334
});
334335

336+
it('should call onRetry callback on each retry', async () => {
337+
const mockFn = createFailingFunction(2);
338+
const onRetry = vi.fn();
339+
const promise = retryWithBackoff(mockFn, {
340+
maxAttempts: 3,
341+
initialDelayMs: 10,
342+
onRetry,
343+
});
344+
345+
await vi.runAllTimersAsync();
346+
347+
await promise;
348+
expect(onRetry).toHaveBeenCalledTimes(2);
349+
expect(onRetry).toHaveBeenCalledWith(
350+
1,
351+
expect.any(Error),
352+
expect.any(Number),
353+
);
354+
expect(onRetry).toHaveBeenCalledWith(
355+
2,
356+
expect.any(Error),
357+
expect.any(Number),
358+
);
359+
});
360+
361+
it('should handle ValidationRequiredError using onValidationRequired', async () => {
362+
const error = new ValidationRequiredError('Validation required', {} as any);
363+
let validationCalled = false;
364+
const mockFn = vi.fn().mockImplementation(async () => {
365+
if (!validationCalled) {
366+
throw error;
367+
}
368+
return 'success';
369+
});
370+
371+
const onValidationRequired = vi.fn().mockImplementation(async () => {
372+
validationCalled = true;
373+
return 'verify';
374+
});
375+
376+
const promise = retryWithBackoff(mockFn, {
377+
maxAttempts: 3,
378+
initialDelayMs: 10,
379+
onValidationRequired,
380+
});
381+
382+
await vi.runAllTimersAsync();
383+
384+
const result = await promise;
385+
expect(result).toBe('success');
386+
expect(onValidationRequired).toHaveBeenCalledWith(error);
387+
expect(mockFn).toHaveBeenCalledTimes(2);
388+
});
389+
390+
it('should throw ValidationRequiredError if onValidationRequired returns cancel', async () => {
391+
const error = new ValidationRequiredError('Validation required', {} as any);
392+
const mockFn = vi.fn().mockImplementation(async () => {
393+
throw error;
394+
});
395+
396+
const onValidationRequired = vi.fn().mockResolvedValue('cancel');
397+
398+
const promise = retryWithBackoff(mockFn, {
399+
maxAttempts: 3,
400+
initialDelayMs: 10,
401+
onValidationRequired,
402+
});
403+
404+
await expect(promise).rejects.toThrow('Validation required');
405+
await vi.runAllTimersAsync();
406+
407+
expect(error.userHandled).toBe(true);
408+
expect(mockFn).toHaveBeenCalledTimes(1);
409+
});
410+
335411
describe('Fetch error retries', () => {
336412
it("should retry on 'fetch failed' when retryFetchErrors is true", async () => {
337413
const mockFn = vi.fn();
@@ -886,3 +962,37 @@ describe('retryWithBackoff', () => {
886962
});
887963
});
888964
});
965+
966+
describe('isRetryableError', () => {
967+
it('should return true for 429 errors', () => {
968+
const error = new ApiError({ message: 'Quota exceeded', status: 429 });
969+
expect(isRetryableError(error)).toBe(true);
970+
});
971+
972+
it('should return true for 499 errors', () => {
973+
const error = new ApiError({
974+
message: 'Client closed request',
975+
status: 499,
976+
});
977+
expect(isRetryableError(error)).toBe(true);
978+
});
979+
980+
it('should return true for 500 errors', () => {
981+
const error = new ApiError({
982+
message: 'Internal Server Error',
983+
status: 500,
984+
});
985+
expect(isRetryableError(error)).toBe(true);
986+
});
987+
988+
it('should return false for 400 errors', () => {
989+
const error = new ApiError({ message: 'Bad Request', status: 400 });
990+
expect(isRetryableError(error)).toBe(false);
991+
});
992+
993+
it('should return true for network error codes like ECONNRESET', () => {
994+
const error = new Error('ECONNRESET');
995+
(error as any).code = 'ECONNRESET';
996+
expect(isRetryableError(error)).toBe(true);
997+
});
998+
});

0 commit comments

Comments
 (0)