Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion sdks/python/apache_beam/io/requestresponse.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from apache_beam.coders import coders
from apache_beam.io.components.adaptive_throttler import AdaptiveThrottler
from apache_beam.metrics import Metrics
from apache_beam.ml.inference.vertex_ai_inference import MSEC_TO_SEC
from apache_beam.transforms.util import BatchElements
from apache_beam.utils import retry

Expand All @@ -58,6 +57,8 @@
# for cache record.
DEFAULT_CACHE_ENTRY_TTL_SEC = 24 * 60 * 60

MSEC_TO_SEC = 1000

_LOGGER = logging.getLogger(__name__)

__all__ = [
Expand Down
76 changes: 16 additions & 60 deletions sdks/python/apache_beam/ml/inference/vertex_ai_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,19 @@
#

import logging
import time
from collections.abc import Iterable
from collections.abc import Mapping
from collections.abc import Sequence
from typing import Any
from typing import Dict
from typing import Iterable
from typing import Mapping
from typing import Optional
from typing import Sequence

from google.api_core.exceptions import ServerError
from google.api_core.exceptions import TooManyRequests
from google.cloud import aiplatform

from apache_beam.io.components.adaptive_throttler import AdaptiveThrottler
from apache_beam.metrics.metric import Metrics
from apache_beam.ml.inference import utils
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.utils import retry

MSEC_TO_SEC = 1000
from apache_beam.ml.inference.base import RemoteModelHandler

LOGGER = logging.getLogger("VertexAIModelHandlerJSON")

Expand All @@ -59,9 +52,9 @@ def _retry_on_appropriate_gcp_error(exception):
return isinstance(exception, (TooManyRequests, ServerError))


class VertexAIModelHandlerJSON(ModelHandler[Any,
PredictionResult,
aiplatform.Endpoint]):
class VertexAIModelHandlerJSON(RemoteModelHandler[Any,
PredictionResult,
aiplatform.Endpoint]):
def __init__(
self,
endpoint_id: str,
Expand Down Expand Up @@ -139,14 +132,10 @@ def __init__(
_ = self._retrieve_endpoint(
self.endpoint_name, self.location, self.is_private)

# Configure AdaptiveThrottler and throttling metrics for client-side
# throttling behavior.
# See https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing
# for more details.
self.throttled_secs = Metrics.counter(
VertexAIModelHandlerJSON, "cumulativeThrottlingSeconds")
self.throttler = AdaptiveThrottler(
window_ms=1, bucket_ms=1, overload_ratio=2)
super().__init__(
namespace='VertexAIModelHandlerJSON',
retry_filter=_retry_on_appropriate_gcp_error,
**kwargs)

def _retrieve_endpoint(
self, endpoint_id: str, location: str,
Expand Down Expand Up @@ -183,7 +172,7 @@ def _retrieve_endpoint(

return endpoint

def load_model(self) -> aiplatform.Endpoint:
def create_client(self) -> aiplatform.Endpoint:
"""Loads the Endpoint object used to build and send prediction request to
Vertex AI.
"""
Expand All @@ -193,39 +182,11 @@ def load_model(self) -> aiplatform.Endpoint:
self.endpoint_name, self.location, self.is_private)
return ep

@retry.with_exponential_backoff(
num_retries=5, retry_filter=_retry_on_appropriate_gcp_error)
def get_request(
self,
batch: Sequence[Any],
model: aiplatform.Endpoint,
throttle_delay_secs: int,
inference_args: Optional[Dict[str, Any]]):
while self.throttler.throttle_request(time.time() * MSEC_TO_SEC):
LOGGER.info(
"Delaying request for %d seconds due to previous failures",
throttle_delay_secs)
time.sleep(throttle_delay_secs)
self.throttled_secs.inc(throttle_delay_secs)

try:
req_time = time.time()
prediction = model.predict(
instances=list(batch), parameters=inference_args)
self.throttler.successful_request(req_time * MSEC_TO_SEC)
return prediction
except TooManyRequests as e:
LOGGER.warning("request was limited by the service with code %i", e.code)
raise
except Exception as e:
LOGGER.error("unexpected exception raised as part of request, got %s", e)
raise

def run_inference(
def request(
self,
batch: Sequence[Any],
model: aiplatform.Endpoint,
inference_args: Optional[Dict[str, Any]] = None
inference_args: Optional[dict[str, Any]] = None
) -> Iterable[PredictionResult]:
""" Sends a prediction request to a Vertex AI endpoint containing batch
of inputs and matches that input with the prediction response from
Expand All @@ -242,16 +203,11 @@ def run_inference(
Returns:
An iterable of Predictions.
"""

# Endpoint.predict returns a Prediction type with the prediction values
# along with model metadata
prediction = self.get_request(
batch, model, throttle_delay_secs=5, inference_args=inference_args)

prediction = model.predict(instances=list(batch), parameters=inference_args)
return utils._convert_to_result(
batch, prediction.predictions, prediction.deployed_model_id)

def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):
pass

def batch_elements_kwargs(self) -> Mapping[str, Any]:
Expand Down
Loading