Skip to content

Commit 0f8420d

Browse files
authored
fix: ark_llm apikey and enable response without caching (#550)
1 parent 7550919 commit 0f8420d

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

veadk/agent.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ class Agent(LlmAgent):
8787
example_store (Optional[BaseExampleProvider]): Example store for providing example Q/A.
8888
enable_shadowchar (bool): Whether to enable shadow character for the agent.
8989
enable_dynamic_load_skills (bool): Whether to enable dynamic loading of skills.
90+
enable_responses_cache (bool): Whether Ark Responses API should reuse
91+
`previous_response_id` and caching for multi-turn continuation.
9092
"""
9193

9294
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
@@ -118,6 +120,7 @@ class Agent(LlmAgent):
118120
tracers: list[BaseTracer] = []
119121

120122
enable_responses: bool = False
123+
enable_responses_cache: bool = True
121124

122125
context_cache_config: Optional[ContextCacheConfig] = None
123126

@@ -194,6 +197,7 @@ def model_post_init(self, __context: Any) -> None:
194197
model=f"{self.model_provider}/{self.model_name}",
195198
api_key=self.model_api_key,
196199
api_base=self.model_api_base,
200+
enable_responses_cache=self.enable_responses_cache,
197201
**self.model_extra_config,
198202
)
199203
else:

veadk/models/ark_llm.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -481,10 +481,17 @@ def _remove_caching(request_data: dict) -> None:
481481
request_data.pop("caching", None)
482482

483483

484-
def request_reorganization_by_ark(request_data: Dict) -> Dict:
484+
def request_reorganization_by_ark(
485+
request_data: Dict, enable_responses_cache: bool = True
486+
) -> Dict:
485487
# 1. model provider
486488
request_data = get_model_without_provider(request_data)
487489

490+
if not enable_responses_cache:
491+
request_data.pop("previous_response_id", None)
492+
_remove_caching(request_data)
493+
request_data.pop("store", None)
494+
488495
# 2. filtered input
489496
request_data["input"] = filtered_inputs(
490497
request_data.get("input"),
@@ -672,7 +679,9 @@ async def aresponses(
672679
) -> Union[ArkTypeResponse, AsyncStream[ResponseStreamEvent]]:
673680
# 1. Get request params
674681
api_base = kwargs.pop("api_base", DEFAULT_VIDEO_MODEL_API_BASE)
675-
api_key = kwargs.pop("api_key", settings.model.api_key)
682+
api_key = kwargs.pop("api_key", None)
683+
if api_key is None:
684+
api_key = settings.model.api_key
676685

677686
# 2. Call openai responses
678687
client = AsyncArk(
@@ -689,6 +698,7 @@ class ArkLlm(Gemini):
689698
llm_client: ArkLlmClient = Field(default_factory=ArkLlmClient)
690699
_additional_args: Dict[str, Any] = None
691700
use_interactions_api: bool = True
701+
enable_responses_cache: bool = True
692702

693703
def __init__(self, **kwargs):
694704
# adk version check
@@ -699,12 +709,14 @@ def __init__(self, **kwargs):
699709
"`pip install -U 'google-adk>=1.21.0'`"
700710
)
701711
super().__init__(**kwargs)
712+
self.enable_responses_cache = kwargs.get("enable_responses_cache", True)
702713
drop_params = kwargs.pop("drop_params", None)
703714
self._additional_args = dict(kwargs)
704715
self._additional_args.pop("llm_client", None)
705716
self._additional_args.pop("messages", None)
706717
self._additional_args.pop("tools", None)
707718
self._additional_args.pop("stream", None)
719+
self._additional_args.pop("enable_responses_cache", None)
708720
if drop_params is not None:
709721
self._additional_args["drop_params"] = drop_params
710722

@@ -733,7 +745,7 @@ async def generate_content_async(
733745
# ------------------------------------------------------ #
734746
# get previous_response_id
735747
previous_response_id = None
736-
if llm_request.previous_interaction_id:
748+
if self.enable_responses_cache and llm_request.previous_interaction_id:
737749
previous_response_id = llm_request.previous_interaction_id
738750
responses_args = {
739751
"model": self.model,
@@ -786,7 +798,9 @@ async def generate_content_async(
786798
async def generate_content_via_responses(
787799
self, responses_args: dict, stream: bool = False
788800
):
789-
responses_args = request_reorganization_by_ark(responses_args)
801+
responses_args = request_reorganization_by_ark(
802+
responses_args, enable_responses_cache=self.enable_responses_cache
803+
)
790804
if stream:
791805
responses_args["stream"] = True
792806
async for part in await self.llm_client.aresponses(**responses_args):

0 commit comments

Comments
 (0)