2525
2626 from apache_beam .ml .inference .anthropic_inference import AnthropicModelHandler
2727 from apache_beam .ml .inference .anthropic_inference import _retry_on_appropriate_error
28- from apache_beam .ml .inference .anthropic_inference import message_from_string
2928 from apache_beam .ml .inference .anthropic_inference import message_from_conversation
29+ from apache_beam .ml .inference .anthropic_inference import message_from_string
3030except ImportError :
3131 raise unittest .SkipTest ('Anthropic dependencies are not installed' )
3232
@@ -106,23 +106,24 @@ def test_sends_each_prompt(self):
106106 _make_fake_response ("answer 1" ),
107107 _make_fake_response ("answer 2" ),
108108 ]
109- results = message_from_string (
110- _TEST_MODEL , ['hello' , 'world' ], client , {})
109+ results = message_from_string (_TEST_MODEL , ['hello' , 'world' ], client , {})
111110 self .assertEqual (len (results ), 2 )
112111 self .assertEqual (client .messages .create .call_count , 2 )
113112
114113 call_args = client .messages .create .call_args_list [0 ]
115114 self .assertEqual (call_args .kwargs ['model' ], _TEST_MODEL )
116115 self .assertEqual (
117- call_args .kwargs ['messages' ],
118- [{"role" : "user" , "content" : "hello" }])
116+ call_args .kwargs ['messages' ], [{
117+ "role" : "user" , "content" : "hello"
118+ }])
119119
120120 def test_passes_inference_args (self ):
121121 client = mock .MagicMock ()
122122 client .messages .create .return_value = _make_fake_response ("ok" )
123123 message_from_string (
124- _TEST_MODEL , ['test' ], client ,
125- {'max_tokens' : 2048 , 'temperature' : 0.5 })
124+ _TEST_MODEL , ['test' ], client , {
125+ 'max_tokens' : 2048 , 'temperature' : 0.5
126+ })
126127 call_args = client .messages .create .call_args
127128 self .assertEqual (call_args .kwargs ['max_tokens' ], 2048 )
128129 self .assertEqual (call_args .kwargs ['temperature' ], 0.5 )
@@ -140,10 +141,11 @@ def test_sends_conversation(self):
140141 client = mock .MagicMock ()
141142 client .messages .create .return_value = _make_fake_response ("Paris!" )
142143 convo = [
143- {"role" : "user" , "content" : "What is the capital of France?" },
144+ {
145+ "role" : "user" , "content" : "What is the capital of France?"
146+ },
144147 ]
145- results = message_from_conversation (
146- _TEST_MODEL , [convo ], client , {})
148+ results = message_from_conversation (_TEST_MODEL , [convo ], client , {})
147149 self .assertEqual (len (results ), 1 )
148150 call_args = client .messages .create .call_args
149151 self .assertEqual (call_args .kwargs ['messages' ], convo )
@@ -162,16 +164,13 @@ def test_create_client_with_api_key(self, mock_anthropic):
162164 @mock .patch ('apache_beam.ml.inference.anthropic_inference.Anthropic' )
163165 def test_create_client_from_env (self , mock_anthropic ):
164166 handler = AnthropicModelHandler (
165- model_name = _TEST_MODEL ,
166- request_fn = message_from_string )
167+ model_name = _TEST_MODEL , request_fn = message_from_string )
167168 handler .create_client ()
168169 mock_anthropic .assert_called_once_with ()
169170
170171 def test_request_returns_prediction_results (self ):
171172 handler = AnthropicModelHandler (
172- model_name = _TEST_MODEL ,
173- request_fn = message_from_string ,
174- api_key = 'fake' )
173+ model_name = _TEST_MODEL , request_fn = message_from_string , api_key = 'fake' )
175174 mock_client = mock .MagicMock ()
176175 resp1 = _make_fake_response ("answer 1" )
177176 resp2 = _make_fake_response ("answer 2" )
@@ -201,8 +200,10 @@ def test_batch_elements_kwargs(self):
201200
202201def _fake_request_fn (model_name , batch , client , inference_args ):
203202 """A picklable request function that returns fake responses."""
204- return [FakeMessage (content = [FakeContentBlock (text = f'answer for: { p } ' )])
205- for p in batch ]
203+ return [
204+ FakeMessage (content = [FakeContentBlock (text = f'answer for: { p } ' )])
205+ for p in batch
206+ ]
206207
207208
208209class SystemPromptTest (unittest .TestCase ):
@@ -229,18 +230,14 @@ def test_system_prompt_not_overridden_by_handler(self):
229230 mock_client = mock .MagicMock ()
230231 mock_client .messages .create .return_value = _make_fake_response ("ok" )
231232
232- handler .request (
233- ['test' ], mock_client ,
234- {'system' : 'Per-request override.' })
233+ handler .request (['test' ], mock_client , {'system' : 'Per-request override.' })
235234
236235 call_args = mock_client .messages .create .call_args
237236 self .assertEqual (call_args .kwargs ['system' ], 'Per-request override.' )
238237
239238 def test_no_system_prompt_when_none (self ):
240239 handler = AnthropicModelHandler (
241- model_name = _TEST_MODEL ,
242- request_fn = message_from_string ,
243- api_key = 'fake' )
240+ model_name = _TEST_MODEL , request_fn = message_from_string , api_key = 'fake' )
244241 mock_client = mock .MagicMock ()
245242 mock_client .messages .create .return_value = _make_fake_response ("ok" )
246243
@@ -256,7 +253,11 @@ class OutputConfigTest(unittest.TestCase):
256253 'type' : 'json_schema' ,
257254 'schema' : {
258255 'type' : 'object' ,
259- 'properties' : {'answer' : {'type' : 'string' }},
256+ 'properties' : {
257+ 'answer' : {
258+ 'type' : 'string'
259+ }
260+ },
260261 'required' : ['answer' ],
261262 'additionalProperties' : False ,
262263 },
@@ -295,9 +296,7 @@ def test_output_config_not_overridden_by_handler(self):
295296
296297 def test_no_output_config_when_none (self ):
297298 handler = AnthropicModelHandler (
298- model_name = _TEST_MODEL ,
299- request_fn = message_from_string ,
300- api_key = 'fake' )
299+ model_name = _TEST_MODEL , request_fn = message_from_string , api_key = 'fake' )
301300 mock_client = mock .MagicMock ()
302301 mock_client .messages .create .return_value = _make_fake_response ("ok" )
303302
@@ -324,8 +323,7 @@ def test_pipeline_e2e(self):
324323 p
325324 | beam .Create (prompts )
326325 | RunInference (handler )
327- | beam .Map (lambda r : r .example )
328- )
326+ | beam .Map (lambda r : r .example ))
329327 assert_that (results , equal_to (prompts ))
330328
331329 def test_pipeline_with_system_prompt (self ):
@@ -345,8 +343,7 @@ def test_pipeline_with_system_prompt(self):
345343 p
346344 | beam .Create (prompts )
347345 | RunInference (handler )
348- | beam .Map (lambda r : r .example )
349- )
346+ | beam .Map (lambda r : r .example ))
350347 assert_that (results , equal_to (prompts ))
351348
352349
0 commit comments