Skip to content

Commit 3de2c1e

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: Increase default timeout to 600 seconds for ask_contexts and async_retrieve_contexts in VertexRagServiceClient.
PiperOrigin-RevId: 893099172
1 parent ff5e246 commit 3de2c1e

File tree

4 files changed

+146
-77
lines changed

4 files changed

+146
-77
lines changed

tests/unit/vertex_rag/test_rag_retrieval.py

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def retrieve_contexts_eq(response, expected_response):
8585

8686

8787
@pytest.mark.usefixtures("google_auth_mock")
88-
class TestRagRetrieval: # pylint: disable=missing-class-docstring
88+
class TestRagRetrieval: # pylint: disable=missing-class-docstring, bad-indentation, unused-variable, unused-argument, redefined-outer-name
8989

9090
def setup_method(self):
9191
importlib.reload(aiplatform.initializer)
@@ -113,6 +113,18 @@ def test_ask_contexts_rag_resources_success(self):
113113
)
114114
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)
115115

116+
@pytest.mark.usefixtures("ask_contexts_mock")
117+
def test_ask_contexts_with_timeout(self, ask_contexts_mock):
118+
rag.ask_contexts(
119+
rag_resources=[tc.TEST_RAG_RESOURCE],
120+
text=tc.TEST_QUERY_TEXT,
121+
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
122+
timeout=300,
123+
)
124+
ask_contexts_mock.assert_called_once()
125+
_, kwargs = ask_contexts_mock.call_args
126+
assert kwargs["timeout"] == 300
127+
116128
@pytest.mark.usefixtures("ask_contexts_mock")
117129
def test_ask_contexts_multiple_rag_resources_success(self):
118130
response = rag.ask_contexts(
@@ -123,8 +135,9 @@ def test_ask_contexts_multiple_rag_resources_success(self):
123135
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)
124136

125137
@pytest.mark.asyncio
126-
@pytest.mark.usefixtures("async_retrieve_contexts_mock")
127-
async def test_async_retrieve_contexts_rag_resources_success(self):
138+
async def test_async_retrieve_contexts_rag_resources_success(
139+
self, async_retrieve_contexts_mock
140+
):
128141
response = await rag.async_retrieve_contexts(
129142
rag_resources=[tc.TEST_RAG_RESOURCE],
130143
text=tc.TEST_QUERY_TEXT,
@@ -133,8 +146,23 @@ async def test_async_retrieve_contexts_rag_resources_success(self):
133146
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)
134147

135148
@pytest.mark.asyncio
136-
@pytest.mark.usefixtures("async_retrieve_contexts_mock")
137-
async def test_async_retrieve_contexts_multiple_rag_resources_success(self):
149+
async def test_async_retrieve_contexts_with_timeout(
150+
self, async_retrieve_contexts_mock
151+
):
152+
await rag.async_retrieve_contexts(
153+
rag_resources=[tc.TEST_RAG_RESOURCE],
154+
text=tc.TEST_QUERY_TEXT,
155+
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
156+
timeout=300,
157+
)
158+
async_retrieve_contexts_mock.assert_called_once()
159+
_, kwargs = async_retrieve_contexts_mock.call_args
160+
assert kwargs["timeout"] == 300
161+
162+
@pytest.mark.asyncio
163+
async def test_async_retrieve_contexts_multiple_rag_resources_success(
164+
self, async_retrieve_contexts_mock
165+
):
138166
response = await rag.async_retrieve_contexts(
139167
rag_resources=[tc.TEST_RAG_RESOURCE, tc.TEST_RAG_RESOURCE],
140168
text=tc.TEST_QUERY_TEXT,
@@ -177,7 +205,7 @@ def test_retrieval_query_failure(self):
177205
text=tc.TEST_QUERY_TEXT,
178206
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
179207
)
180-
e.match("Failed in retrieving contexts due to")
208+
e.match("Failed in retrieving contexts due to")
181209

182210
def test_retrieval_query_invalid_name(self):
183211
with pytest.raises(ValueError) as e:
@@ -186,7 +214,7 @@ def test_retrieval_query_invalid_name(self):
186214
text=tc.TEST_QUERY_TEXT,
187215
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
188216
)
189-
e.match("Invalid RagCorpus name")
217+
e.match("Invalid RagCorpus name")
190218

191219
def test_retrieval_query_multiple_rag_resources(self):
192220
with pytest.raises(ValueError) as e:
@@ -195,7 +223,7 @@ def test_retrieval_query_multiple_rag_resources(self):
195223
text=tc.TEST_QUERY_TEXT,
196224
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
197225
)
198-
e.match("Currently only support 1 RagResource")
226+
e.match("Currently only support 1 RagResource")
199227

200228
def test_retrieval_query_similarity_multiple_rag_resources(self):
201229
with pytest.raises(ValueError) as e:
@@ -204,7 +232,7 @@ def test_retrieval_query_similarity_multiple_rag_resources(self):
204232
text=tc.TEST_QUERY_TEXT,
205233
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG,
206234
)
207-
e.match("Currently only support 1 RagResource")
235+
e.match("Currently only support 1 RagResource")
208236

209237
def test_retrieval_query_invalid_config_filter(self):
210238
with pytest.raises(ValueError) as e:
@@ -213,8 +241,8 @@ def test_retrieval_query_invalid_config_filter(self):
213241
text=tc.TEST_QUERY_TEXT,
214242
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_ERROR_CONFIG,
215243
)
216-
e.match(
217-
"Only one of vector_distance_threshold or"
218-
" vector_similarity_threshold can be specified at a time"
219-
" in rag_retrieval_config."
220-
)
244+
e.match(
245+
"Only one of vector_distance_threshold or"
246+
" vector_similarity_threshold can be specified at a time"
247+
" in rag_retrieval_config."
248+
)

tests/unit/vertex_rag/test_rag_retrieval_preview.py

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def retrieve_contexts_eq(response, expected_response):
8787

8888

8989
@pytest.mark.usefixtures("google_auth_mock")
90-
class TestRagRetrieval: # pylint: disable=missing-class-docstring
90+
class TestRagRetrieval: # pylint: disable=missing-class-docstring, bad-indentation, unused-variable, unused-argument, redefined-outer-name
9191

9292
def setup_method(self):
9393
importlib.reload(aiplatform.initializer)
@@ -118,6 +118,18 @@ def test_ask_contexts_rag_resources_success(self):
118118
)
119119
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)
120120

121+
@pytest.mark.usefixtures("ask_contexts_mock")
122+
def test_ask_contexts_with_timeout(self, ask_contexts_mock):
123+
rag.ask_contexts(
124+
rag_resources=[tc.TEST_RAG_RESOURCE],
125+
text=tc.TEST_QUERY_TEXT,
126+
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG_ALPHA,
127+
timeout=300,
128+
)
129+
ask_contexts_mock.assert_called_once()
130+
args, kwargs = ask_contexts_mock.call_args
131+
assert kwargs["timeout"] == 300
132+
121133
@pytest.mark.usefixtures("ask_contexts_mock")
122134
def test_ask_contexts_multiple_rag_resources_success(self):
123135
response = rag.ask_contexts(
@@ -138,8 +150,9 @@ def test_ask_contexts_multiple_rag_corpora_success(self):
138150
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)
139151

140152
@pytest.mark.asyncio
141-
@pytest.mark.usefixtures("async_retrieve_contexts_mock")
142-
async def test_async_retrieve_contexts_rag_resources_success(self):
153+
async def test_async_retrieve_contexts_rag_resources_success(
154+
self, async_retrieve_contexts_mock
155+
):
143156
response = await rag.async_retrieve_contexts(
144157
rag_resources=[tc.TEST_RAG_RESOURCE],
145158
text=tc.TEST_QUERY_TEXT,
@@ -148,8 +161,23 @@ async def test_async_retrieve_contexts_rag_resources_success(self):
148161
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)
149162

150163
@pytest.mark.asyncio
151-
@pytest.mark.usefixtures("async_retrieve_contexts_mock")
152-
async def test_async_retrieve_contexts_multiple_rag_resources_success(self):
164+
async def test_async_retrieve_contexts_with_timeout(
165+
self, async_retrieve_contexts_mock
166+
):
167+
await rag.async_retrieve_contexts(
168+
rag_resources=[tc.TEST_RAG_RESOURCE],
169+
text=tc.TEST_QUERY_TEXT,
170+
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG_ALPHA,
171+
timeout=300,
172+
)
173+
async_retrieve_contexts_mock.assert_called_once()
174+
args, kwargs = async_retrieve_contexts_mock.call_args
175+
assert kwargs["timeout"] == 300
176+
177+
@pytest.mark.asyncio
178+
async def test_async_retrieve_contexts_multiple_rag_resources_success(
179+
self, async_retrieve_contexts_mock
180+
):
153181
response = await rag.async_retrieve_contexts(
154182
rag_resources=[tc.TEST_RAG_RESOURCE, tc.TEST_RAG_RESOURCE],
155183
text=tc.TEST_QUERY_TEXT,
@@ -158,8 +186,9 @@ async def test_async_retrieve_contexts_multiple_rag_resources_success(self):
158186
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)
159187

160188
@pytest.mark.asyncio
161-
@pytest.mark.usefixtures("async_retrieve_contexts_mock")
162-
async def test_async_retrieve_contexts_multiple_rag_corpora_success(self):
189+
async def test_async_retrieve_contexts_multiple_rag_corpora_success(
190+
self, async_retrieve_contexts_mock
191+
):
163192
with pytest.warns(DeprecationWarning):
164193
response = await rag.async_retrieve_contexts(
165194
rag_corpora=[tc.TEST_RAG_CORPUS_ID, tc.TEST_RAG_CORPUS_ID],
@@ -262,7 +291,7 @@ def test_retrieval_query_failure(self):
262291
similarity_top_k=2,
263292
vector_distance_threshold=0.5,
264293
)
265-
e.match("Failed in retrieving contexts due to")
294+
e.match("Failed in retrieving contexts due to")
266295

267296
@pytest.mark.usefixtures("rag_client_mock_exception")
268297
def test_retrieval_query_config_failure(self):
@@ -272,7 +301,7 @@ def test_retrieval_query_config_failure(self):
272301
text=tc.TEST_QUERY_TEXT,
273302
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
274303
)
275-
e.match("Failed in retrieving contexts due to")
304+
e.match("Failed in retrieving contexts due to")
276305

277306
def test_retrieval_query_invalid_name(self):
278307
with pytest.raises(ValueError) as e:
@@ -282,7 +311,7 @@ def test_retrieval_query_invalid_name(self):
282311
similarity_top_k=2,
283312
vector_distance_threshold=0.5,
284313
)
285-
e.match("Invalid RagCorpus name")
314+
e.match("Invalid RagCorpus name")
286315

287316
def test_retrieval_query_invalid_name_config(self):
288317
with pytest.raises(ValueError) as e:
@@ -291,7 +320,7 @@ def test_retrieval_query_invalid_name_config(self):
291320
text=tc.TEST_QUERY_TEXT,
292321
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
293322
)
294-
e.match("Invalid RagCorpus name")
323+
e.match("Invalid RagCorpus name")
295324

296325
def test_retrieval_query_multiple_rag_corpora(self):
297326
with pytest.raises(ValueError) as e:
@@ -304,7 +333,7 @@ def test_retrieval_query_multiple_rag_corpora(self):
304333
similarity_top_k=2,
305334
vector_distance_threshold=0.5,
306335
)
307-
e.match("Currently only support 1 RagCorpus")
336+
e.match("Currently only support 1 RagCorpus")
308337

309338
def test_retrieval_query_multiple_rag_corpora_config(self):
310339
with pytest.raises(ValueError) as e:
@@ -316,7 +345,7 @@ def test_retrieval_query_multiple_rag_corpora_config(self):
316345
text=tc.TEST_QUERY_TEXT,
317346
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
318347
)
319-
e.match("Currently only support 1 RagCorpus")
348+
e.match("Currently only support 1 RagCorpus")
320349

321350
def test_retrieval_query_multiple_rag_resources(self):
322351
with pytest.raises(ValueError) as e:
@@ -329,7 +358,7 @@ def test_retrieval_query_multiple_rag_resources(self):
329358
similarity_top_k=2,
330359
vector_distance_threshold=0.5,
331360
)
332-
e.match("Currently only support 1 RagResource")
361+
e.match("Currently only support 1 RagResource")
333362

334363
def test_retrieval_query_multiple_rag_resources_config(self):
335364
with pytest.raises(ValueError) as e:
@@ -341,7 +370,7 @@ def test_retrieval_query_multiple_rag_resources_config(self):
341370
text=tc.TEST_QUERY_TEXT,
342371
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
343372
)
344-
e.match("Currently only support 1 RagResource")
373+
e.match("Currently only support 1 RagResource")
345374

346375
def test_retrieval_query_multiple_rag_resources_similarity_config(self):
347376
with pytest.raises(ValueError) as e:
@@ -353,7 +382,7 @@ def test_retrieval_query_multiple_rag_resources_similarity_config(self):
353382
text=tc.TEST_QUERY_TEXT,
354383
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG,
355384
)
356-
e.match("Currently only support 1 RagResource")
385+
e.match("Currently only support 1 RagResource")
357386

358387
def test_retrieval_query_invalid_config_filter(self):
359388
with pytest.raises(ValueError) as e:
@@ -362,8 +391,8 @@ def test_retrieval_query_invalid_config_filter(self):
362391
text=tc.TEST_QUERY_TEXT,
363392
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_ERROR_CONFIG,
364393
)
365-
e.match(
366-
"Only one of vector_distance_threshold or"
367-
" vector_similarity_threshold can be specified at a time"
368-
" in rag_retrieval_config."
369-
)
394+
e.match(
395+
"Only one of vector_distance_threshold or"
396+
" vector_similarity_threshold can be specified at a time"
397+
" in rag_retrieval_config."
398+
)

vertexai/preview/rag/rag_retrieval.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ async def async_retrieve_contexts(
290290
vector_distance_threshold: Optional[float] = None,
291291
vector_search_alpha: Optional[float] = None,
292292
rag_retrieval_config: Optional[resources.RagRetrievalConfig] = None,
293+
timeout: int = 600,
293294
) -> aiplatform_v1beta1.RetrieveContextsResponse:
294295
"""Retrieve top k relevant docs/chunks asynchronously.
295296
@@ -316,22 +317,23 @@ async def async_retrieve_contexts(
316317
Args:
317318
text: Required. The query in text format to get relevant contexts.
318319
rag_resources: Optional. A list of RagResource. It can be used to specify
319-
corpus only or ragfiles. Currently only support one corpus or multiple
320-
files from one corpus. In the future we may open up multiple corpora
321-
support.
320+
corpus only or ragfiles. Currently only support one corpus or multiple
321+
files from one corpus. In the future we may open up multiple corpora
322+
support.
322323
rag_corpora: Optional. Deprecated. Please use rag_resources instead. A
323-
list of RagCorpora resource names. Format:
324-
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
325-
Currently only support one corpus. In the future we may open up multiple
326-
corpora support.
324+
list of RagCorpora resource names. Format:
325+
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
326+
Currently only support one corpus. In the future we may open up multiple
327+
corpora support.
327328
similarity_top_k: Optional. Deprecated. Please use
328-
rag_retrieval_config.top_k instead.
329+
rag_retrieval_config.top_k instead.
329330
vector_distance_threshold: Optional. Deprecated. Please use
330-
rag_retrieval_config.filter.vector_distance_threshold instead.
331+
rag_retrieval_config.filter.vector_distance_threshold instead.
331332
vector_search_alpha: Optional. Deprecated. Please use
332-
rag_retrieval_config.hybrid_search.alpha instead.
333+
rag_retrieval_config.hybrid_search.alpha instead.
333334
rag_retrieval_config: Optional. The config containing the retrieval
334-
parameters, including top_k, vector_distance_threshold, and alpha.
335+
parameters, including top_k, vector_distance_threshold, and alpha.
336+
timeout: Optional. The timeout for the request in seconds. Default is 600.
335337
336338
Returns:
337339
RetrieveContextsResponse.
@@ -523,7 +525,9 @@ async def async_retrieve_contexts(
523525
tools=[tool],
524526
)
525527
try:
526-
response_lro = await client.async_retrieve_contexts(request=request)
528+
response_lro = await client.async_retrieve_contexts(
529+
request=request, timeout=timeout
530+
)
527531
response = await response_lro.result()
528532
except Exception as e:
529533
raise RuntimeError(
@@ -541,6 +545,7 @@ def ask_contexts(
541545
vector_distance_threshold: Optional[float] = None,
542546
vector_search_alpha: Optional[float] = None,
543547
rag_retrieval_config: Optional[resources.RagRetrievalConfig] = None,
548+
timeout: int = 600,
544549
) -> aiplatform_v1beta1.AskContextsResponse:
545550
"""Ask questions on top k relevant docs/chunks.
546551
@@ -567,22 +572,23 @@ def ask_contexts(
567572
Args:
568573
text: Required. The query in text format to get relevant contexts.
569574
rag_resources: Optional. A list of RagResource. It can be used to specify
570-
corpus only or ragfiles. Currently only support one corpus or multiple
571-
files from one corpus. In the future we may open up multiple corpora
572-
support.
575+
corpus only or ragfiles. Currently only support one corpus or multiple
576+
files from one corpus. In the future we may open up multiple corpora
577+
support.
573578
rag_corpora: Optional. Deprecated. Please use rag_resources instead. A
574-
list of RagCorpora resource names. Format:
575-
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
576-
Currently only support one corpus. In the future we may open up multiple
577-
corpora support.
579+
list of RagCorpora resource names. Format:
580+
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
581+
Currently only support one corpus. In the future we may open up multiple
582+
corpora support.
578583
similarity_top_k: Optional. Deprecated. Please use
579-
rag_retrieval_config.top_k instead.
584+
rag_retrieval_config.top_k instead.
580585
vector_distance_threshold: Optional. Deprecated. Please use
581-
rag_retrieval_config.filter.vector_distance_threshold instead.
586+
rag_retrieval_config.filter.vector_distance_threshold instead.
582587
vector_search_alpha: Optional. Deprecated. Please use
583-
rag_retrieval_config.hybrid_search.alpha instead.
588+
rag_retrieval_config.hybrid_search.alpha instead.
584589
rag_retrieval_config: Optional. The config containing the retrieval
585-
parameters, including top_k, vector_distance_threshold, and alpha.
590+
parameters, including top_k, vector_distance_threshold, and alpha.
591+
timeout: Optional. The timeout for the request in seconds. Default is 600.
586592
587593
Returns:
588594
AskContextsResponse.
@@ -774,7 +780,7 @@ def ask_contexts(
774780
tools=[tool],
775781
)
776782
try:
777-
response = client.ask_contexts(request=request)
783+
response = client.ask_contexts(request=request, timeout=timeout)
778784
except Exception as e:
779785
raise RuntimeError("Failed in asking contexts due to: ", e) from e
780786

0 commit comments

Comments
 (0)