diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index c7b10aa61a..50c820c143 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -122,9 +122,9 @@ async def generate_content_async( if llm_request.config: if not llm_request.config.http_options: llm_request.config.http_options = types.HttpOptions() - if not llm_request.config.http_options.headers: - llm_request.config.http_options.headers = {} - llm_request.config.http_options.headers.update(self._tracking_headers) + llm_request.config.http_options.headers = self._merge_tracking_headers( + llm_request.config.http_options.headers + ) if stream: responses = await self.api_client.aio.models.generate_content_stream( @@ -336,6 +336,23 @@ async def _preprocess_request(self, llm_request: LlmRequest) -> None: llm_request.config.system_instruction = None await self._adapt_computer_use_tool(llm_request) + def _merge_tracking_headers(self, headers: dict[str, str]) -> dict[str, str]: + """Merge tracking headers to the given headers.""" + headers = headers or {} + for key, tracking_header_value in self._tracking_headers.items(): + custom_value = headers.get(key, None) + if not custom_value: + headers[key] = tracking_header_value + continue + + # Merge tracking headers with existing headers and avoid duplicates. + value_parts = tracking_header_value.split(' ') + for custom_value_part in custom_value.split(' '): + if custom_value_part not in value_parts: + value_parts.append(custom_value_part) + headers[key] = ' '.join(value_parts) + return headers + def _build_function_declaration_log( func_decl: types.FunctionDeclaration, diff --git a/tests/unittests/models/test_google_llm.py b/tests/unittests/models/test_google_llm.py index 4e99c5a567..03d18ec6d7 100644 --- a/tests/unittests/models/test_google_llm.py +++ b/tests/unittests/models/test_google_llm.py @@ -403,7 +403,7 @@ async def mock_coro(): for key, value in config_arg.http_options.headers.items(): if key in gemini_llm._tracking_headers: - assert value == gemini_llm._tracking_headers[key] + assert value == gemini_llm._tracking_headers[key] + " custom" else: assert value == custom_headers[key]