@@ -9,12 +9,13 @@ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
99import { ApiError } from '@google/genai' ;
1010import { AuthType } from '../core/contentGenerator.js' ;
1111import { type HttpError , ModelNotFoundError } from './httpErrors.js' ;
12- import { retryWithBackoff } from './retry.js' ;
12+ import { retryWithBackoff , isRetryableError } from './retry.js' ;
1313import { setSimulate429 } from './testUtils.js' ;
1414import { debugLogger } from './debugLogger.js' ;
1515import {
1616 TerminalQuotaError ,
1717 RetryableQuotaError ,
18+ ValidationRequiredError ,
1819} from './googleQuotaErrors.js' ;
1920import { PREVIEW_GEMINI_MODEL } from '../config/models.js' ;
2021import type { ModelPolicy } from '../availability/modelPolicy.js' ;
@@ -332,6 +333,81 @@ describe('retryWithBackoff', () => {
332333 } ) ;
333334 } ) ;
334335
336+ it ( 'should call onRetry callback on each retry' , async ( ) => {
337+ const mockFn = createFailingFunction ( 2 ) ;
338+ const onRetry = vi . fn ( ) ;
339+ const promise = retryWithBackoff ( mockFn , {
340+ maxAttempts : 3 ,
341+ initialDelayMs : 10 ,
342+ onRetry,
343+ } ) ;
344+
345+ await vi . runAllTimersAsync ( ) ;
346+
347+ await promise ;
348+ expect ( onRetry ) . toHaveBeenCalledTimes ( 2 ) ;
349+ expect ( onRetry ) . toHaveBeenCalledWith (
350+ 1 ,
351+ expect . any ( Error ) ,
352+ expect . any ( Number ) ,
353+ ) ;
354+ expect ( onRetry ) . toHaveBeenCalledWith (
355+ 2 ,
356+ expect . any ( Error ) ,
357+ expect . any ( Number ) ,
358+ ) ;
359+ } ) ;
360+
361+ it ( 'should handle ValidationRequiredError using onValidationRequired' , async ( ) => {
362+ const error = new ValidationRequiredError ( 'Validation required' , { } as any ) ;
363+ let validationCalled = false ;
364+ const mockFn = vi . fn ( ) . mockImplementation ( async ( ) => {
365+ if ( ! validationCalled ) {
366+ throw error ;
367+ }
368+ return 'success' ;
369+ } ) ;
370+
371+ const onValidationRequired = vi . fn ( ) . mockImplementation ( async ( ) => {
372+ validationCalled = true ;
373+ return 'verify' ;
374+ } ) ;
375+
376+ const promise = retryWithBackoff ( mockFn , {
377+ maxAttempts : 3 ,
378+ initialDelayMs : 10 ,
379+ onValidationRequired,
380+ } ) ;
381+
382+ await vi . runAllTimersAsync ( ) ;
383+
384+ const result = await promise ;
385+ expect ( result ) . toBe ( 'success' ) ;
386+ expect ( onValidationRequired ) . toHaveBeenCalledWith ( error ) ;
387+ expect ( mockFn ) . toHaveBeenCalledTimes ( 2 ) ;
388+ } ) ;
389+
390+ it ( 'should throw ValidationRequiredError if onValidationRequired returns cancel' , async ( ) => {
391+ const error = new ValidationRequiredError ( 'Validation required' , { } as any ) ;
392+ const mockFn = vi . fn ( ) . mockImplementation ( async ( ) => {
393+ throw error ;
394+ } ) ;
395+
396+ const onValidationRequired = vi . fn ( ) . mockResolvedValue ( 'cancel' ) ;
397+
398+ const promise = retryWithBackoff ( mockFn , {
399+ maxAttempts : 3 ,
400+ initialDelayMs : 10 ,
401+ onValidationRequired,
402+ } ) ;
403+
404+ await expect ( promise ) . rejects . toThrow ( 'Validation required' ) ;
405+ await vi . runAllTimersAsync ( ) ;
406+
407+ expect ( error . userHandled ) . toBe ( true ) ;
408+ expect ( mockFn ) . toHaveBeenCalledTimes ( 1 ) ;
409+ } ) ;
410+
335411 describe ( 'Fetch error retries' , ( ) => {
336412 it ( "should retry on 'fetch failed' when retryFetchErrors is true" , async ( ) => {
337413 const mockFn = vi . fn ( ) ;
@@ -886,3 +962,37 @@ describe('retryWithBackoff', () => {
886962 } ) ;
887963 } ) ;
888964} ) ;
965+
966+ describe ( 'isRetryableError' , ( ) => {
967+ it ( 'should return true for 429 errors' , ( ) => {
968+ const error = new ApiError ( { message : 'Quota exceeded' , status : 429 } ) ;
969+ expect ( isRetryableError ( error ) ) . toBe ( true ) ;
970+ } ) ;
971+
972+ it ( 'should return true for 499 errors' , ( ) => {
973+ const error = new ApiError ( {
974+ message : 'Client closed request' ,
975+ status : 499 ,
976+ } ) ;
977+ expect ( isRetryableError ( error ) ) . toBe ( true ) ;
978+ } ) ;
979+
980+ it ( 'should return true for 500 errors' , ( ) => {
981+ const error = new ApiError ( {
982+ message : 'Internal Server Error' ,
983+ status : 500 ,
984+ } ) ;
985+ expect ( isRetryableError ( error ) ) . toBe ( true ) ;
986+ } ) ;
987+
988+ it ( 'should return false for 400 errors' , ( ) => {
989+ const error = new ApiError ( { message : 'Bad Request' , status : 400 } ) ;
990+ expect ( isRetryableError ( error ) ) . toBe ( false ) ;
991+ } ) ;
992+
993+ it ( 'should return true for network error codes like ECONNRESET' , ( ) => {
994+ const error = new Error ( 'ECONNRESET' ) ;
995+ ( error as any ) . code = 'ECONNRESET' ;
996+ expect ( isRetryableError ( error ) ) . toBe ( true ) ;
997+ } ) ;
998+ } ) ;
0 commit comments