Skip to content

Commit e322252

Browse files
committed
[ML] Apply yapf formatting and fix isort order for Anthropic inference
Run yapf with Beam's setup.cfg style config to fix all formatting discrepancies. Fix import ordering (message_from_conversation before message_from_string) in both test files to satisfy isort.
1 parent 5beba44 commit e322252

3 files changed

Lines changed: 62 additions & 64 deletions

File tree

sdks/python/apache_beam/ml/inference/anthropic_inference.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,9 @@ def message_from_string(
121121
response = client.messages.create(
122122
model=model_name,
123123
max_tokens=max_tokens,
124-
messages=[{"role": "user", "content": prompt}],
124+
messages=[{
125+
"role": "user", "content": prompt
126+
}],
125127
**inference_args)
126128
responses.append(response)
127129
return responses
@@ -160,8 +162,8 @@ class AnthropicModelHandler(RemoteModelHandler[Any, PredictionResult,
160162
def __init__(
161163
self,
162164
model_name: str,
163-
request_fn: Callable[
164-
[str, Sequence[Any], Anthropic, dict[str, Any]], Any],
165+
request_fn: Callable[[str, Sequence[Any], Anthropic, dict[str, Any]],
166+
Any],
165167
api_key: Optional[str] = None,
166168
*,
167169
system: Optional[Union[str, list[dict[str, str]]]] = None,
@@ -290,6 +292,5 @@ def request(
290292
inference_args['system'] = self.system
291293
if self.output_config is not None and 'output_config' not in inference_args:
292294
inference_args['output_config'] = self.output_config
293-
responses = self.request_fn(
294-
self.model_name, batch, model, inference_args)
295+
responses = self.request_fn(self.model_name, batch, model, inference_args)
295296
return utils._convert_to_result(batch, responses, self.model_name)

sdks/python/apache_beam/ml/inference/anthropic_inference_it_test.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525

2626
try:
2727
from apache_beam.ml.inference.anthropic_inference import AnthropicModelHandler
28-
from apache_beam.ml.inference.anthropic_inference import message_from_string
2928
from apache_beam.ml.inference.anthropic_inference import message_from_conversation
29+
from apache_beam.ml.inference.anthropic_inference import message_from_string
3030
except ImportError:
3131
raise unittest.SkipTest("Anthropic dependencies are not installed")
3232

@@ -67,8 +67,7 @@ def test_anthropic_text_generation(self):
6767
p
6868
| beam.Create(prompts)
6969
| RunInference(handler)
70-
| beam.Map(_extract_text)
71-
)
70+
| beam.Map(_extract_text))
7271
assert_that(results, is_not_empty())
7372

7473
@unittest.skipIf(
@@ -84,9 +83,15 @@ def test_anthropic_conversation(self):
8483

8584
conversations = [
8685
[
87-
{"role": "user", "content": "What is 2 + 2?"},
88-
{"role": "assistant", "content": "4"},
89-
{"role": "user", "content": "Add 3 to that."},
86+
{
87+
"role": "user", "content": "What is 2 + 2?"
88+
},
89+
{
90+
"role": "assistant", "content": "4"
91+
},
92+
{
93+
"role": "user", "content": "Add 3 to that."
94+
},
9095
],
9196
]
9297

@@ -95,8 +100,7 @@ def test_anthropic_conversation(self):
95100
p
96101
| beam.Create(conversations)
97102
| RunInference(handler)
98-
| beam.Map(_extract_text)
99-
)
103+
| beam.Map(_extract_text))
100104
assert_that(results, is_not_empty())
101105

102106
@unittest.skipIf(
@@ -118,8 +122,7 @@ def test_anthropic_with_system_prompt(self):
118122
p
119123
| beam.Create(prompts)
120124
| RunInference(handler)
121-
| beam.Map(_extract_text)
122-
)
125+
| beam.Map(_extract_text))
123126
assert_that(results, is_not_empty())
124127

125128
@unittest.skipIf(
@@ -133,8 +136,7 @@ def test_anthropic_system_prompt_with_structured_output(self):
133136
system=(
134137
"You are a counting bot. When asked to count objects, convert "
135138
"responses such that numbers that are multiples of 3 are written "
136-
"as 'Fizz' instead of the number."
137-
),
139+
"as 'Fizz' instead of the number."),
138140
output_config={
139141
'format': {
140142
'type': 'json_schema',
@@ -146,8 +148,12 @@ def test_anthropic_system_prompt_with_structured_output(self):
146148
'items': {
147149
'type': 'object',
148150
'properties': {
149-
'name': {'type': 'string'},
150-
'count': {'type': 'string'},
151+
'name': {
152+
'type': 'string'
153+
},
154+
'count': {
155+
'type': 'string'
156+
},
151157
},
152158
'required': ['name', 'count'],
153159
'additionalProperties': False,
@@ -169,8 +175,7 @@ def test_anthropic_system_prompt_with_structured_output(self):
169175
p
170176
| beam.Create(prompts)
171177
| RunInference(handler)
172-
| beam.Map(_extract_text)
173-
)
178+
| beam.Map(_extract_text))
174179

175180
def verify_fizz(response_text):
176181
import json
@@ -187,24 +192,19 @@ def verify_fizz(response_text):
187192
if 'banana' in name:
188193
found_banana = True
189194
if count != 'Fizz':
190-
raise ValueError(
191-
'Expected banana count Fizz, '
192-
'got %s' % count)
195+
raise ValueError('Expected banana count Fizz, '
196+
'got %s' % count)
193197
elif 'orange' in name:
194198
found_orange = True
195199
if count != 'Fizz':
196-
raise ValueError(
197-
'Expected orange count Fizz, '
198-
'got %s' % count)
200+
raise ValueError('Expected orange count Fizz, '
201+
'got %s' % count)
199202
elif 'apple' in name:
200203
if count != '2':
201-
raise ValueError(
202-
'Expected apple count 2, '
203-
'got %s' % count)
204+
raise ValueError('Expected apple count 2, '
205+
'got %s' % count)
204206
if not found_banana or not found_orange:
205-
raise ValueError(
206-
'Missing expected items: %s'
207-
% response_text)
207+
raise ValueError('Missing expected items: %s' % response_text)
208208
return response_text
209209

210210
_ = results | beam.Map(verify_fizz)

sdks/python/apache_beam/ml/inference/anthropic_inference_test.py

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525

2626
from apache_beam.ml.inference.anthropic_inference import AnthropicModelHandler
2727
from apache_beam.ml.inference.anthropic_inference import _retry_on_appropriate_error
28-
from apache_beam.ml.inference.anthropic_inference import message_from_string
2928
from apache_beam.ml.inference.anthropic_inference import message_from_conversation
29+
from apache_beam.ml.inference.anthropic_inference import message_from_string
3030
except ImportError:
3131
raise unittest.SkipTest('Anthropic dependencies are not installed')
3232

@@ -106,23 +106,24 @@ def test_sends_each_prompt(self):
106106
_make_fake_response("answer 1"),
107107
_make_fake_response("answer 2"),
108108
]
109-
results = message_from_string(
110-
_TEST_MODEL, ['hello', 'world'], client, {})
109+
results = message_from_string(_TEST_MODEL, ['hello', 'world'], client, {})
111110
self.assertEqual(len(results), 2)
112111
self.assertEqual(client.messages.create.call_count, 2)
113112

114113
call_args = client.messages.create.call_args_list[0]
115114
self.assertEqual(call_args.kwargs['model'], _TEST_MODEL)
116115
self.assertEqual(
117-
call_args.kwargs['messages'],
118-
[{"role": "user", "content": "hello"}])
116+
call_args.kwargs['messages'], [{
117+
"role": "user", "content": "hello"
118+
}])
119119

120120
def test_passes_inference_args(self):
121121
client = mock.MagicMock()
122122
client.messages.create.return_value = _make_fake_response("ok")
123123
message_from_string(
124-
_TEST_MODEL, ['test'], client,
125-
{'max_tokens': 2048, 'temperature': 0.5})
124+
_TEST_MODEL, ['test'], client, {
125+
'max_tokens': 2048, 'temperature': 0.5
126+
})
126127
call_args = client.messages.create.call_args
127128
self.assertEqual(call_args.kwargs['max_tokens'], 2048)
128129
self.assertEqual(call_args.kwargs['temperature'], 0.5)
@@ -140,10 +141,11 @@ def test_sends_conversation(self):
140141
client = mock.MagicMock()
141142
client.messages.create.return_value = _make_fake_response("Paris!")
142143
convo = [
143-
{"role": "user", "content": "What is the capital of France?"},
144+
{
145+
"role": "user", "content": "What is the capital of France?"
146+
},
144147
]
145-
results = message_from_conversation(
146-
_TEST_MODEL, [convo], client, {})
148+
results = message_from_conversation(_TEST_MODEL, [convo], client, {})
147149
self.assertEqual(len(results), 1)
148150
call_args = client.messages.create.call_args
149151
self.assertEqual(call_args.kwargs['messages'], convo)
@@ -162,16 +164,13 @@ def test_create_client_with_api_key(self, mock_anthropic):
162164
@mock.patch('apache_beam.ml.inference.anthropic_inference.Anthropic')
163165
def test_create_client_from_env(self, mock_anthropic):
164166
handler = AnthropicModelHandler(
165-
model_name=_TEST_MODEL,
166-
request_fn=message_from_string)
167+
model_name=_TEST_MODEL, request_fn=message_from_string)
167168
handler.create_client()
168169
mock_anthropic.assert_called_once_with()
169170

170171
def test_request_returns_prediction_results(self):
171172
handler = AnthropicModelHandler(
172-
model_name=_TEST_MODEL,
173-
request_fn=message_from_string,
174-
api_key='fake')
173+
model_name=_TEST_MODEL, request_fn=message_from_string, api_key='fake')
175174
mock_client = mock.MagicMock()
176175
resp1 = _make_fake_response("answer 1")
177176
resp2 = _make_fake_response("answer 2")
@@ -201,8 +200,10 @@ def test_batch_elements_kwargs(self):
201200

202201
def _fake_request_fn(model_name, batch, client, inference_args):
203202
"""A picklable request function that returns fake responses."""
204-
return [FakeMessage(content=[FakeContentBlock(text=f'answer for: {p}')])
205-
for p in batch]
203+
return [
204+
FakeMessage(content=[FakeContentBlock(text=f'answer for: {p}')])
205+
for p in batch
206+
]
206207

207208

208209
class SystemPromptTest(unittest.TestCase):
@@ -229,18 +230,14 @@ def test_system_prompt_not_overridden_by_handler(self):
229230
mock_client = mock.MagicMock()
230231
mock_client.messages.create.return_value = _make_fake_response("ok")
231232

232-
handler.request(
233-
['test'], mock_client,
234-
{'system': 'Per-request override.'})
233+
handler.request(['test'], mock_client, {'system': 'Per-request override.'})
235234

236235
call_args = mock_client.messages.create.call_args
237236
self.assertEqual(call_args.kwargs['system'], 'Per-request override.')
238237

239238
def test_no_system_prompt_when_none(self):
240239
handler = AnthropicModelHandler(
241-
model_name=_TEST_MODEL,
242-
request_fn=message_from_string,
243-
api_key='fake')
240+
model_name=_TEST_MODEL, request_fn=message_from_string, api_key='fake')
244241
mock_client = mock.MagicMock()
245242
mock_client.messages.create.return_value = _make_fake_response("ok")
246243

@@ -256,7 +253,11 @@ class OutputConfigTest(unittest.TestCase):
256253
'type': 'json_schema',
257254
'schema': {
258255
'type': 'object',
259-
'properties': {'answer': {'type': 'string'}},
256+
'properties': {
257+
'answer': {
258+
'type': 'string'
259+
}
260+
},
260261
'required': ['answer'],
261262
'additionalProperties': False,
262263
},
@@ -295,9 +296,7 @@ def test_output_config_not_overridden_by_handler(self):
295296

296297
def test_no_output_config_when_none(self):
297298
handler = AnthropicModelHandler(
298-
model_name=_TEST_MODEL,
299-
request_fn=message_from_string,
300-
api_key='fake')
299+
model_name=_TEST_MODEL, request_fn=message_from_string, api_key='fake')
301300
mock_client = mock.MagicMock()
302301
mock_client.messages.create.return_value = _make_fake_response("ok")
303302

@@ -324,8 +323,7 @@ def test_pipeline_e2e(self):
324323
p
325324
| beam.Create(prompts)
326325
| RunInference(handler)
327-
| beam.Map(lambda r: r.example)
328-
)
326+
| beam.Map(lambda r: r.example))
329327
assert_that(results, equal_to(prompts))
330328

331329
def test_pipeline_with_system_prompt(self):
@@ -345,8 +343,7 @@ def test_pipeline_with_system_prompt(self):
345343
p
346344
| beam.Create(prompts)
347345
| RunInference(handler)
348-
| beam.Map(lambda r: r.example)
349-
)
346+
| beam.Map(lambda r: r.example))
350347
assert_that(results, equal_to(prompts))
351348

352349

0 commit comments

Comments
 (0)