|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import logging |
15 | 16 | import os |
16 | 17 | import sys |
17 | 18 | from typing import Optional |
@@ -2263,3 +2264,96 @@ async def __aexit__(self, *args): |
2263 | 2264 | # Verify the final speech_config is still None |
2264 | 2265 | assert config_arg.speech_config is None |
2265 | 2266 | assert isinstance(connection, GeminiLlmConnection) |
| 2267 | + |
| 2268 | + |
| 2269 | +@pytest.mark.asyncio |
| 2270 | +@pytest.mark.parametrize( |
| 2271 | + "log_level,should_call", |
| 2272 | + [ |
| 2273 | + (logging.WARNING, False), |
| 2274 | + (logging.INFO, False), |
| 2275 | + (logging.DEBUG, True), |
| 2276 | + ], |
| 2277 | +) |
| 2278 | +async def test_generate_content_async_skips_response_log_build_above_debug( |
| 2279 | + gemini_llm, |
| 2280 | + llm_request, |
| 2281 | + generate_content_response, |
| 2282 | + log_level, |
| 2283 | + should_call, |
| 2284 | +): |
| 2285 | + gemini_logger = logging.getLogger("google_adk.google.adk.models.google_llm") |
| 2286 | + original_level = gemini_logger.level |
| 2287 | + gemini_logger.setLevel(log_level) |
| 2288 | + try: |
| 2289 | + with mock.patch( |
| 2290 | + "google.adk.models.google_llm._build_response_log", |
| 2291 | + return_value="log", |
| 2292 | + ) as mock_build: |
| 2293 | + with mock.patch.object(gemini_llm, "api_client") as mock_client: |
| 2294 | + |
| 2295 | + async def mock_coro(): |
| 2296 | + return generate_content_response |
| 2297 | + |
| 2298 | + mock_client.aio.models.generate_content.return_value = mock_coro() |
| 2299 | + |
| 2300 | + async for _ in gemini_llm.generate_content_async( |
| 2301 | + llm_request, stream=False |
| 2302 | + ): |
| 2303 | + pass |
| 2304 | + |
| 2305 | + assert mock_build.called is should_call |
| 2306 | + finally: |
| 2307 | + gemini_logger.setLevel(original_level) |
| 2308 | + |
| 2309 | + |
| 2310 | +@pytest.mark.asyncio |
| 2311 | +@pytest.mark.parametrize( |
| 2312 | + "log_level,should_call", |
| 2313 | + [ |
| 2314 | + (logging.WARNING, False), |
| 2315 | + (logging.INFO, False), |
| 2316 | + (logging.DEBUG, True), |
| 2317 | + ], |
| 2318 | +) |
| 2319 | +async def test_generate_content_async_stream_skips_response_log_build_above_debug( |
| 2320 | + gemini_llm, llm_request, log_level, should_call |
| 2321 | +): |
| 2322 | + mock_responses = [ |
| 2323 | + types.GenerateContentResponse( |
| 2324 | + candidates=[ |
| 2325 | + types.Candidate( |
| 2326 | + content=Content( |
| 2327 | + role="model", parts=[Part.from_text(text="hi")] |
| 2328 | + ), |
| 2329 | + finish_reason=types.FinishReason.STOP, |
| 2330 | + ) |
| 2331 | + ] |
| 2332 | + ), |
| 2333 | + ] |
| 2334 | + |
| 2335 | + gemini_logger = logging.getLogger("google_adk.google.adk.models.google_llm") |
| 2336 | + original_level = gemini_logger.level |
| 2337 | + gemini_logger.setLevel(log_level) |
| 2338 | + try: |
| 2339 | + with mock.patch( |
| 2340 | + "google.adk.models.google_llm._build_response_log", |
| 2341 | + return_value="log", |
| 2342 | + ) as mock_build: |
| 2343 | + with mock.patch.object(gemini_llm, "api_client") as mock_client: |
| 2344 | + |
| 2345 | + async def mock_coro(): |
| 2346 | + return MockAsyncIterator(mock_responses) |
| 2347 | + |
| 2348 | + mock_client.aio.models.generate_content_stream.return_value = ( |
| 2349 | + mock_coro() |
| 2350 | + ) |
| 2351 | + |
| 2352 | + async for _ in gemini_llm.generate_content_async( |
| 2353 | + llm_request, stream=True |
| 2354 | + ): |
| 2355 | + pass |
| 2356 | + |
| 2357 | + assert mock_build.called is should_call |
| 2358 | + finally: |
| 2359 | + gemini_logger.setLevel(original_level) |
0 commit comments