Skip to content

Commit e6fe30c

Browse files
committed
fix(files): wait for File API uploads to reach ACTIVE state before generation
When a file is uploaded via the File API and immediately used in a generate_content request, the file may still be in PROCESSING state, causing the API to reject the request. This was the root cause of #864 where video understanding failed for File API uploads but worked for YouTube URLs. Add _ensure_file_active() which polls the file state until ACTIVE, and _process_contents_for_generation() which applies this check to all File objects in the content before each generate_content call. Also add FileProcessingError for clear error messaging when a file fails to become ACTIVE. Fixes #864
1 parent 8ec977c commit e6fe30c

File tree

3 files changed

+323
-0
lines changed

3 files changed

+323
-0
lines changed

google/genai/errors.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,4 +302,20 @@ class UnknownApiResponseError(ValueError):
302302
"""Raised when the response from the API cannot be parsed as JSON."""
303303
pass
304304

305+
306+
class FileProcessingError(Exception):
307+
"""Error related to file processing in the API.
308+
309+
This exception is raised when a file fails to reach the ACTIVE state
310+
required for using it in content generation requests.
311+
"""
312+
313+
def __init__(
314+
self, message: str, response_json: Optional[dict[str, Any]] = None
315+
) -> None:
316+
self.message = message
317+
self.details = response_json or {}
318+
super().__init__(message)
319+
320+
305321
ExperimentalWarning = _common.ExperimentalWarning

google/genai/models.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import json
1919
import logging
20+
import time
2021
from typing import Any, AsyncIterator, Awaitable, Iterator, Optional, Union
2122
from urllib.parse import urlencode
2223

@@ -4617,6 +4618,86 @@ def _VoiceConfig_to_vertex(
46174618
return to_object
46184619

46194620

4621+
def _ensure_file_active(
4622+
api_client: BaseApiClient,
4623+
file_obj: types.File,
4624+
max_retries: int = 3,
4625+
retry_delay_seconds: int = 5,
4626+
) -> types.File:
4627+
"""Ensure a file object is in ACTIVE state before using it in content generation.
4628+
4629+
Args:
4630+
api_client: The API client to use for requests.
4631+
file_obj: The file object to check.
4632+
max_retries: Maximum number of retries for checking file state.
4633+
retry_delay_seconds: Delay between retries in seconds.
4634+
4635+
Returns:
4636+
The file object, refreshed if necessary.
4637+
4638+
Raises:
4639+
errors.FileProcessingError: If the file fails to become ACTIVE within the retry limit.
4640+
"""
4641+
if hasattr(file_obj, 'name') and file_obj.name and hasattr(file_obj, 'state'):
4642+
if file_obj.state == types.FileState.PROCESSING:
4643+
logger.info(
4644+
f'File {file_obj.name} is in PROCESSING state. Waiting for it to become ACTIVE.'
4645+
)
4646+
for attempt in range(max_retries):
4647+
time.sleep(retry_delay_seconds)
4648+
try:
4649+
file_id = file_obj.name.split('/')[-1]
4650+
response = api_client.request('GET', f'files/{file_id}', {}, None)
4651+
response_dict = {} if not response.body else json.loads(response.body)
4652+
refreshed_file = types.File._from_response(
4653+
response=response_dict, kwargs={}
4654+
)
4655+
logger.info(f'File {file_obj.name} state: {refreshed_file.state}')
4656+
if refreshed_file.state == types.FileState.ACTIVE:
4657+
return refreshed_file
4658+
if refreshed_file.state == types.FileState.FAILED:
4659+
error_msg = 'File processing failed'
4660+
if hasattr(refreshed_file, 'error') and refreshed_file.error:
4661+
error_msg = f'{error_msg}: {refreshed_file.error.message}'
4662+
raise errors.FileProcessingError(error_msg)
4663+
except errors.FileProcessingError:
4664+
raise
4665+
except Exception as e:
4666+
logger.warning(f'Error refreshing file state: {e}')
4667+
logger.warning(
4668+
f'File {file_obj.name} did not become ACTIVE after {max_retries} attempts. '
4669+
'This may cause the content generation to fail.'
4670+
)
4671+
return file_obj
4672+
4673+
4674+
def _process_contents_for_generation(
4675+
api_client: BaseApiClient,
4676+
contents: Union[types.ContentListUnion, types.ContentListUnionDict],
4677+
) -> list[types.Content]:
4678+
"""Process the contents, ensuring all File objects are in the ACTIVE state.
4679+
4680+
Args:
4681+
api_client: The API client to use for requests.
4682+
contents: The contents to process.
4683+
4684+
Returns:
4685+
The processed contents.
4686+
"""
4687+
processed_contents = t.t_contents(contents)
4688+
4689+
def process_file_in_item(item: types.Content) -> types.Content:
4690+
if isinstance(item, types.Content):
4691+
if hasattr(item, 'parts') and item.parts:
4692+
for part in item.parts:
4693+
if hasattr(part, 'file_data') and part.file_data:
4694+
if isinstance(part.file_data, types.File):
4695+
part.file_data = _ensure_file_active(api_client, part.file_data)
4696+
return item
4697+
4698+
return [process_file_in_item(item) for item in processed_contents]
4699+
4700+
46204701
class Models(_api_module.BaseModule):
46214702

46224703
def _generate_content(
@@ -4626,6 +4707,7 @@ def _generate_content(
46264707
contents: Union[types.ContentListUnion, types.ContentListUnionDict],
46274708
config: Optional[types.GenerateContentConfigOrDict] = None,
46284709
) -> types.GenerateContentResponse:
4710+
contents = _process_contents_for_generation(self._api_client, contents)
46294711
parameter_model = types._GenerateContentParameters(
46304712
model=model,
46314713
contents=contents,
@@ -4707,6 +4789,7 @@ def _generate_content_stream(
47074789
contents: Union[types.ContentListUnion, types.ContentListUnionDict],
47084790
config: Optional[types.GenerateContentConfigOrDict] = None,
47094791
) -> Iterator[types.GenerateContentResponse]:
4792+
contents = _process_contents_for_generation(self._api_client, contents)
47104793
parameter_model = types._GenerateContentParameters(
47114794
model=model,
47124795
contents=contents,
@@ -6585,6 +6668,7 @@ async def _generate_content(
65856668
contents: Union[types.ContentListUnion, types.ContentListUnionDict],
65866669
config: Optional[types.GenerateContentConfigOrDict] = None,
65876670
) -> types.GenerateContentResponse:
6671+
contents = _process_contents_for_generation(self._api_client, contents)
65886672
parameter_model = types._GenerateContentParameters(
65896673
model=model,
65906674
contents=contents,
@@ -6666,6 +6750,7 @@ async def _generate_content_stream(
66666750
contents: Union[types.ContentListUnion, types.ContentListUnionDict],
66676751
config: Optional[types.GenerateContentConfigOrDict] = None,
66686752
) -> Awaitable[AsyncIterator[types.GenerateContentResponse]]:
6753+
contents = _process_contents_for_generation(self._api_client, contents)
66696754
parameter_model = types._GenerateContentParameters(
66706755
model=model,
66716756
contents=contents,
Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
#!/usr/bin/env python
2+
# Copyright 2025 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
"""Tests for file state handling in content generation."""
18+
19+
import unittest
20+
from unittest import mock
21+
import time
22+
23+
import pytest
24+
25+
from google.genai import types
26+
from google.genai import errors
27+
from google.genai.models import _ensure_file_active, _process_contents_for_generation
28+
from google.genai.types import FileState
29+
30+
31+
class TestFileStateHandling(unittest.TestCase):
32+
"""Test file state handling functionality."""
33+
34+
def setUp(self):
35+
"""Set up test fixtures."""
36+
self.api_client = mock.MagicMock()
37+
self.file_obj = types.File(
38+
name="files/test123",
39+
display_name="Test File",
40+
mime_type="video/mp4",
41+
uri="https://example.com/files/test123",
42+
state=types.FileState.PROCESSING,
43+
)
44+
45+
@mock.patch("google.genai.files._File_from_mldev")
46+
def test_ensure_file_active_with_processing_file(self, mock_file_from_mldev):
47+
"""Test that _ensure_file_active properly handles a file in PROCESSING state."""
48+
49+
response_mock = mock.MagicMock()
50+
response_mock.json = {
51+
"file": {
52+
"name": "files/test123",
53+
"displayName": "Test File",
54+
"mimeType": "video/mp4",
55+
"state": "ACTIVE",
56+
}
57+
}
58+
self.api_client.call_api.return_value = response_mock
59+
60+
# Set up the mock to return a dict that will create an ACTIVE file
61+
mock_file_from_mldev.return_value = {
62+
"name": "files/test123",
63+
"display_name": "Test File",
64+
"mime_type": "video/mp4",
65+
"state": types.FileState.ACTIVE,
66+
}
67+
68+
result = _ensure_file_active(
69+
self.api_client, self.file_obj, max_retries=1, retry_delay_seconds=0.1
70+
)
71+
72+
self.api_client.call_api.assert_called_once_with(
73+
method="GET",
74+
url="files/test123",
75+
api_client_type="mldev",
76+
)
77+
78+
# Verify the result has ACTIVE state
79+
self.assertEqual(result.state, types.FileState.ACTIVE)
80+
81+
def test_ensure_file_active_with_failed_file(self):
82+
"""Test that _ensure_file_active properly handles a file in FAILED state."""
83+
84+
response_mock = mock.MagicMock()
85+
response_mock.json = {
86+
"file": {
87+
"name": "files/test123",
88+
"displayName": "Test File",
89+
"mimeType": "video/mp4",
90+
"state": "FAILED",
91+
"error": {"message": "File processing failed"},
92+
}
93+
}
94+
95+
# Set up a side effect for call_api that returns the response with FAILED state
96+
def mock_call_api(**kwargs):
97+
# Only return the mock for the expected file API call
98+
if kwargs.get("method") == "GET" and "files/" in kwargs.get("url", ""):
99+
return response_mock
100+
return mock.DEFAULT
101+
102+
self.api_client.call_api.side_effect = mock_call_api
103+
104+
105+
with pytest.raises(errors.FileProcessingError) as excinfo:
106+
_ensure_file_active(
107+
self.api_client, self.file_obj, max_retries=1, retry_delay_seconds=0.1
108+
)
109+
110+
assert "File processing failed" in str(excinfo.value)
111+
112+
def test_ensure_file_active_with_retries_exhausted(self):
113+
"""Test that _ensure_file_active handles a file that remains in PROCESSING state after all retries."""
114+
# Mock the response for file state check - file stays in PROCESSING
115+
response_mock = mock.MagicMock()
116+
response_mock.json = {
117+
"file": {
118+
"name": "files/test123",
119+
"displayName": "Test File",
120+
"mimeType": "video/mp4",
121+
"state": "PROCESSING",
122+
}
123+
}
124+
self.api_client.call_api.return_value = response_mock
125+
126+
# Call the function
127+
result = _ensure_file_active(
128+
self.api_client, self.file_obj, max_retries=2, retry_delay_seconds=0.1
129+
)
130+
131+
# Verify the file state was checked multiple times
132+
self.assertEqual(self.api_client.call_api.call_count, 2)
133+
134+
# Verify the original file is returned
135+
self.assertEqual(result, self.file_obj)
136+
self.assertEqual(result.state, types.FileState.PROCESSING)
137+
138+
def test_ensure_file_active_with_already_active_file(self):
139+
"""Test that _ensure_file_active returns immediately for an already ACTIVE file."""
140+
active_file = types.File(
141+
name="files/active123",
142+
display_name="Active File",
143+
mime_type="video/mp4",
144+
state=types.FileState.ACTIVE,
145+
)
146+
147+
result = _ensure_file_active(
148+
self.api_client, active_file, max_retries=1, retry_delay_seconds=0.1
149+
)
150+
151+
# Verify no API calls were made
152+
self.api_client.call_api.assert_not_called()
153+
154+
# Verify the original file is returned unchanged
155+
self.assertEqual(result, active_file)
156+
self.assertEqual(result.state, types.FileState.ACTIVE)
157+
158+
159+
class TestProcessContentsFunction(unittest.TestCase):
160+
"""Test the _process_contents_for_generation function."""
161+
162+
def setUp(self):
163+
"""Set up test fixtures."""
164+
self.api_client = mock.MagicMock()
165+
self.processing_file = types.File(
166+
name="files/processing123",
167+
display_name="Processing File",
168+
mime_type="video/mp4",
169+
uri="https://example.com/files/processing123",
170+
state=types.FileState.PROCESSING
171+
)
172+
self.active_file = types.File(
173+
name="files/active123",
174+
display_name="Active File",
175+
mime_type="video/mp4",
176+
uri="https://example.com/files/active123",
177+
state=types.FileState.ACTIVE
178+
)
179+
180+
def test_process_contents_with_files(self):
181+
"""Test that _process_contents_for_generation can handle various file scenarios."""
182+
# Scenarios:
183+
# - File directly in content list
184+
# - File in content parts
185+
# - Multiple files in different parts
186+
187+
# Test data
188+
file_in_list = [self.processing_file, "Process this file"]
189+
190+
file_in_parts = types.Content(
191+
role="user",
192+
parts=[types.Part(text="Here's a video:"), self.processing_file]
193+
)
194+
195+
multiple_files = [
196+
types.Content(
197+
role="user",
198+
parts=[types.Part(text="First video:"), self.processing_file]
199+
),
200+
types.Content(
201+
role="user",
202+
parts=[types.Part(text="Second video:"), self.active_file]
203+
)
204+
]
205+
206+
# Mock _ensure_file_active to return the file unchanged
207+
# This allows us to test the function without changing file states
208+
with mock.patch("google.genai.models._ensure_file_active",
209+
side_effect=lambda client, file: file):
210+
211+
# Test all three cases
212+
for test_content in [file_in_list, file_in_parts, multiple_files]:
213+
with mock.patch("google.genai.models.t.t_contents",
214+
return_value=test_content if isinstance(test_content, list) else [test_content]):
215+
# Just verify it runs without errors
216+
result = _process_contents_for_generation(self.api_client, test_content)
217+
# Basic assertion that we got something back
218+
self.assertTrue(result)
219+
220+
221+
if __name__ == "__main__":
222+
unittest.main()

0 commit comments

Comments
 (0)