Skip to content

Commit e1b72f9

Browse files
committed
[ML] Add AnthropicModelHandler for direct Claude inference
Add ModelHandler for Anthropic Claude models using the Messages API. This enables direct inference without requiring Vertex AI/GCP setup.
1 parent 2eb71e9 commit e1b72f9

5 files changed

Lines changed: 591 additions & 0 deletions

File tree

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
"""A ModelHandler for Anthropic Claude models using the Messages API.
19+
20+
This module provides an integration between Apache Beam's RunInference
21+
transform and Anthropic's Claude models, enabling batch inference in
22+
Beam pipelines.
23+
24+
Example usage::
25+
26+
from apache_beam.ml.inference.anthropic_inference import (
27+
AnthropicModelHandler,
28+
message_from_string,
29+
)
30+
from apache_beam.ml.inference.base import RunInference
31+
32+
model_handler = AnthropicModelHandler(
33+
model_name='claude-haiku-4-5',
34+
api_key='your-api-key',
35+
request_fn=message_from_string,
36+
)
37+
38+
with beam.Pipeline() as p:
39+
results = (
40+
p
41+
| beam.Create(['What is Apache Beam?', 'Explain MapReduce.'])
42+
| RunInference(model_handler)
43+
)
44+
"""
45+
46+
import logging
47+
from collections.abc import Callable
48+
from collections.abc import Iterable
49+
from collections.abc import Sequence
50+
from typing import Any
51+
from typing import Optional
52+
53+
from anthropic import Anthropic
54+
from anthropic import APIStatusError
55+
56+
from apache_beam.ml.inference import utils
57+
from apache_beam.ml.inference.base import PredictionResult
58+
from apache_beam.ml.inference.base import RemoteModelHandler
59+
60+
LOGGER = logging.getLogger("AnthropicModelHandler")
61+
62+
63+
def _retry_on_appropriate_error(exception: Exception) -> bool:
64+
"""Retry filter that returns True for 5xx errors or 429 (rate limiting).
65+
66+
Args:
67+
exception: the exception encountered during the request/response loop.
68+
69+
Returns:
70+
True if the exception is retriable (429 or 5xx), False otherwise.
71+
"""
72+
if not isinstance(exception, APIStatusError):
73+
return False
74+
return exception.status_code == 429 or exception.status_code >= 500
75+
76+
77+
def message_from_string(
78+
model_name: str,
79+
batch: Sequence[str],
80+
client: Anthropic,
81+
inference_args: dict[str, Any]):
82+
"""Request function that sends string prompts to Claude as user messages.
83+
84+
Each string in the batch is sent as a separate request. The results are
85+
returned as a list of response objects.
86+
87+
Args:
88+
model_name: the Claude model to use (e.g. 'claude-haiku-4-5').
89+
batch: the string inputs to send to Claude.
90+
client: the Anthropic client instance.
91+
inference_args: additional arguments passed to the messages.create call.
92+
Common args include 'max_tokens', 'system', 'temperature', 'top_p'.
93+
"""
94+
max_tokens = inference_args.pop('max_tokens', 1024)
95+
responses = []
96+
for prompt in batch:
97+
response = client.messages.create(
98+
model=model_name,
99+
max_tokens=max_tokens,
100+
messages=[{"role": "user", "content": prompt}],
101+
**inference_args)
102+
responses.append(response)
103+
return responses
104+
105+
106+
def message_from_conversation(
107+
model_name: str,
108+
batch: Sequence[list[dict[str, str]]],
109+
client: Anthropic,
110+
inference_args: dict[str, Any]):
111+
"""Request function that sends multi-turn conversations to Claude.
112+
113+
Each element in the batch is a list of message dicts with 'role' and
114+
'content' keys, representing a multi-turn conversation.
115+
116+
Args:
117+
model_name: the Claude model to use.
118+
batch: a sequence of conversations (each a list of message dicts).
119+
client: the Anthropic client instance.
120+
inference_args: additional arguments passed to the messages.create call.
121+
"""
122+
max_tokens = inference_args.pop('max_tokens', 1024)
123+
responses = []
124+
for messages in batch:
125+
response = client.messages.create(
126+
model=model_name,
127+
max_tokens=max_tokens,
128+
messages=messages,
129+
**inference_args)
130+
responses.append(response)
131+
return responses
132+
133+
134+
class AnthropicModelHandler(RemoteModelHandler[Any, PredictionResult,
135+
Anthropic]):
136+
def __init__(
137+
self,
138+
model_name: str,
139+
request_fn: Callable[
140+
[str, Sequence[Any], Anthropic, dict[str, Any]], Any],
141+
api_key: Optional[str] = None,
142+
*,
143+
min_batch_size: Optional[int] = None,
144+
max_batch_size: Optional[int] = None,
145+
max_batch_duration_secs: Optional[int] = None,
146+
max_batch_weight: Optional[int] = None,
147+
element_size_fn: Optional[Callable[[Any], int]] = None,
148+
**kwargs):
149+
"""Implementation of the ModelHandler interface for Anthropic Claude.
150+
151+
**NOTE:** This API and its implementation are under development and
152+
do not provide backward compatibility guarantees.
153+
154+
This handler connects to the Anthropic Messages API to run inference
155+
using Claude models. It supports text generation from string prompts
156+
or multi-turn conversations.
157+
158+
Args:
159+
model_name: the Claude model to send requests to (e.g.
160+
'claude-sonnet-4-6', 'claude-haiku-4-5').
161+
request_fn: the function to use to send requests. Should take the
162+
model name, batch, client, and inference_args and return the
163+
responses from Claude. Built-in options are message_from_string
164+
and message_from_conversation.
165+
api_key: the Anthropic API key. If not provided, the client will
166+
look for the ANTHROPIC_API_KEY environment variable.
167+
min_batch_size: optional. the minimum batch size to use when
168+
batching inputs.
169+
max_batch_size: optional. the maximum batch size to use when
170+
batching inputs.
171+
max_batch_duration_secs: optional. the maximum amount of time to
172+
buffer a batch before emitting; used in streaming contexts.
173+
max_batch_weight: optional. the maximum total weight of a batch.
174+
element_size_fn: optional. a function that returns the size
175+
(weight) of an element.
176+
"""
177+
self._batching_kwargs = {}
178+
self._env_vars = kwargs.get('env_vars', {})
179+
if min_batch_size is not None:
180+
self._batching_kwargs["min_batch_size"] = min_batch_size
181+
if max_batch_size is not None:
182+
self._batching_kwargs["max_batch_size"] = max_batch_size
183+
if max_batch_duration_secs is not None:
184+
self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs
185+
if max_batch_weight is not None:
186+
self._batching_kwargs["max_batch_weight"] = max_batch_weight
187+
if element_size_fn is not None:
188+
self._batching_kwargs['element_size_fn'] = element_size_fn
189+
190+
self.model_name = model_name
191+
self.request_fn = request_fn
192+
self.api_key = api_key
193+
194+
super().__init__(
195+
namespace='AnthropicModelHandler',
196+
retry_filter=_retry_on_appropriate_error,
197+
**kwargs)
198+
199+
def batch_elements_kwargs(self):
200+
return self._batching_kwargs
201+
202+
def create_client(self) -> Anthropic:
203+
"""Creates the Anthropic client used to send requests.
204+
205+
If api_key was provided at construction time, it is used directly.
206+
Otherwise, the client will fall back to the ANTHROPIC_API_KEY
207+
environment variable.
208+
"""
209+
if self.api_key:
210+
return Anthropic(api_key=self.api_key)
211+
return Anthropic()
212+
213+
def request(
214+
self,
215+
batch: Sequence[Any],
216+
model: Anthropic,
217+
inference_args: Optional[dict[str, Any]] = None
218+
) -> Iterable[PredictionResult]:
219+
"""Sends a prediction request to the Anthropic API.
220+
221+
Args:
222+
batch: a sequence of inputs to be passed to the request function.
223+
model: an Anthropic client instance.
224+
inference_args: additional arguments to send as part of the
225+
prediction request (e.g. max_tokens, temperature, system).
226+
227+
Returns:
228+
An iterable of PredictionResults.
229+
"""
230+
if inference_args is None:
231+
inference_args = {}
232+
responses = self.request_fn(
233+
self.model_name, batch, model, inference_args)
234+
return utils._convert_to_result(batch, responses, self.model_name)
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
"""End-to-End test for Anthropic Claude Remote Inference"""
19+
20+
import logging
21+
import os
22+
import unittest
23+
24+
import pytest
25+
26+
try:
27+
from anthropic import Anthropic
28+
29+
from apache_beam.ml.inference.anthropic_inference import AnthropicModelHandler
30+
from apache_beam.ml.inference.anthropic_inference import message_from_string
31+
from apache_beam.ml.inference.anthropic_inference import message_from_conversation
32+
except ImportError:
33+
raise unittest.SkipTest("Anthropic dependencies are not installed")
34+
35+
import apache_beam as beam
36+
from apache_beam.ml.inference.base import RunInference
37+
from apache_beam.testing.test_pipeline import TestPipeline
38+
from apache_beam.testing.util import assert_that
39+
from apache_beam.testing.util import is_not_empty
40+
41+
_ANTHROPIC_API_KEY = os.environ.get('ANTHROPIC_API_KEY', None)
42+
_TEST_MODEL = 'claude-haiku-4-5'
43+
44+
45+
def _extract_text(prediction_result):
46+
return prediction_result.inference.content[0].text
47+
48+
49+
@pytest.mark.anthropic_postcommit
50+
class AnthropicInferenceIT(unittest.TestCase):
51+
@unittest.skipIf(
52+
_ANTHROPIC_API_KEY is None,
53+
'ANTHROPIC_API_KEY environment variable is not set')
54+
def test_anthropic_text_generation(self):
55+
handler = AnthropicModelHandler(
56+
model_name=_TEST_MODEL,
57+
request_fn=message_from_string,
58+
api_key=_ANTHROPIC_API_KEY,
59+
max_batch_size=1,
60+
)
61+
62+
prompts = [
63+
'What is Apache Beam in one sentence?',
64+
'Name three distributed computing frameworks.',
65+
]
66+
67+
with TestPipeline() as p:
68+
results = (
69+
p
70+
| beam.Create(prompts)
71+
| RunInference(handler)
72+
| beam.Map(_extract_text)
73+
)
74+
assert_that(results, is_not_empty())
75+
76+
@unittest.skipIf(
77+
_ANTHROPIC_API_KEY is None,
78+
'ANTHROPIC_API_KEY environment variable is not set')
79+
def test_anthropic_conversation(self):
80+
handler = AnthropicModelHandler(
81+
model_name=_TEST_MODEL,
82+
request_fn=message_from_conversation,
83+
api_key=_ANTHROPIC_API_KEY,
84+
max_batch_size=1,
85+
)
86+
87+
conversations = [
88+
[
89+
{"role": "user", "content": "What is 2 + 2?"},
90+
{"role": "assistant", "content": "4"},
91+
{"role": "user", "content": "Add 3 to that."},
92+
],
93+
]
94+
95+
with TestPipeline() as p:
96+
results = (
97+
p
98+
| beam.Create(conversations)
99+
| RunInference(handler)
100+
| beam.Map(_extract_text)
101+
)
102+
assert_that(results, is_not_empty())
103+
104+
105+
if __name__ == '__main__':
106+
logging.getLogger().setLevel(logging.DEBUG)
107+
unittest.main()

0 commit comments

Comments
 (0)