Skip to content

Commit 64f2232

Browse files
xuanyang15copybara-github
authored andcommitted
fix: keep existing header values while merging tracking headers for llm_request.config.http_options in Gemini.generate_content_async
PiperOrigin-RevId: 788135222
1 parent 646eb42 commit 64f2232

2 files changed

Lines changed: 81 additions & 204 deletions

File tree

src/google/adk/models/google_llm.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,15 @@ async def generate_content_async(
116116
)
117117
logger.debug(_build_request_log(llm_request))
118118

119-
# add tracking headers to custom headers given it will override the headers
120-
# set in the api client constructor
121-
if llm_request.config and llm_request.config.http_options:
122-
if not llm_request.config.http_options.headers:
123-
llm_request.config.http_options.headers = {}
124-
llm_request.config.http_options.headers.update(self._tracking_headers)
119+
# Always add tracking headers to custom headers given it will override
120+
# the headers set in the api client constructor to avoid tracking headers
121+
# being dropped if user provides custom headers or overrides the api client.
122+
if llm_request.config:
123+
if not llm_request.config.http_options:
124+
llm_request.config.http_options = types.HttpOptions()
125+
llm_request.config.http_options.headers = self._merge_tracking_headers(
126+
llm_request.config.http_options.headers
127+
)
125128

126129
if stream:
127130
responses = await self.api_client.aio.models.generate_content_stream(
@@ -333,6 +336,23 @@ async def _preprocess_request(self, llm_request: LlmRequest) -> None:
333336
llm_request.config.system_instruction = None
334337
await self._adapt_computer_use_tool(llm_request)
335338

339+
def _merge_tracking_headers(self, headers: dict[str, str]) -> dict[str, str]:
340+
"""Merge tracking headers to the given headers."""
341+
headers = headers or {}
342+
for key, tracking_header_value in self._tracking_headers.items():
343+
custom_value = headers.get(key, None)
344+
if not custom_value:
345+
headers[key] = tracking_header_value
346+
continue
347+
348+
# Merge tracking headers with existing headers and avoid duplicates.
349+
value_parts = tracking_header_value.split(' ')
350+
for custom_value_part in custom_value.split(' '):
351+
if custom_value_part not in value_parts:
352+
value_parts.append(custom_value_part)
353+
headers[key] = ' '.join(value_parts)
354+
return headers
355+
336356

337357
def _build_function_declaration_log(
338358
func_decl: types.FunctionDeclaration,

0 commit comments

Comments
 (0)