@@ -87,7 +87,7 @@ def retrieve_contexts_eq(response, expected_response):
8787
8888
8989@pytest .mark .usefixtures ("google_auth_mock" )
90- class TestRagRetrieval : # pylint: disable=missing-class-docstring
90+ class TestRagRetrieval : # pylint: disable=missing-class-docstring, bad-indentation, unused-variable, unused-argument, redefined-outer-name
9191
9292 def setup_method (self ):
9393 importlib .reload (aiplatform .initializer )
@@ -118,6 +118,18 @@ def test_ask_contexts_rag_resources_success(self):
118118 )
119119 retrieve_contexts_eq (response , tc .TEST_RETRIEVAL_RESPONSE )
120120
121+ @pytest .mark .usefixtures ("ask_contexts_mock" )
122+ def test_ask_contexts_with_timeout (self , ask_contexts_mock ):
123+ rag .ask_contexts (
124+ rag_resources = [tc .TEST_RAG_RESOURCE ],
125+ text = tc .TEST_QUERY_TEXT ,
126+ rag_retrieval_config = tc .TEST_RAG_RETRIEVAL_CONFIG_ALPHA ,
127+ timeout = 300 ,
128+ )
129+ ask_contexts_mock .assert_called_once ()
130+ args , kwargs = ask_contexts_mock .call_args
131+ assert kwargs ["timeout" ] == 300
132+
121133 @pytest .mark .usefixtures ("ask_contexts_mock" )
122134 def test_ask_contexts_multiple_rag_resources_success (self ):
123135 response = rag .ask_contexts (
@@ -138,8 +150,9 @@ def test_ask_contexts_multiple_rag_corpora_success(self):
138150 retrieve_contexts_eq (response , tc .TEST_RETRIEVAL_RESPONSE )
139151
140152 @pytest .mark .asyncio
141- @pytest .mark .usefixtures ("async_retrieve_contexts_mock" )
142- async def test_async_retrieve_contexts_rag_resources_success (self ):
153+ async def test_async_retrieve_contexts_rag_resources_success (
154+ self , async_retrieve_contexts_mock
155+ ):
143156 response = await rag .async_retrieve_contexts (
144157 rag_resources = [tc .TEST_RAG_RESOURCE ],
145158 text = tc .TEST_QUERY_TEXT ,
@@ -148,8 +161,23 @@ async def test_async_retrieve_contexts_rag_resources_success(self):
148161 retrieve_contexts_eq (response , tc .TEST_RETRIEVAL_RESPONSE )
149162
150163 @pytest .mark .asyncio
151- @pytest .mark .usefixtures ("async_retrieve_contexts_mock" )
152- async def test_async_retrieve_contexts_multiple_rag_resources_success (self ):
164+ async def test_async_retrieve_contexts_with_timeout (
165+ self , async_retrieve_contexts_mock
166+ ):
167+ await rag .async_retrieve_contexts (
168+ rag_resources = [tc .TEST_RAG_RESOURCE ],
169+ text = tc .TEST_QUERY_TEXT ,
170+ rag_retrieval_config = tc .TEST_RAG_RETRIEVAL_CONFIG_ALPHA ,
171+ timeout = 300 ,
172+ )
173+ async_retrieve_contexts_mock .assert_called_once ()
174+ args , kwargs = async_retrieve_contexts_mock .call_args
175+ assert kwargs ["timeout" ] == 300
176+
177+ @pytest .mark .asyncio
178+ async def test_async_retrieve_contexts_multiple_rag_resources_success (
179+ self , async_retrieve_contexts_mock
180+ ):
153181 response = await rag .async_retrieve_contexts (
154182 rag_resources = [tc .TEST_RAG_RESOURCE , tc .TEST_RAG_RESOURCE ],
155183 text = tc .TEST_QUERY_TEXT ,
@@ -158,8 +186,9 @@ async def test_async_retrieve_contexts_multiple_rag_resources_success(self):
158186 retrieve_contexts_eq (response , tc .TEST_RETRIEVAL_RESPONSE )
159187
160188 @pytest .mark .asyncio
161- @pytest .mark .usefixtures ("async_retrieve_contexts_mock" )
162- async def test_async_retrieve_contexts_multiple_rag_corpora_success (self ):
189+ async def test_async_retrieve_contexts_multiple_rag_corpora_success (
190+ self , async_retrieve_contexts_mock
191+ ):
163192 with pytest .warns (DeprecationWarning ):
164193 response = await rag .async_retrieve_contexts (
165194 rag_corpora = [tc .TEST_RAG_CORPUS_ID , tc .TEST_RAG_CORPUS_ID ],
@@ -262,7 +291,7 @@ def test_retrieval_query_failure(self):
262291 similarity_top_k = 2 ,
263292 vector_distance_threshold = 0.5 ,
264293 )
265- e .match ("Failed in retrieving contexts due to" )
294+ e .match ("Failed in retrieving contexts due to" )
266295
267296 @pytest .mark .usefixtures ("rag_client_mock_exception" )
268297 def test_retrieval_query_config_failure (self ):
@@ -272,7 +301,7 @@ def test_retrieval_query_config_failure(self):
272301 text = tc .TEST_QUERY_TEXT ,
273302 rag_retrieval_config = tc .TEST_RAG_RETRIEVAL_CONFIG ,
274303 )
275- e .match ("Failed in retrieving contexts due to" )
304+ e .match ("Failed in retrieving contexts due to" )
276305
277306 def test_retrieval_query_invalid_name (self ):
278307 with pytest .raises (ValueError ) as e :
@@ -282,7 +311,7 @@ def test_retrieval_query_invalid_name(self):
282311 similarity_top_k = 2 ,
283312 vector_distance_threshold = 0.5 ,
284313 )
285- e .match ("Invalid RagCorpus name" )
314+ e .match ("Invalid RagCorpus name" )
286315
287316 def test_retrieval_query_invalid_name_config (self ):
288317 with pytest .raises (ValueError ) as e :
@@ -291,7 +320,7 @@ def test_retrieval_query_invalid_name_config(self):
291320 text = tc .TEST_QUERY_TEXT ,
292321 rag_retrieval_config = tc .TEST_RAG_RETRIEVAL_CONFIG ,
293322 )
294- e .match ("Invalid RagCorpus name" )
323+ e .match ("Invalid RagCorpus name" )
295324
296325 def test_retrieval_query_multiple_rag_corpora (self ):
297326 with pytest .raises (ValueError ) as e :
@@ -304,7 +333,7 @@ def test_retrieval_query_multiple_rag_corpora(self):
304333 similarity_top_k = 2 ,
305334 vector_distance_threshold = 0.5 ,
306335 )
307- e .match ("Currently only support 1 RagCorpus" )
336+ e .match ("Currently only support 1 RagCorpus" )
308337
309338 def test_retrieval_query_multiple_rag_corpora_config (self ):
310339 with pytest .raises (ValueError ) as e :
@@ -316,7 +345,7 @@ def test_retrieval_query_multiple_rag_corpora_config(self):
316345 text = tc .TEST_QUERY_TEXT ,
317346 rag_retrieval_config = tc .TEST_RAG_RETRIEVAL_CONFIG ,
318347 )
319- e .match ("Currently only support 1 RagCorpus" )
348+ e .match ("Currently only support 1 RagCorpus" )
320349
321350 def test_retrieval_query_multiple_rag_resources (self ):
322351 with pytest .raises (ValueError ) as e :
@@ -329,7 +358,7 @@ def test_retrieval_query_multiple_rag_resources(self):
329358 similarity_top_k = 2 ,
330359 vector_distance_threshold = 0.5 ,
331360 )
332- e .match ("Currently only support 1 RagResource" )
361+ e .match ("Currently only support 1 RagResource" )
333362
334363 def test_retrieval_query_multiple_rag_resources_config (self ):
335364 with pytest .raises (ValueError ) as e :
@@ -341,7 +370,7 @@ def test_retrieval_query_multiple_rag_resources_config(self):
341370 text = tc .TEST_QUERY_TEXT ,
342371 rag_retrieval_config = tc .TEST_RAG_RETRIEVAL_CONFIG ,
343372 )
344- e .match ("Currently only support 1 RagResource" )
373+ e .match ("Currently only support 1 RagResource" )
345374
346375 def test_retrieval_query_multiple_rag_resources_similarity_config (self ):
347376 with pytest .raises (ValueError ) as e :
@@ -353,7 +382,7 @@ def test_retrieval_query_multiple_rag_resources_similarity_config(self):
353382 text = tc .TEST_QUERY_TEXT ,
354383 rag_retrieval_config = tc .TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG ,
355384 )
356- e .match ("Currently only support 1 RagResource" )
385+ e .match ("Currently only support 1 RagResource" )
357386
358387 def test_retrieval_query_invalid_config_filter (self ):
359388 with pytest .raises (ValueError ) as e :
@@ -362,8 +391,8 @@ def test_retrieval_query_invalid_config_filter(self):
362391 text = tc .TEST_QUERY_TEXT ,
363392 rag_retrieval_config = tc .TEST_RAG_RETRIEVAL_ERROR_CONFIG ,
364393 )
365- e .match (
366- "Only one of vector_distance_threshold or"
367- " vector_similarity_threshold can be specified at a time"
368- " in rag_retrieval_config."
369- )
394+ e .match (
395+ "Only one of vector_distance_threshold or"
396+ " vector_similarity_threshold can be specified at a time"
397+ " in rag_retrieval_config."
398+ )
0 commit comments