Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/google/adk_community/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from . import memory
from . import plugins
from . import sessions
from . import version
__version__ = version.__version__
21 changes: 21 additions & 0 deletions src/google/adk_community/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Community plugins for Google ADK."""

from .llm_resilience_plugin import LlmResiliencePlugin

__all__ = [
"LlmResiliencePlugin",
]
352 changes: 352 additions & 0 deletions src/google/adk_community/plugins/llm_resilience_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,352 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""LlmResiliencePlugin - retry with exponential backoff and model fallbacks."""

from __future__ import annotations

import asyncio
import logging
import random
from typing import Iterable
from typing import Optional
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from google.adk.agents.invocation_context import InvocationContext

try:
import httpx
except Exception: # pragma: no cover - httpx might not be installed in all envs
httpx = None # type: ignore

from google.genai import types

from google.adk.agents.callback_context import CallbackContext
from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse
from google.adk.models.registry import LLMRegistry
from google.adk.plugins.base_plugin import BasePlugin

logger = logging.getLogger("google_adk_community." + __name__)


def _extract_status_code(err: Exception) -> Optional[int]:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This extraction logic is quite robust, but it might be worth adding support for grpc error codes if the ADK uses gRPC for some providers. gRPC errors usually have a .code() method returning an enum.

For now, this covers httpx and generic status_code attributes which is a great start.

"""Best-effort extraction of HTTP status codes from common client libraries."""
status = getattr(err, "status_code", None)
if isinstance(status, int):
return status
# httpx specific
if httpx is not None:
if isinstance(err, httpx.HTTPStatusError):
try:
return int(err.response.status_code)
except (AttributeError, ValueError, TypeError):
return None
# Fallback: look for nested response
resp = getattr(err, "response", None)
if resp is not None:
code = getattr(resp, "status_code", None)
if isinstance(code, int):
return code
return None


def _is_transient_error(err: Exception) -> bool:
"""Check if an error is transient and should trigger retry."""
# Retry on common transient classes and HTTP status codes
transient_http = {429, 500, 502, 503, 504}
status = _extract_status_code(err)
if status is not None and status in transient_http:
return True

# httpx transient
if httpx is not None and isinstance(
err, (httpx.ReadTimeout, httpx.ConnectError, httpx.RemoteProtocolError)
):
return True

# asyncio timeouts and cancellations often warrant retry/fallback at callsite
if isinstance(err, (asyncio.TimeoutError,)):
return True

return False


class LlmResiliencePlugin(BasePlugin):
"""A plugin that adds retry with exponential backoff and model fallbacks.

Behavior:
- Intercepts model errors via on_model_error_callback
- Retries the same model up to max_retries with exponential backoff + jitter
- If still failing and fallback_models configured, tries them in order
- Returns the first successful LlmResponse or None to propagate the error

Notes:
- Live (bidirectional) mode errors are not intercepted by BaseLlmFlow's error
handler; this plugin currently targets generate_content_async flow.
- In SSE streaming mode, the plugin returns a single final LlmResponse.

Example:
>>> from google.adk.runners import Runner
>>> from google.adk_community.plugins import LlmResiliencePlugin
>>>
>>> runner = Runner(
... app_name="my_app",
... agent=my_agent,
... plugins=[
... LlmResiliencePlugin(
... max_retries=3,
... backoff_initial=1.0,
... fallback_models=["gemini-1.5-flash"],
... )
... ],
... )
"""

def __init__(
self,
*,
name: str = "llm_resilience_plugin",
max_retries: int = 3,
backoff_initial: float = 1.0,
backoff_multiplier: float = 2.0,
max_backoff: float = 10.0,
jitter: float = 0.2,
retry_on_exceptions: Optional[tuple[type[BaseException], ...]] = None,
fallback_models: Optional[Iterable[str]] = None,
) -> None:
"""Initialize the LlmResiliencePlugin.

Args:
name: Plugin name identifier.
max_retries: Maximum number of retry attempts on the same model.
backoff_initial: Initial backoff delay in seconds.
backoff_multiplier: Multiplier for exponential backoff.
max_backoff: Maximum backoff delay in seconds.
jitter: Jitter factor (0.0 to 1.0) to add randomness to backoff.
retry_on_exceptions: Optional tuple of exception types to retry on.
If None, uses built-in transient error detection.
fallback_models: Optional list of model names to try if primary fails.
"""
super().__init__(name)
if max_retries < 0:
raise ValueError("max_retries must be >= 0")
if backoff_initial <= 0:
raise ValueError("backoff_initial must be > 0")
if backoff_multiplier < 1.0:
raise ValueError("backoff_multiplier must be >= 1.0")
if max_backoff <= 0:
raise ValueError("max_backoff must be > 0")
if jitter < 0:
raise ValueError("jitter must be >= 0")

self.max_retries = max_retries
self.backoff_initial = backoff_initial
self.backoff_multiplier = backoff_multiplier
self.max_backoff = max_backoff
self.jitter = jitter
self.retry_on_exceptions = retry_on_exceptions
self.fallback_models = list(fallback_models or [])

async def on_model_error_callback(
self,
*,
callback_context: CallbackContext,
llm_request: LlmRequest,
error: Exception,
) -> Optional[LlmResponse]:
"""Handle model errors with retry and fallback logic."""
# Decide whether to handle this error:
# Retry if error is in retry_on_exceptions OR is a transient error
if self.retry_on_exceptions and isinstance(error, self.retry_on_exceptions):
# User explicitly wants to retry on this exception type.
pass
elif not _is_transient_error(error):
# Not an explicit exception and not a transient error, so don't handle.
return None

# Attempt retries on the same model
response = await self._retry_same_model(
callback_context=callback_context, llm_request=llm_request
)
if response is not None:
return response

# Try fallbacks in order
if self.fallback_models:
response = await self._try_fallbacks(
callback_context=callback_context, llm_request=llm_request
)
if response is not None:
return response

# Let the original error propagate if all attempts failed
return None

def _get_invocation_context(
self, callback_context: CallbackContext | InvocationContext
) -> InvocationContext:
"""Extract InvocationContext from callback_context.

Accepts both Context (CallbackContext alias) and InvocationContext via
duck typing.

Args:
callback_context: The callback context passed to the plugin.

Returns:
The underlying InvocationContext.

Raises:
TypeError: If callback_context is not a recognized type.
"""
# If this looks like an InvocationContext (has agent and run_config), use it directly
if hasattr(callback_context, "agent") and hasattr(
callback_context, "run_config"
):
return callback_context # type: ignore[return-value]
# Otherwise expect a Context-like object exposing the private _invocation_context
ic = getattr(callback_context, "_invocation_context", None)
if ic is None:
raise TypeError(
"callback_context must be Context or InvocationContext-like"
)
return ic

def _is_sse_streaming(self, invocation_context: InvocationContext) -> bool:
"""Check if SSE streaming mode is enabled.

Args:
invocation_context: The invocation context to check.

Returns:
True if SSE streaming is enabled, False otherwise.
"""
streaming_mode = getattr(
invocation_context.run_config, "streaming_mode", None
)
try:
from google.adk.agents.run_config import StreamingMode

return streaming_mode == StreamingMode.SSE
except (ImportError, AttributeError):
return False

async def _retry_same_model(
self,
*,
callback_context: CallbackContext | InvocationContext,
llm_request: LlmRequest,
) -> Optional[LlmResponse]:
invocation_context = self._get_invocation_context(callback_context)
stream = self._is_sse_streaming(invocation_context)

agent = invocation_context.agent
llm = agent.canonical_model

backoff = self.backoff_initial
for attempt in range(1, self.max_retries + 1):
sleep_time = min(self.max_backoff, backoff)
# add multiplicative (+/-) jitter
if self.jitter > 0:
jitter_delta = sleep_time * random.uniform(-self.jitter, self.jitter)
sleep_time = max(0.0, sleep_time + jitter_delta)
if sleep_time > 0:
await asyncio.sleep(sleep_time)

try:
final_response = await self._call_llm_and_get_final(
llm=llm, llm_request=llm_request, stream=stream
)
logger.info(
"LLM retry succeeded on attempt %s for agent %s",
attempt,
agent.name,
)
return final_response
except Exception as e: # continue to next attempt
logger.warning(
"LLM retry attempt %s failed: %s", attempt, repr(e), exc_info=False
)
backoff *= self.backoff_multiplier

return None

async def _try_fallbacks(
self,
*,
callback_context: CallbackContext | InvocationContext,
llm_request: LlmRequest,
) -> Optional[LlmResponse]:
invocation_context = self._get_invocation_context(callback_context)
stream = self._is_sse_streaming(invocation_context)

for model_name in self.fallback_models:
try:
fallback_llm = LLMRegistry.new_llm(model_name)
# Update request model hint for provider bridges that honor it
llm_request.model = model_name
final_response = await self._call_llm_and_get_final(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is LLMRegistry.new_llm(model_name) always expected to return an object that has generate_content_async? If model_name is invalid, this might raise an exception. The try...except at line 305 catches it and logs it, which is good.

However, new_llm might return an instance that isn't fully configured if it's just a raw registry lookup. Does it need run_config or other settings from the original llm? In _retry_same_model, you use agent.canonical_model. Fallbacks might need similar initialization.

llm=fallback_llm, llm_request=llm_request, stream=stream
)
logger.info("LLM fallback succeeded with model '%s'", model_name)
return final_response
except Exception as e:
logger.warning(
"LLM fallback model '%s' failed: %s",
model_name,
repr(e),
exc_info=False,
)
continue
return None

async def _call_llm_and_get_final(
self, *, llm, llm_request: LlmRequest, stream: bool
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Importing inspect inside a method is generally discouraged unless it's for very rare paths or to avoid circular imports. Since it's used in the main execution path of this method, it should probably be moved to the top of the file.

) -> LlmResponse:
"""Calls the given llm and returns the final non-partial LlmResponse."""
import inspect

final: Optional[LlmResponse] = None
agen_or_coro = llm.generate_content_async(llm_request, stream=stream)

# If the provider raised before first yield, this may be a coroutine; handle gracefully
if inspect.isasyncgen(agen_or_coro) or hasattr(agen_or_coro, "__aiter__"):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If agen_or_coro is an async generator and it raises an exception immediately when calling llm.generate_content_async, this code handles it. However, if the error happens during iteration (inside async for resp in agen:), the try...finally will execute aclose(), but the exception will propagate OUT of _call_llm_and_get_final.

In _retry_same_model (line 284), you catch Exception as e and log it. This seems correct for retries.

One thing: If the provider yields some partial results and then fails, final will hold the last successful partial response. In SSE mode, you mentioned: "last one is non-partial". But if it fails mid-stream, final might be partial. Should we return a partial response or let it fail? Currently, it returns whatever was last yielded.

agen = agen_or_coro
try:
async for resp in agen:
# Keep the latest response; in streaming mode, last one is non-partial
final = resp
finally:
# If the generator is an async generator, ensure it's closed properly
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The finally block here calls agen.aclose(). While aclose() is defined for async generators, agen_or_coro might be an AsyncIterable that is NOT an async generator (e.g., a custom class with __aiter__).

It's safer to check if aclose exists before calling it:

if inspect.isasyncgen(agen):
    await agen.aclose()

Wait, you already have a try...except and a comment, but type: ignore[attr-defined] suggests it might fail on some types.

Actually, inspect.isasyncgen specifically checks for the async gen type. If it's just an AsyncIterable, it might not have aclose.

try:
await agen.aclose() # type: ignore[attr-defined]
except Exception:
pass
else:
# Await the coroutine; some LLMs may return a single response
result = await agen_or_coro
if isinstance(result, LlmResponse):
final = result
elif isinstance(result, types.Content):
final = LlmResponse(content=result, partial=False)
else:
# Unknown return type
raise TypeError("LLM generate_content_async returned unsupported type")

if final is None:
# Edge case: provider yielded nothing. Create a minimal error response.
return LlmResponse(partial=False)
return final
13 changes: 13 additions & 0 deletions tests/unittests/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading