Skip to content

Commit d368e63

Browse files
xuanyang15copybara-github
authored andcommitted
fix: keep existing header values while merging tracking headers for llm_request.config.http_options in Gemini.generate_content_async
PiperOrigin-RevId: 788135222
1 parent f29ab5d commit d368e63

2 files changed

Lines changed: 82 additions & 204 deletions

File tree

src/google/adk/models/google_llm.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,15 @@ async def generate_content_async(
116116
)
117117
logger.debug(_build_request_log(llm_request))
118118

119-
# add tracking headers to custom headers given it will override the headers
120-
# set in the api client constructor
121-
if llm_request.config and llm_request.config.http_options:
122-
if not llm_request.config.http_options.headers:
123-
llm_request.config.http_options.headers = {}
124-
llm_request.config.http_options.headers.update(self._tracking_headers)
119+
# Always add tracking headers to custom headers given it will override
120+
# the headers set in the api client constructor to avoid tracking headers
121+
# being dropped if user provides custom headers or overrides the api client.
122+
if llm_request.config:
123+
if not llm_request.config.http_options:
124+
llm_request.config.http_options = types.HttpOptions()
125+
llm_request.config.http_options.headers = self._merge_tracking_headers(
126+
llm_request.config.http_options.headers
127+
)
125128

126129
if stream:
127130
responses = await self.api_client.aio.models.generate_content_stream(
@@ -333,6 +336,21 @@ async def _preprocess_request(self, llm_request: LlmRequest) -> None:
333336
llm_request.config.system_instruction = None
334337
await self._adapt_computer_use_tool(llm_request)
335338

339+
def _merge_tracking_headers(self, headers: dict[str, str]) -> dict[str, str]:
340+
"""Merge tracking headers to the given headers."""
341+
headers = headers or {}
342+
for key, tracking_header_value in self._tracking_headers.items():
343+
current_value = headers.get(key, None)
344+
if not current_value:
345+
headers[key] = tracking_header_value
346+
else:
347+
# Merge tracking headers with existing headers and avoid duplicates.
348+
value_parts = set(current_value.split(' ')) | set(
349+
tracking_header_value.split(' ')
350+
)
351+
headers[key] = ' '.join(sorted(value_parts))
352+
return headers
353+
336354

337355
def _build_function_declaration_log(
338356
func_decl: types.FunctionDeclaration,

0 commit comments

Comments
 (0)