Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 135 additions & 1 deletion sdks/python/apache_beam/ml/inference/agent_development_kit.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import asyncio
import logging
import uuid

from collections.abc import Callable
from collections.abc import Iterable
from collections.abc import Sequence
Expand All @@ -61,6 +62,7 @@

from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.base import SubprocessModelHandler

try:
from google.adk import sessions
Expand All @@ -73,6 +75,17 @@
ADK_AVAILABLE = True
except ImportError:
ADK_AVAILABLE = False

BEAM_PLACEHOLDER_MODEL = "beam-placeholder-model"

def _is_beam_placeholder_model(model: Any) -> bool:
return model == BEAM_PLACEHOLDER_MODEL

if not ADK_AVAILABLE:
class Agent:
pass
class Runner:
pass
genai_Content = Any # type: ignore[assignment, misc]
genai_Part = Any # type: ignore[assignment, misc]

Expand Down Expand Up @@ -128,6 +141,7 @@ def __init__(
app_name: str = "beam_inference",
session_service_factory: Optional[Callable[[],
"BaseSessionService"]] = None,
underlying_model_handler: Optional[SubprocessModelHandler] = None,
*,
min_batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None,
Expand All @@ -146,6 +160,8 @@ def __init__(
self._agent_or_factory = agent
self._app_name = app_name
self._session_service_factory = session_service_factory
self._underlying_model_handler = underlying_model_handler
self._current_port = None

super().__init__(
min_batch_size=min_batch_size,
Expand All @@ -165,11 +181,62 @@ def load_model(self) -> "Runner":
Returns:
A fully initialised :class:`~google.adk.runners.Runner`.
"""
local_model = None
underlying_model = None

if self._underlying_model_handler is not None:
underlying_model = self._underlying_model_handler.load_model()
self._current_port = self._underlying_model_handler.get_port(underlying_model)
model_name = self._underlying_model_handler.get_model_name()

from google.adk.models.lite_llm import LiteLlm
local_model = LiteLlm(
model=model_name,
api_base=f"http://localhost:{self._current_port}/v1"
)

# Resolve agent and inject model
if callable(self._agent_or_factory) and not isinstance(
self._agent_or_factory, Agent):
agent = self._agent_or_factory()
import inspect
sig = inspect.signature(self._agent_or_factory)
params = list(sig.parameters.values())
required_params = [
p for p in params
if p.default is inspect.Parameter.empty and p.kind not in (
inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
]

if len(required_params) == 1:
if local_model is None:
raise ValueError("Agent factory expects 1 argument but no local model was configured.")
agent = self._agent_or_factory(local_model)
elif len(required_params) == 0:
if local_model is not None and len(params) > 0:
agent = self._agent_or_factory(local_model)
else:
agent = self._agent_or_factory()
if local_model is not None:
if not _is_beam_placeholder_model(agent.model) and agent.model is not None:
raise ValueError(
f"Agent model must be BEAM_PLACEHOLDER_MODEL or None when using local model. "
f"Found: {agent.model}")
self._set_agent_model(agent, local_model, is_root=True)
else:
raise ValueError("Agent factory must take 0 or 1 required argument.")
else:
agent = self._agent_or_factory
if local_model is not None:
if not _is_beam_placeholder_model(agent.model) and agent.model is not None:
raise ValueError(
f"Agent model must be BEAM_PLACEHOLDER_MODEL or None when using local model. "
f"Found: {agent.model}")
self._set_agent_model(agent, local_model, is_root=True)

# Validation when local model is NOT used
if local_model is None:
if _is_beam_placeholder_model(agent.model):
raise ValueError("Agent model cannot be BEAM_PLACEHOLDER_MODEL when no local model is configured.")

if self._session_service_factory is not None:
session_service = self._session_service_factory()
Expand All @@ -181,13 +248,33 @@ def load_model(self) -> "Runner":
app_name=self._app_name,
session_service=session_service,
)

if underlying_model is not None:
runner._underlying_model = underlying_model

LOGGER.info(
"Loaded ADK Runner for agent '%s' (app_name='%s')",
agent.name,
self._app_name,
)
return runner

def _set_agent_model(self, agent: "Agent", model: Any, is_root: bool = False):
if is_root:
if _is_beam_placeholder_model(agent.model) or agent.model is None:
agent.model = model
else:
if _is_beam_placeholder_model(agent.model):
agent.model = model

# Speculative propagation to subagents/tools
if getattr(agent, 'tools', None) is not None:
for tool in agent.tools:
if hasattr(tool, 'agent'):
self._set_agent_model(tool.agent, model, is_root=False)
elif isinstance(tool, Agent):
self._set_agent_model(tool, model, is_root=False)

def run_inference(
self,
batch: Sequence[str | genai_Content],
Expand Down Expand Up @@ -219,6 +306,33 @@ def run_inference(
An iterable of :class:`~apache_beam.ml.inference.base.PredictionResult`,
one per input element.
"""
underlying_model = None
if self._underlying_model_handler is not None:
underlying_model = getattr(model, '_underlying_model', None)
if underlying_model is not None:
port = self._underlying_model_handler.get_port(underlying_model)
if port != self._current_port:
LOGGER.info("Local model server port changed to %d, updating agent.", port)
self._update_agent_port(model.agent, port)
self._current_port = port

try:
return self._run_inference_internal(batch, model, inference_args)
except Exception as e:
if self._underlying_model_handler is not None and underlying_model is not None:
LOGGER.warning("Inference failed, triggering local server connectivity check.")
try:
self._underlying_model_handler.check_connectivity(underlying_model)
except Exception as recovery_err:
LOGGER.error("Failed during connectivity check: %s", recovery_err)
raise e

def _run_inference_internal(
self,
batch: Sequence[str | genai_Content],
model: "Runner",
inference_args: Optional[dict[str, Any]] = None,
) -> Iterable[PredictionResult]:
if inference_args is None:
inference_args = {}

Expand Down Expand Up @@ -259,6 +373,26 @@ async def _run_concurrently():

return results

def _update_agent_port(self, agent: "Agent", port: int):
if ADK_AVAILABLE:
from google.adk.models.lite_llm import LiteLlm
if hasattr(agent, 'model') and isinstance(agent.model, LiteLlm):
agent.model = LiteLlm(
model=agent.model.model,
api_base=f"http://localhost:{port}/v1"
)
if getattr(agent, 'tools', None) is not None:
for tool in agent.tools:
if hasattr(tool, 'agent'):
self._update_agent_port(tool.agent, port)
elif isinstance(tool, Agent):
self._update_agent_port(tool, port)

def share_model_across_processes(self) -> bool:
if self._underlying_model_handler is not None:
return self._underlying_model_handler.share_model_across_processes()
return super().share_model_across_processes()

@staticmethod
async def _invoke_agent(
runner: "Runner",
Expand Down
Loading
Loading