Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions veadk/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ class Agent(LlmAgent):
example_store (Optional[BaseExampleProvider]): Example store for providing example Q/A.
enable_shadowchar (bool): Whether to enable shadow character for the agent.
enable_dynamic_load_skills (bool): Whether to enable dynamic loading of skills.
enable_responses_cache (bool): Whether Ark Responses API should reuse
`previous_response_id` and caching for multi-turn continuation.
"""

model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
Expand Down Expand Up @@ -118,6 +120,7 @@ class Agent(LlmAgent):
tracers: list[BaseTracer] = []

enable_responses: bool = False
enable_responses_cache: bool = True

context_cache_config: Optional[ContextCacheConfig] = None

Expand Down Expand Up @@ -194,6 +197,7 @@ def model_post_init(self, __context: Any) -> None:
model=f"{self.model_provider}/{self.model_name}",
api_key=self.model_api_key,
api_base=self.model_api_base,
enable_responses_cache=self.enable_responses_cache,
**self.model_extra_config,
)
else:
Expand Down
22 changes: 18 additions & 4 deletions veadk/models/ark_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,10 +481,17 @@ def _remove_caching(request_data: dict) -> None:
request_data.pop("caching", None)


def request_reorganization_by_ark(request_data: Dict) -> Dict:
def request_reorganization_by_ark(
request_data: Dict, enable_responses_cache: bool = True
) -> Dict:
# 1. model provider
request_data = get_model_without_provider(request_data)

if not enable_responses_cache:
request_data.pop("previous_response_id", None)
_remove_caching(request_data)
request_data.pop("store", None)

# 2. filtered input
request_data["input"] = filtered_inputs(
request_data.get("input"),
Expand Down Expand Up @@ -672,7 +679,9 @@ async def aresponses(
) -> Union[ArkTypeResponse, AsyncStream[ResponseStreamEvent]]:
# 1. Get request params
api_base = kwargs.pop("api_base", DEFAULT_VIDEO_MODEL_API_BASE)
api_key = kwargs.pop("api_key", settings.model.api_key)
api_key = kwargs.pop("api_key", None)
if api_key is None:
api_key = settings.model.api_key

# 2. Call openai responses
client = AsyncArk(
Expand All @@ -689,6 +698,7 @@ class ArkLlm(Gemini):
llm_client: ArkLlmClient = Field(default_factory=ArkLlmClient)
_additional_args: Dict[str, Any] = None
use_interactions_api: bool = True
enable_responses_cache: bool = True

def __init__(self, **kwargs):
# adk version check
Expand All @@ -699,12 +709,14 @@ def __init__(self, **kwargs):
"`pip install -U 'google-adk>=1.21.0'`"
)
super().__init__(**kwargs)
self.enable_responses_cache = kwargs.get("enable_responses_cache", True)
drop_params = kwargs.pop("drop_params", None)
self._additional_args = dict(kwargs)
self._additional_args.pop("llm_client", None)
self._additional_args.pop("messages", None)
self._additional_args.pop("tools", None)
self._additional_args.pop("stream", None)
self._additional_args.pop("enable_responses_cache", None)
if drop_params is not None:
self._additional_args["drop_params"] = drop_params

Expand Down Expand Up @@ -733,7 +745,7 @@ async def generate_content_async(
# ------------------------------------------------------ #
# get previous_response_id
previous_response_id = None
if llm_request.previous_interaction_id:
if self.enable_responses_cache and llm_request.previous_interaction_id:
previous_response_id = llm_request.previous_interaction_id
responses_args = {
"model": self.model,
Expand Down Expand Up @@ -786,7 +798,9 @@ async def generate_content_async(
async def generate_content_via_responses(
self, responses_args: dict, stream: bool = False
):
responses_args = request_reorganization_by_ark(responses_args)
responses_args = request_reorganization_by_ark(
responses_args, enable_responses_cache=self.enable_responses_cache
)
if stream:
responses_args["stream"] = True
async for part in await self.llm_client.aresponses(**responses_args):
Expand Down
Loading