Skip to content

Commit 16600e1

Browse files
sararobcopybara-github
authored andcommitted
feat: GenAI client - Add update_corpus and update_config methods to the RAG module
PiperOrigin-RevId: 929902649
1 parent bf32f5e commit 16600e1

5 files changed

Lines changed: 461 additions & 0 deletions

File tree

agentplatform/_genai/rag.py

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,19 @@ def _GetCorpusOperationParameters_to_vertex(
135135
return to_object
136136

137137

138+
def _GetRagConfigOperationParameters_to_vertex(
139+
from_object: Union[dict[str, Any], object],
140+
parent_object: Optional[dict[str, Any]] = None,
141+
) -> dict[str, Any]:
142+
to_object: dict[str, Any] = {}
143+
if getv(from_object, ["operation_name"]) is not None:
144+
setv(
145+
to_object, ["_url", "operation_name"], getv(from_object, ["operation_name"])
146+
)
147+
148+
return to_object
149+
150+
138151
def _GetRagConfigRequestParameters_to_vertex(
139152
from_object: Union[dict[str, Any], object],
140153
parent_object: Optional[dict[str, Any]] = None,
@@ -2300,6 +2313,74 @@ def retrieve_contexts(
23002313
self._api_client._verify_response(return_value)
23012314
return return_value
23022315

2316+
def _get_rag_config_operation(
2317+
self,
2318+
*,
2319+
operation_name: str,
2320+
config: Optional[types.GetRagConfigOperationConfigOrDict] = None,
2321+
) -> types.RagEngineConfigOperation:
2322+
parameter_model = types._GetRagConfigOperationParameters(
2323+
operation_name=operation_name,
2324+
config=config,
2325+
)
2326+
2327+
request_url_dict: Optional[dict[str, str]]
2328+
if not self._api_client.vertexai:
2329+
raise ValueError(
2330+
"This method is only supported in Gemini Enterprise Agent Platform mode, not in Gemini Developer API mode."
2331+
)
2332+
else:
2333+
request_dict = _GetRagConfigOperationParameters_to_vertex(parameter_model)
2334+
request_url_dict = request_dict.get("_url")
2335+
if request_url_dict:
2336+
path = "{operation_name}".format_map(request_url_dict)
2337+
else:
2338+
path = "{operation_name}"
2339+
2340+
query_params = request_dict.get("_query")
2341+
if query_params:
2342+
path = f"{path}?{urlencode(query_params)}"
2343+
# TODO: remove the hack that pops config.
2344+
request_dict.pop("config", None)
2345+
2346+
http_options: Optional[types.HttpOptions] = None
2347+
if (
2348+
parameter_model.config is not None
2349+
and parameter_model.config.http_options is not None
2350+
):
2351+
http_options = parameter_model.config.http_options
2352+
2353+
request_dict = _common.convert_to_dict(request_dict)
2354+
request_dict = _common.encode_unserializable_types(request_dict)
2355+
2356+
response = self._api_client.request("get", path, request_dict, http_options)
2357+
2358+
response_dict = {} if not response.body else json.loads(response.body)
2359+
2360+
return_value = types.RagEngineConfigOperation._from_response(
2361+
response=response_dict,
2362+
kwargs=(
2363+
{
2364+
"config": {
2365+
"response_schema": getattr(
2366+
parameter_model.config, "response_schema", None
2367+
),
2368+
"response_json_schema": getattr(
2369+
parameter_model.config, "response_json_schema", None
2370+
),
2371+
"include_all_fields": getattr(
2372+
parameter_model.config, "include_all_fields", None
2373+
),
2374+
}
2375+
}
2376+
if getattr(parameter_model, "config", None)
2377+
else {}
2378+
),
2379+
)
2380+
2381+
self._api_client._verify_response(return_value)
2382+
return return_value
2383+
23032384
def create_corpus(
23042385
self,
23052386
*,
@@ -2372,6 +2453,67 @@ def delete_file(
23722453

23732454
return None
23742455

2456+
def update_corpus(
2457+
self,
2458+
*,
2459+
name: str,
2460+
rag_corpus: types.RagCorpusOrDict,
2461+
config: Optional[types.UpdateRagCorpusConfigOrDict] = None,
2462+
) -> types.RagCorpus:
2463+
"""
2464+
Updates a Rag Corpus and waits for completion.
2465+
2466+
Args:
2467+
name: The name of the RagCorpus to update, formatted as
2468+
`projects/{project}/locations/{location}/ragCorpora/{corpus_id}`.
2469+
rag_corpus: The RagCorpus to update.
2470+
config: The configuration to use for the RagCorpus update request.
2471+
2472+
Returns:
2473+
The updated RagCorpus.
2474+
"""
2475+
operation = self._update_corpus(name=name, rag_corpus=rag_corpus, config=config)
2476+
2477+
operation = _operations_utils.await_operation(
2478+
operation_name=operation.name,
2479+
get_operation_fn=self._get_corpus_operation,
2480+
)
2481+
2482+
if operation.error:
2483+
raise RuntimeError(f"Failed to update RagCorpus: {operation.error}")
2484+
2485+
return self.get_corpus(name=operation.response.name)
2486+
2487+
def update_config(
2488+
self,
2489+
*,
2490+
updated_config: types.RagEngineConfigOrDict,
2491+
request_config: Optional[types.UpdateRagConfigOrDict] = None,
2492+
) -> types.RagEngineConfig:
2493+
"""
2494+
Updates a RagEngineConfig and waits for completion.
2495+
2496+
Args:
2497+
updated_config: The RagEngineConfig to update.
2498+
request_config: The configuration to use for the RagEngineConfig update request.
2499+
2500+
Returns:
2501+
The updated RagEngineConfig.
2502+
"""
2503+
operation = self._update_config(
2504+
updated_config=updated_config, config=request_config
2505+
)
2506+
2507+
operation = _operations_utils.await_operation(
2508+
operation_name=operation.name,
2509+
get_operation_fn=self._get_rag_config_operation,
2510+
)
2511+
2512+
if operation.error:
2513+
raise RuntimeError(f"Failed to update RagEngineConfig: {operation.error}")
2514+
2515+
return self.get_config()
2516+
23752517

23762518
class AsyncRag(_api_module.BaseModule):
23772519

@@ -3333,6 +3475,76 @@ async def retrieve_contexts(
33333475
self._api_client._verify_response(return_value)
33343476
return return_value
33353477

3478+
async def _get_rag_config_operation(
3479+
self,
3480+
*,
3481+
operation_name: str,
3482+
config: Optional[types.GetRagConfigOperationConfigOrDict] = None,
3483+
) -> types.RagEngineConfigOperation:
3484+
parameter_model = types._GetRagConfigOperationParameters(
3485+
operation_name=operation_name,
3486+
config=config,
3487+
)
3488+
3489+
request_url_dict: Optional[dict[str, str]]
3490+
if not self._api_client.vertexai:
3491+
raise ValueError(
3492+
"This method is only supported in Gemini Enterprise Agent Platform mode, not in Gemini Developer API mode."
3493+
)
3494+
else:
3495+
request_dict = _GetRagConfigOperationParameters_to_vertex(parameter_model)
3496+
request_url_dict = request_dict.get("_url")
3497+
if request_url_dict:
3498+
path = "{operation_name}".format_map(request_url_dict)
3499+
else:
3500+
path = "{operation_name}"
3501+
3502+
query_params = request_dict.get("_query")
3503+
if query_params:
3504+
path = f"{path}?{urlencode(query_params)}"
3505+
# TODO: remove the hack that pops config.
3506+
request_dict.pop("config", None)
3507+
3508+
http_options: Optional[types.HttpOptions] = None
3509+
if (
3510+
parameter_model.config is not None
3511+
and parameter_model.config.http_options is not None
3512+
):
3513+
http_options = parameter_model.config.http_options
3514+
3515+
request_dict = _common.convert_to_dict(request_dict)
3516+
request_dict = _common.encode_unserializable_types(request_dict)
3517+
3518+
response = await self._api_client.async_request(
3519+
"get", path, request_dict, http_options
3520+
)
3521+
3522+
response_dict = {} if not response.body else json.loads(response.body)
3523+
3524+
return_value = types.RagEngineConfigOperation._from_response(
3525+
response=response_dict,
3526+
kwargs=(
3527+
{
3528+
"config": {
3529+
"response_schema": getattr(
3530+
parameter_model.config, "response_schema", None
3531+
),
3532+
"response_json_schema": getattr(
3533+
parameter_model.config, "response_json_schema", None
3534+
),
3535+
"include_all_fields": getattr(
3536+
parameter_model.config, "include_all_fields", None
3537+
),
3538+
}
3539+
}
3540+
if getattr(parameter_model, "config", None)
3541+
else {}
3542+
),
3543+
)
3544+
3545+
self._api_client._verify_response(return_value)
3546+
return return_value
3547+
33363548
async def create_corpus(
33373549
self,
33383550
*,
@@ -3404,3 +3616,66 @@ async def delete_file(
34043616
)
34053617

34063618
return None
3619+
3620+
async def update_corpus(
3621+
self,
3622+
*,
3623+
name: str,
3624+
rag_corpus: types.RagCorpusOrDict,
3625+
config: Optional[types.UpdateRagCorpusConfigOrDict] = None,
3626+
) -> types.RagCorpus:
3627+
"""
3628+
Updates a Rag Corpus and waits for completion asynchronously.
3629+
3630+
Args:
3631+
name: The name of the RagCorpus to update, formatted as
3632+
`projects/{project}/locations/{location}/ragCorpora/{corpus_id}`.
3633+
rag_corpus: The RagCorpus to update.
3634+
config: The configuration to use for the RagCorpus update request.
3635+
3636+
Returns:
3637+
The updated RagCorpus.
3638+
"""
3639+
operation = await self._update_corpus(
3640+
name=name, rag_corpus=rag_corpus, config=config
3641+
)
3642+
3643+
operation = await _operations_utils.await_operation_async(
3644+
operation_name=operation.name,
3645+
get_operation_fn=self._get_corpus_operation,
3646+
)
3647+
3648+
if operation.error:
3649+
raise RuntimeError(f"Failed to update RagCorpus: {operation.error}")
3650+
3651+
return await self.get_corpus(name=operation.response.name)
3652+
3653+
async def update_config(
3654+
self,
3655+
*,
3656+
updated_config: types.RagEngineConfigOrDict,
3657+
request_config: Optional[types.UpdateRagConfigOrDict] = None,
3658+
) -> types.RagEngineConfig:
3659+
"""
3660+
Updates a RagEngineConfig and waits for completion asynchronously.
3661+
3662+
Args:
3663+
updated_config: The RagEngineConfig to update.
3664+
request_config: The configuration to use for the RagEngineConfig update request.
3665+
3666+
Returns:
3667+
The updated RagEngineConfig.
3668+
"""
3669+
operation = await self._update_config(
3670+
updated_config=updated_config, config=request_config
3671+
)
3672+
3673+
operation = await _operations_utils.await_operation_async(
3674+
operation_name=operation.name,
3675+
get_operation_fn=self._get_rag_config_operation,
3676+
)
3677+
3678+
if operation.error:
3679+
raise RuntimeError(f"Failed to update RagEngineConfig: {operation.error}")
3680+
3681+
return await self.get_config()

agentplatform/_genai/types/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@
9494
from .common import _GetEvaluationSetParameters
9595
from .common import _GetMultimodalDatasetOperationParameters
9696
from .common import _GetMultimodalDatasetParameters
97+
from .common import _GetRagConfigOperationParameters
9798
from .common import _GetRagConfigRequestParameters
9899
from .common import _GetRagCorpusRequestParameters
99100
from .common import _GetRagFileRequestParameters
@@ -692,6 +693,9 @@
692693
from .common import GetPromptConfigOrDict
693694
from .common import GetRagConfig
694695
from .common import GetRagConfigDict
696+
from .common import GetRagConfigOperationConfig
697+
from .common import GetRagConfigOperationConfigDict
698+
from .common import GetRagConfigOperationConfigOrDict
695699
from .common import GetRagConfigOrDict
696700
from .common import GetRagCorpusConfig
697701
from .common import GetRagCorpusConfigDict
@@ -1108,6 +1112,9 @@
11081112
from .common import RagEmbeddingModelConfigVertexPredictionEndpointOrDict
11091113
from .common import RagEngineConfig
11101114
from .common import RagEngineConfigDict
1115+
from .common import RagEngineConfigOperation
1116+
from .common import RagEngineConfigOperationDict
1117+
from .common import RagEngineConfigOperationOrDict
11111118
from .common import RagEngineConfigOrDict
11121119
from .common import RagFile
11131120
from .common import RagFileDict
@@ -2731,6 +2738,12 @@
27312738
"RetrieveContextsResponse",
27322739
"RetrieveContextsResponseDict",
27332740
"RetrieveContextsResponseOrDict",
2741+
"GetRagConfigOperationConfig",
2742+
"GetRagConfigOperationConfigDict",
2743+
"GetRagConfigOperationConfigOrDict",
2744+
"RagEngineConfigOperation",
2745+
"RagEngineConfigOperationDict",
2746+
"RagEngineConfigOperationOrDict",
27342747
"GetAgentEngineRuntimeRevisionConfig",
27352748
"GetAgentEngineRuntimeRevisionConfigDict",
27362749
"GetAgentEngineRuntimeRevisionConfigOrDict",
@@ -3334,6 +3347,7 @@
33343347
"_DeleteRagFileRequestParameters",
33353348
"_UpdateRagConfigRequestParameters",
33363349
"_RetrieveRagContextsRequestParameters",
3350+
"_GetRagConfigOperationParameters",
33373351
"_GetAgentEngineRuntimeRevisionRequestParameters",
33383352
"_ListAgentEngineRuntimeRevisionsRequestParameters",
33393353
"_DeleteAgentEngineRuntimeRevisionRequestParameters",

0 commit comments

Comments
 (0)