Skip to content

Commit f118195

Browse files
committed
linting
1 parent 935dd5b commit f118195

3 files changed

Lines changed: 24 additions & 32 deletions

File tree

sdks/python/apache_beam/ml/inference/openai_inference.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ def generate_completion(
7070
# Note: OpenAI's library expects a single prompt for completions.create,
7171
# so we iterate and call. Batching is handled by RunInference.
7272
# For chat models, multiple messages can be part of a single request.
73-
if "chat.completions" in client.chat.completions.with_raw_response.create.binary_relative_path: # rough check
73+
if ("chat.completions" in client.chat.completions.with_raw_response.
74+
create.binary_relative_path):
7475
# Assuming chat model if path indicates chat completions
7576
# User might need to format input as list of messages
7677
# For simplicity, we'll assume a single user message per prompt string

sdks/python/apache_beam/ml/inference/openai_inference_it_test.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
_OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
4141

4242
# Models for testing - one completion, one chat
43-
_COMPLETION_MODEL = "gpt-3.5-turbo-instruct" # A smaller, faster completion model
43+
_COMPLETION_MODEL = "gpt-3.5-turbo-instruct"
4444
_CHAT_MODEL = "gpt-3.5-turbo"
4545

4646

@@ -98,7 +98,7 @@ def process_output_file(readable_file):
9898
self.assertTrue(
9999
any("PredictionResult(example=" in line for line in match_results))
100100

101-
@pytest.mark.openai_postcommit # Mark as postcommit as it makes external calls.
101+
@pytest.mark.openai_postcommit
102102
def test_openai_completion_model(self):
103103
model_handler = OpenAIModelHandler(
104104
api_key=_OPENAI_API_KEY, model=_COMPLETION_MODEL)
@@ -116,7 +116,7 @@ def test_openai_completion_model(self):
116116
def test_openai_chat_model(self):
117117
model_handler = OpenAIModelHandler(
118118
api_key=_OPENAI_API_KEY, model=_CHAT_MODEL)
119-
# Chat models expect a list of messages or a single string (handled as user message)
119+
# Chat models expect a list of messages or a single string
120120
test_data = [
121121
"What is 2+2?", # Single string prompt
122122
[{
@@ -134,10 +134,11 @@ def test_openai_chat_model(self):
134134
def test_openai_chat_model_with_system_message(self):
135135
model_handler = OpenAIModelHandler(
136136
api_key=_OPENAI_API_KEY, model=_CHAT_MODEL)
137-
# Chat models expect a list of messages or a single string (handled as user message)
137+
# Chat models expect a list of messages or a single string
138138
test_data = [
139-
# This requires the OpenAIModelHandler's generate_completion to correctly
140-
# handle list of messages if the input element itself is a list of dicts.
139+
# This requires the OpenAIModelHandler's generate_completion to
140+
# correctly handle list of messages if the input element itself
141+
# is a list of dicts.
141142
[{
142143
"role": "system",
143144
"content": "You are a helpful assistant that speaks like a pirate."

sdks/python/apache_beam/ml/inference/openai_inference_test.py

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -37,31 +37,19 @@
3737

3838
from apache_beam.ml.inference.base import PredictionResult
3939

40-
# Configure logger for debugging tests related to _retry_on_appropriate_openai_error
41-
# This gets the logger instance used in openai_inference.py
42-
logger_to_debug = logging.getLogger("OpenAIModelHandler")
43-
logger_to_debug.setLevel(logging.DEBUG)
44-
# Add a handler to see the output during tests, e.g., stream to stderr
45-
# Check if a handler already exists to avoid duplicate messages if tests are run multiple times
46-
if not any(isinstance(h, logging.StreamHandler)
47-
for h in logger_to_debug.handlers):
48-
stream_handler = logging.StreamHandler()
49-
stream_handler.setLevel(logging.DEBUG)
50-
formatter = logging.Formatter(
51-
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
52-
stream_handler.setFormatter(formatter)
53-
logger_to_debug.addHandler(stream_handler)
54-
5540

5641
class RetryOnAPIErrorTest(unittest.TestCase):
5742
def _create_mock_error_with_status(self, status_code, error_class=APIError):
5843
"""
5944
Helper to create a mock error object (APIError or RateLimitError)
6045
with a given status code.
61-
The key is to ensure that `getattr(err, 'status_code', None)` works as expected.
46+
The key is to ensure that `getattr(err, 'status_code', None)` works as
47+
expected.
6248
For real OpenAI errors:
63-
- RateLimitError (and other APIStatusErrors) have `err.status_code` as a direct attribute.
64-
- APIError (the base) has `err.status_code` as a property that inspects `err.request.response.status_code`.
49+
- RateLimitError (and other APIStatusErrors) have `err.status_code` as a
50+
direct attribute.
51+
- APIError (the base) has `err.status_code` as a property that inspects
52+
`err.request.response.status_code`.
6553
"""
6654
mock_response = MagicMock(spec=httpx.Response)
6755
# mock_response.status_code will be set below.
@@ -92,8 +80,9 @@ def _create_mock_error_with_status(self, status_code, error_class=APIError):
9280
mock_request_that_failed.response = response_for_api_error_property
9381

9482
err = APIError("API error", request=mock_request_that_failed, body=None)
95-
# Directly set status_code on the instance for getattr in the retry function to pick up.
96-
# This is simpler than ensuring the nested property mock works perfectly.
83+
# Directly set status_code on the instance for getattr in the retry
84+
# function to pick up. This is simpler than ensuring the nested
85+
# property mock works perfectly.
9786
# Note: This shadows the property for this instance.
9887
err.status_code = status_code
9988
return err
@@ -161,7 +150,7 @@ def test_request_completion_model_success(
161150
CompletionChoice(
162151
text=" World!", index=0, finish_reason="length", logprobs=None)
163152
])
164-
mock_openai_client.completions.create.return_value = mock_completion_response
153+
mock_openai_client.completions.create.return_value = mock_completion_response # pylint: disable=line-too-long
165154

166155
handler = OpenAIModelHandler(api_key=self.api_key, model=self.model_name)
167156
# Initialize client by calling create_client or load_model
@@ -188,9 +177,10 @@ def test_request_completion_model_success(
188177
@patch('openai.OpenAI')
189178
def test_request_chat_model_success(self, mock_openai_client_constructor):
190179
mock_openai_client = MagicMock()
191-
# Simulate chat model by checking a mock attribute on the client's chat completions path
192-
# This is a bit of a hack for testing the path in generate_completion
193-
mock_openai_client.chat.completions.with_raw_response.create.binary_relative_path = "chat.completions"
180+
# Simulate chat model by checking a mock attribute on the client's chat
181+
# completions path. This is a bit of a hack for testing the path in
182+
# generate_completion.
183+
mock_openai_client.chat.completions.with_raw_response.create.binary_relative_path = "chat.completions" # pylint: disable=line-too-long
194184
mock_openai_client_constructor.return_value = mock_openai_client
195185

196186
# Mock the response from client.chat.completions.create
@@ -206,7 +196,7 @@ def test_request_chat_model_success(self, mock_openai_client_constructor):
206196
role="assistant", content="There!"),
207197
finish_reason="stop")
208198
])
209-
mock_openai_client.chat.completions.create.return_value = mock_chat_response
199+
mock_openai_client.chat.completions.create.return_value = mock_chat_response # pylint: disable=line-too-long
210200

211201
handler = OpenAIModelHandler(
212202
api_key=self.api_key, model=self.chat_model_name)

0 commit comments

Comments
 (0)