3737
3838from apache_beam .ml .inference .base import PredictionResult
3939
40- # Configure logger for debugging tests related to _retry_on_appropriate_openai_error
41- # This gets the logger instance used in openai_inference.py
42- logger_to_debug = logging .getLogger ("OpenAIModelHandler" )
43- logger_to_debug .setLevel (logging .DEBUG )
44- # Add a handler to see the output during tests, e.g., stream to stderr
45- # Check if a handler already exists to avoid duplicate messages if tests are run multiple times
46- if not any (isinstance (h , logging .StreamHandler )
47- for h in logger_to_debug .handlers ):
48- stream_handler = logging .StreamHandler ()
49- stream_handler .setLevel (logging .DEBUG )
50- formatter = logging .Formatter (
51- '%(asctime)s - %(name)s - %(levelname)s - %(message)s' )
52- stream_handler .setFormatter (formatter )
53- logger_to_debug .addHandler (stream_handler )
54-
5540
5641class RetryOnAPIErrorTest (unittest .TestCase ):
5742 def _create_mock_error_with_status (self , status_code , error_class = APIError ):
5843 """
5944 Helper to create a mock error object (APIError or RateLimitError)
6045 with a given status code.
61- The key is to ensure that `getattr(err, 'status_code', None)` works as expected.
46+ The key is to ensure that `getattr(err, 'status_code', None)` works as
47+ expected.
6248 For real OpenAI errors:
63- - RateLimitError (and other APIStatusErrors) have `err.status_code` as a direct attribute.
64- - APIError (the base) has `err.status_code` as a property that inspects `err.request.response.status_code`.
49+ - RateLimitError (and other APIStatusErrors) have `err.status_code` as a
50+ direct attribute.
51+ - APIError (the base) has `err.status_code` as a property that inspects
52+ `err.request.response.status_code`.
6553 """
6654 mock_response = MagicMock (spec = httpx .Response )
6755 # mock_response.status_code will be set below.
@@ -92,8 +80,9 @@ def _create_mock_error_with_status(self, status_code, error_class=APIError):
9280 mock_request_that_failed .response = response_for_api_error_property
9381
9482 err = APIError ("API error" , request = mock_request_that_failed , body = None )
95- # Directly set status_code on the instance for getattr in the retry function to pick up.
96- # This is simpler than ensuring the nested property mock works perfectly.
83+ # Directly set status_code on the instance for getattr in the retry
84+ # function to pick up. This is simpler than ensuring the nested
85+ # property mock works perfectly.
9786 # Note: This shadows the property for this instance.
9887 err .status_code = status_code
9988 return err
@@ -161,7 +150,7 @@ def test_request_completion_model_success(
161150 CompletionChoice (
162151 text = " World!" , index = 0 , finish_reason = "length" , logprobs = None )
163152 ])
164- mock_openai_client .completions .create .return_value = mock_completion_response
153+ mock_openai_client .completions .create .return_value = mock_completion_response # pylint: disable=line-too-long
165154
166155 handler = OpenAIModelHandler (api_key = self .api_key , model = self .model_name )
167156 # Initialize client by calling create_client or load_model
@@ -188,9 +177,10 @@ def test_request_completion_model_success(
188177 @patch ('openai.OpenAI' )
189178 def test_request_chat_model_success (self , mock_openai_client_constructor ):
190179 mock_openai_client = MagicMock ()
191- # Simulate chat model by checking a mock attribute on the client's chat completions path
192- # This is a bit of a hack for testing the path in generate_completion
193- mock_openai_client .chat .completions .with_raw_response .create .binary_relative_path = "chat.completions"
180+ # Simulate chat model by checking a mock attribute on the client's chat
181+ # completions path. This is a bit of a hack for testing the path in
182+ # generate_completion.
183+ mock_openai_client .chat .completions .with_raw_response .create .binary_relative_path = "chat.completions" # pylint: disable=line-too-long
194184 mock_openai_client_constructor .return_value = mock_openai_client
195185
196186 # Mock the response from client.chat.completions.create
@@ -206,7 +196,7 @@ def test_request_chat_model_success(self, mock_openai_client_constructor):
206196 role = "assistant" , content = "There!" ),
207197 finish_reason = "stop" )
208198 ])
209- mock_openai_client .chat .completions .create .return_value = mock_chat_response
199+ mock_openai_client .chat .completions .create .return_value = mock_chat_response # pylint: disable=line-too-long
210200
211201 handler = OpenAIModelHandler (
212202 api_key = self .api_key , model = self .chat_model_name )
0 commit comments