diff --git a/langextract/providers/gemini.py b/langextract/providers/gemini.py index a82afe1e..97b6bb46 100644 --- a/langextract/providers/gemini.py +++ b/langextract/providers/gemini.py @@ -19,6 +19,8 @@ import concurrent.futures import dataclasses +import random +import time from typing import Any, Final, Iterator, Sequence from absl import logging @@ -67,6 +69,7 @@ class GeminiLanguageModel(base_model.BaseLanguageModel): # pylint: disable=too- format_type: data.FormatType = data.FormatType.JSON temperature: float = 0.0 max_workers: int = 10 + max_retries: int = 5 fence_output: bool = False _extra_kwargs: dict[str, Any] = dataclasses.field( default_factory=dict, repr=False, compare=False @@ -104,6 +107,7 @@ def __init__( format_type: data.FormatType = data.FormatType.JSON, temperature: float = 0.0, max_workers: int = 10, + max_retries: int = 5, fence_output: bool = False, **kwargs, ) -> None: @@ -121,6 +125,7 @@ def __init__( format_type: Output format (JSON or YAML). temperature: Sampling temperature. max_workers: Maximum number of parallel API calls. + max_retries: Maximum number of retries for rate limit (429) errors. fence_output: Whether to wrap output in markdown fences (ignored, Gemini handles this based on schema). **kwargs: Additional Gemini API parameters. Only allowlisted keys are @@ -148,6 +153,7 @@ def __init__( self.format_type = format_type self.temperature = temperature self.max_workers = max_workers + self.max_retries = max_retries self.fence_output = fence_output # Extract batch config before we filter kwargs into _extra_kwargs @@ -214,15 +220,47 @@ def _process_single_prompt( config.setdefault('response_mime_type', 'application/json') config.setdefault('response_schema', self.gemini_schema.schema_dict) - response = self._client.models.generate_content( - model=self.model_id, contents=prompt, config=config - ) + base_delay = 1.0 # seconds + max_delay = 120.0 # seconds - return core_types.ScoredOutput(score=1.0, output=response.text) + for attempt in range(self.max_retries + 1): + try: + response = self._client.models.generate_content( + model=self.model_id, contents=prompt, config=config + ) + return core_types.ScoredOutput(score=1.0, output=response.text) + + except Exception as e: + # Check for 429 RESOURCE_EXHAUSTED + is_rate_limit = False + error_message = str(e) + if "429" in error_message or "RESOURCE_EXHAUSTED" in error_message: + is_rate_limit = True + + if is_rate_limit and attempt < self.max_retries: + delay = min(max_delay, base_delay * (2**attempt)) + jitter = random.uniform(0, 0.1 * delay) + sleep_time = delay + jitter + logging.warning( + "Gemini API rate limit hit (429). Retrying in %.2fs (attempt %d/%d)", + sleep_time, + attempt + 1, + self.max_retries, + ) + time.sleep(sleep_time) + continue + + raise exceptions.InferenceRuntimeError( + f"Gemini API error: {error_message}", original=e + ) from e + # This should technically be unreachable due to the raise in the loop + raise exceptions.InferenceRuntimeError("Gemini API error: Maximum retries exceeded") except Exception as e: + if isinstance(e, exceptions.InferenceRuntimeError): + raise raise exceptions.InferenceRuntimeError( - f'Gemini API error: {str(e)}', original=e + f"Gemini API error: {str(e)}", original=e ) from e def infer( diff --git a/tests/test_gemini_backoff.py b/tests/test_gemini_backoff.py new file mode 100644 index 00000000..4529b5df --- /dev/null +++ b/tests/test_gemini_backoff.py @@ -0,0 +1,99 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Gemini provider exponential backoff.""" + +from unittest import mock +from absl.testing import absltest +from langextract.core import exceptions +from langextract.providers import gemini + +class TestGeminiBackoff(absltest.TestCase): + + @mock.patch("google.genai.Client") + @mock.patch("time.sleep") # Mock sleep to speed up tests + def test_gemini_retry_on_429(self, mock_sleep, mock_client_class): + """Test that Gemini retries on 429 errors and eventually succeeds.""" + mock_client = mock.Mock() + mock_client_class.return_value = mock_client + + # Simulate one 429 error followed by a success + mock_response = mock.Mock() + mock_response.text = '{"result": "success"}' + + mock_client.models.generate_content.side_effect = [ + Exception("429 RESOURCE_EXHAUSTED"), + mock_response + ] + + model = gemini.GeminiLanguageModel( + api_key="test-key", + max_retries=3 + ) + + results = list(model.infer(["Test prompt"])) + + self.assertEqual(len(results), 1) + self.assertEqual(results[0][0].output, '{"result": "success"}') + self.assertEqual(mock_client.models.generate_content.call_count, 2) + mock_sleep.assert_called_once() + + @mock.patch("google.genai.Client") + @mock.patch("time.sleep") + def test_gemini_max_retries_exceeded(self, mock_sleep, mock_client_class): + """Test that Gemini fails after exceeding max retries.""" + mock_client = mock.Mock() + mock_client_class.return_value = mock_client + + # Simulate continuous 429 errors + mock_client.models.generate_content.side_effect = Exception("429 RESOURCE_EXHAUSTED") + + model = gemini.GeminiLanguageModel( + api_key="test-key", + max_retries=2 + ) + + with self.assertRaises(exceptions.InferenceRuntimeError) as cm: + list(model.infer(["Test prompt"])) + + self.assertIn("Gemini API error", str(cm.exception)) + self.assertIn("429", str(cm.exception)) + # 1 initial call + 2 retries = 3 calls + self.assertEqual(mock_client.models.generate_content.call_count, 3) + self.assertEqual(mock_sleep.call_count, 2) + + @mock.patch("google.genai.Client") + @mock.patch("time.sleep") + def test_gemini_no_retry_on_other_errors(self, mock_sleep, mock_client_class): + """Test that Gemini does not retry on non-429 errors.""" + mock_client = mock.Mock() + mock_client_class.return_value = mock_client + + # Simulate a non-429 error + mock_client.models.generate_content.side_effect = Exception("500 Internal Server Error") + + model = gemini.GeminiLanguageModel( + api_key="test-key", + max_retries=3 + ) + + with self.assertRaises(exceptions.InferenceRuntimeError) as cm: + list(model.infer(["Test prompt"])) + + self.assertIn("500", str(cm.exception)) + self.assertEqual(mock_client.models.generate_content.call_count, 1) + mock_sleep.assert_not_called() + +if __name__ == "__main__": + absltest.main()