|
| 1 | +# |
| 2 | +# Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | +# contributor license agreements. See the NOTICE file distributed with |
| 4 | +# this work for additional information regarding copyright ownership. |
| 5 | +# The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | +# (the "License"); you may not use this file except in compliance with |
| 7 | +# the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | +# |
| 17 | + |
| 18 | +"""A ModelHandler for Anthropic Claude models using the Messages API. |
| 19 | +
|
| 20 | +This module provides an integration between Apache Beam's RunInference |
| 21 | +transform and Anthropic's Claude models, enabling batch inference in |
| 22 | +Beam pipelines. |
| 23 | +
|
| 24 | +Example usage:: |
| 25 | +
|
| 26 | + from apache_beam.ml.inference.anthropic_inference import ( |
| 27 | + AnthropicModelHandler, |
| 28 | + message_from_string, |
| 29 | + ) |
| 30 | + from apache_beam.ml.inference.base import RunInference |
| 31 | +
|
| 32 | + model_handler = AnthropicModelHandler( |
| 33 | + model_name='claude-haiku-4-5', |
| 34 | + api_key='your-api-key', |
| 35 | + request_fn=message_from_string, |
| 36 | + ) |
| 37 | +
|
| 38 | + with beam.Pipeline() as p: |
| 39 | + results = ( |
| 40 | + p |
| 41 | + | beam.Create(['What is Apache Beam?', 'Explain MapReduce.']) |
| 42 | + | RunInference(model_handler) |
| 43 | + ) |
| 44 | +""" |
| 45 | + |
| 46 | +import logging |
| 47 | +from collections.abc import Callable |
| 48 | +from collections.abc import Iterable |
| 49 | +from collections.abc import Sequence |
| 50 | +from typing import Any |
| 51 | +from typing import Optional |
| 52 | + |
| 53 | +from anthropic import Anthropic |
| 54 | +from anthropic import APIStatusError |
| 55 | + |
| 56 | +from apache_beam.ml.inference import utils |
| 57 | +from apache_beam.ml.inference.base import PredictionResult |
| 58 | +from apache_beam.ml.inference.base import RemoteModelHandler |
| 59 | + |
| 60 | +LOGGER = logging.getLogger("AnthropicModelHandler") |
| 61 | + |
| 62 | + |
| 63 | +def _retry_on_appropriate_error(exception: Exception) -> bool: |
| 64 | + """Retry filter that returns True for 5xx errors or 429 (rate limiting). |
| 65 | +
|
| 66 | + Args: |
| 67 | + exception: the exception encountered during the request/response loop. |
| 68 | +
|
| 69 | + Returns: |
| 70 | + True if the exception is retriable (429 or 5xx), False otherwise. |
| 71 | + """ |
| 72 | + if not isinstance(exception, APIStatusError): |
| 73 | + return False |
| 74 | + return exception.status_code == 429 or exception.status_code >= 500 |
| 75 | + |
| 76 | + |
| 77 | +def message_from_string( |
| 78 | + model_name: str, |
| 79 | + batch: Sequence[str], |
| 80 | + client: Anthropic, |
| 81 | + inference_args: dict[str, Any]): |
| 82 | + """Request function that sends string prompts to Claude as user messages. |
| 83 | +
|
| 84 | + Each string in the batch is sent as a separate request. The results are |
| 85 | + returned as a list of response objects. |
| 86 | +
|
| 87 | + Args: |
| 88 | + model_name: the Claude model to use (e.g. 'claude-haiku-4-5'). |
| 89 | + batch: the string inputs to send to Claude. |
| 90 | + client: the Anthropic client instance. |
| 91 | + inference_args: additional arguments passed to the messages.create call. |
| 92 | + Common args include 'max_tokens', 'system', 'temperature', 'top_p'. |
| 93 | + """ |
| 94 | + max_tokens = inference_args.pop('max_tokens', 1024) |
| 95 | + responses = [] |
| 96 | + for prompt in batch: |
| 97 | + response = client.messages.create( |
| 98 | + model=model_name, |
| 99 | + max_tokens=max_tokens, |
| 100 | + messages=[{"role": "user", "content": prompt}], |
| 101 | + **inference_args) |
| 102 | + responses.append(response) |
| 103 | + return responses |
| 104 | + |
| 105 | + |
| 106 | +def message_from_conversation( |
| 107 | + model_name: str, |
| 108 | + batch: Sequence[list[dict[str, str]]], |
| 109 | + client: Anthropic, |
| 110 | + inference_args: dict[str, Any]): |
| 111 | + """Request function that sends multi-turn conversations to Claude. |
| 112 | +
|
| 113 | + Each element in the batch is a list of message dicts with 'role' and |
| 114 | + 'content' keys, representing a multi-turn conversation. |
| 115 | +
|
| 116 | + Args: |
| 117 | + model_name: the Claude model to use. |
| 118 | + batch: a sequence of conversations (each a list of message dicts). |
| 119 | + client: the Anthropic client instance. |
| 120 | + inference_args: additional arguments passed to the messages.create call. |
| 121 | + """ |
| 122 | + max_tokens = inference_args.pop('max_tokens', 1024) |
| 123 | + responses = [] |
| 124 | + for messages in batch: |
| 125 | + response = client.messages.create( |
| 126 | + model=model_name, |
| 127 | + max_tokens=max_tokens, |
| 128 | + messages=messages, |
| 129 | + **inference_args) |
| 130 | + responses.append(response) |
| 131 | + return responses |
| 132 | + |
| 133 | + |
| 134 | +class AnthropicModelHandler(RemoteModelHandler[Any, PredictionResult, |
| 135 | + Anthropic]): |
| 136 | + def __init__( |
| 137 | + self, |
| 138 | + model_name: str, |
| 139 | + request_fn: Callable[ |
| 140 | + [str, Sequence[Any], Anthropic, dict[str, Any]], Any], |
| 141 | + api_key: Optional[str] = None, |
| 142 | + *, |
| 143 | + min_batch_size: Optional[int] = None, |
| 144 | + max_batch_size: Optional[int] = None, |
| 145 | + max_batch_duration_secs: Optional[int] = None, |
| 146 | + max_batch_weight: Optional[int] = None, |
| 147 | + element_size_fn: Optional[Callable[[Any], int]] = None, |
| 148 | + **kwargs): |
| 149 | + """Implementation of the ModelHandler interface for Anthropic Claude. |
| 150 | +
|
| 151 | + **NOTE:** This API and its implementation are under development and |
| 152 | + do not provide backward compatibility guarantees. |
| 153 | +
|
| 154 | + This handler connects to the Anthropic Messages API to run inference |
| 155 | + using Claude models. It supports text generation from string prompts |
| 156 | + or multi-turn conversations. |
| 157 | +
|
| 158 | + Args: |
| 159 | + model_name: the Claude model to send requests to (e.g. |
| 160 | + 'claude-sonnet-4-6', 'claude-haiku-4-5'). |
| 161 | + request_fn: the function to use to send requests. Should take the |
| 162 | + model name, batch, client, and inference_args and return the |
| 163 | + responses from Claude. Built-in options are message_from_string |
| 164 | + and message_from_conversation. |
| 165 | + api_key: the Anthropic API key. If not provided, the client will |
| 166 | + look for the ANTHROPIC_API_KEY environment variable. |
| 167 | + min_batch_size: optional. the minimum batch size to use when |
| 168 | + batching inputs. |
| 169 | + max_batch_size: optional. the maximum batch size to use when |
| 170 | + batching inputs. |
| 171 | + max_batch_duration_secs: optional. the maximum amount of time to |
| 172 | + buffer a batch before emitting; used in streaming contexts. |
| 173 | + max_batch_weight: optional. the maximum total weight of a batch. |
| 174 | + element_size_fn: optional. a function that returns the size |
| 175 | + (weight) of an element. |
| 176 | + """ |
| 177 | + self._batching_kwargs = {} |
| 178 | + self._env_vars = kwargs.get('env_vars', {}) |
| 179 | + if min_batch_size is not None: |
| 180 | + self._batching_kwargs["min_batch_size"] = min_batch_size |
| 181 | + if max_batch_size is not None: |
| 182 | + self._batching_kwargs["max_batch_size"] = max_batch_size |
| 183 | + if max_batch_duration_secs is not None: |
| 184 | + self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs |
| 185 | + if max_batch_weight is not None: |
| 186 | + self._batching_kwargs["max_batch_weight"] = max_batch_weight |
| 187 | + if element_size_fn is not None: |
| 188 | + self._batching_kwargs['element_size_fn'] = element_size_fn |
| 189 | + |
| 190 | + self.model_name = model_name |
| 191 | + self.request_fn = request_fn |
| 192 | + self.api_key = api_key |
| 193 | + |
| 194 | + super().__init__( |
| 195 | + namespace='AnthropicModelHandler', |
| 196 | + retry_filter=_retry_on_appropriate_error, |
| 197 | + **kwargs) |
| 198 | + |
| 199 | + def batch_elements_kwargs(self): |
| 200 | + return self._batching_kwargs |
| 201 | + |
| 202 | + def create_client(self) -> Anthropic: |
| 203 | + """Creates the Anthropic client used to send requests. |
| 204 | +
|
| 205 | + If api_key was provided at construction time, it is used directly. |
| 206 | + Otherwise, the client will fall back to the ANTHROPIC_API_KEY |
| 207 | + environment variable. |
| 208 | + """ |
| 209 | + if self.api_key: |
| 210 | + return Anthropic(api_key=self.api_key) |
| 211 | + return Anthropic() |
| 212 | + |
| 213 | + def request( |
| 214 | + self, |
| 215 | + batch: Sequence[Any], |
| 216 | + model: Anthropic, |
| 217 | + inference_args: Optional[dict[str, Any]] = None |
| 218 | + ) -> Iterable[PredictionResult]: |
| 219 | + """Sends a prediction request to the Anthropic API. |
| 220 | +
|
| 221 | + Args: |
| 222 | + batch: a sequence of inputs to be passed to the request function. |
| 223 | + model: an Anthropic client instance. |
| 224 | + inference_args: additional arguments to send as part of the |
| 225 | + prediction request (e.g. max_tokens, temperature, system). |
| 226 | +
|
| 227 | + Returns: |
| 228 | + An iterable of PredictionResults. |
| 229 | + """ |
| 230 | + if inference_args is None: |
| 231 | + inference_args = {} |
| 232 | + responses = self.request_fn( |
| 233 | + self.model_name, batch, model, inference_args) |
| 234 | + return utils._convert_to_result(batch, responses, self.model_name) |
0 commit comments