Skip to content

Commit 4f0fdfe

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add RagMetadata and RagDataSchema management APIs
PiperOrigin-RevId: 888262230
1 parent 9b7dc29 commit 4f0fdfe

File tree

8 files changed

+1095
-19
lines changed

8 files changed

+1095
-19
lines changed

tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_data_service.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,12 @@
7272
from google.cloud.aiplatform_v1beta1.services.vertex_rag_data_service import (
7373
VertexRagDataServiceClient,
7474
)
75-
from google.cloud.aiplatform_v1beta1.services.vertex_rag_data_service import pagers
76-
from google.cloud.aiplatform_v1beta1.services.vertex_rag_data_service import transports
75+
from google.cloud.aiplatform_v1beta1.services.vertex_rag_data_service import (
76+
pagers,
77+
)
78+
from google.cloud.aiplatform_v1beta1.services.vertex_rag_data_service import (
79+
transports,
80+
)
7781
from google.cloud.aiplatform_v1beta1.types import api_auth
7882
from google.cloud.aiplatform_v1beta1.types import encryption_spec
7983
from google.cloud.aiplatform_v1beta1.types import io
@@ -5294,6 +5298,79 @@ async def test_delete_rag_file_flattened_error_async():
52945298
)
52955299

52965300

5301+
@pytest.mark.parametrize(
5302+
"request_type",
5303+
[
5304+
vertex_rag_data_service.BatchCreateRagDataSchemasRequest,
5305+
dict,
5306+
],
5307+
)
5308+
def test_batch_create_rag_data_schemas(request_type, transport: str = "grpc"):
5309+
client = VertexRagDataServiceClient(
5310+
credentials=ga_credentials.AnonymousCredentials(),
5311+
transport=transport,
5312+
)
5313+
5314+
# Everything is optional in proto3 as far as the runtime is concerned,
5315+
# and we are mocking out the actual API, so just send an empty request.
5316+
request = request_type()
5317+
5318+
# Mock the actual call within the gRPC stub, and fake the request.
5319+
with mock.patch.object(
5320+
type(client.transport.batch_create_rag_data_schemas), "__call__"
5321+
) as call:
5322+
# Designate an appropriate return value for the call.
5323+
call.return_value = operations_pb2.Operation(name="operations/spam")
5324+
response = client.batch_create_rag_data_schemas(request)
5325+
5326+
# Establish that the underlying gRPC stub method was called.
5327+
assert len(call.mock_calls) == 1
5328+
_, args, _ = call.mock_calls[0]
5329+
request = vertex_rag_data_service.BatchCreateRagDataSchemasRequest()
5330+
assert args[0] == request
5331+
5332+
# Establish that the response is the type that we expect.
5333+
assert isinstance(response, future.Future)
5334+
5335+
5336+
@pytest.mark.parametrize(
5337+
"request_type",
5338+
[
5339+
vertex_rag_data_service.ListRagDataSchemasRequest,
5340+
dict,
5341+
],
5342+
)
5343+
def test_list_rag_data_schemas(request_type, transport: str = "grpc"):
5344+
client = VertexRagDataServiceClient(
5345+
credentials=ga_credentials.AnonymousCredentials(),
5346+
transport=transport,
5347+
)
5348+
5349+
# Everything is optional in proto3 as far as the runtime is concerned,
5350+
# and we are mocking out the actual API, so just send an empty request.
5351+
request = request_type()
5352+
5353+
# Mock the actual call within the gRPC stub, and fake the request.
5354+
with mock.patch.object(
5355+
type(client.transport.list_rag_data_schemas), "__call__"
5356+
) as call:
5357+
# Designate an appropriate return value for the call.
5358+
call.return_value = vertex_rag_data_service.ListRagDataSchemasResponse(
5359+
next_page_token="next_page_token_value",
5360+
)
5361+
response = client.list_rag_data_schemas(request)
5362+
5363+
# Establish that the underlying gRPC stub method was called.
5364+
assert len(call.mock_calls) == 1
5365+
_, args, _ = call.mock_calls[0]
5366+
request = vertex_rag_data_service.ListRagDataSchemasRequest()
5367+
assert args[0] == request
5368+
5369+
# Establish that the response is the type that we expect.
5370+
assert isinstance(response, pagers.ListRagDataSchemasPager)
5371+
assert response.next_page_token == "next_page_token_value"
5372+
5373+
52975374
@pytest.mark.parametrize(
52985375
"request_type",
52995376
[

tests/unit/vertex_rag/conftest.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
#
15-
from unittest.mock import patch
1615
from unittest import mock
1716
from google import auth
1817
from google.api_core import operation as ga_operation
1918
from google.auth import credentials as auth_credentials
20-
from vertexai import rag
21-
from vertexai.preview import rag as rag_preview
2219
from google.cloud.aiplatform_v1 import (
2320
DeleteRagCorpusRequest,
2421
VertexRagDataServiceAsyncClient,
@@ -51,17 +48,19 @@ def google_auth_mock():
5148

5249
@pytest.fixture
5350
def authorized_session_mock():
54-
with patch(
55-
"google.auth.transport.requests.AuthorizedSession"
56-
) as MockAuthorizedSession:
51+
from google.auth.transport import requests
52+
53+
with mock.patch.object(requests, "AuthorizedSession") as MockAuthorizedSession:
5754
mock_auth_session = MockAuthorizedSession(_TEST_CREDENTIALS)
5855
yield mock_auth_session
5956

6057

6158
@pytest.fixture
6259
def rag_data_client_mock():
60+
from vertexai.rag.utils import _gapic_utils
61+
6362
with mock.patch.object(
64-
rag.utils._gapic_utils, "create_rag_data_service_client"
63+
_gapic_utils, "create_rag_data_service_client"
6564
) as rag_data_client_mock:
6665
api_client_mock = mock.Mock(spec=VertexRagDataServiceClient)
6766

@@ -84,8 +83,10 @@ def rag_data_client_mock():
8483

8584
@pytest.fixture
8685
def rag_data_client_preview_mock():
86+
from vertexai.preview.rag.utils import _gapic_utils
87+
8788
with mock.patch.object(
88-
rag_preview.utils._gapic_utils, "create_rag_data_service_client"
89+
_gapic_utils, "create_rag_data_service_client"
8990
) as rag_data_client_mock:
9091
api_client_mock = mock.Mock(spec=VertexRagDataServiceClientPreview)
9192

@@ -108,8 +109,10 @@ def rag_data_client_preview_mock():
108109

109110
@pytest.fixture
110111
def rag_data_client_mock_exception():
112+
from vertexai.rag.utils import _gapic_utils
113+
111114
with mock.patch.object(
112-
rag.utils._gapic_utils, "create_rag_data_service_client"
115+
_gapic_utils, "create_rag_data_service_client"
113116
) as rag_data_client_mock_exception:
114117
api_client_mock = mock.Mock(spec=VertexRagDataServiceClient)
115118
# create_rag_corpus
@@ -138,8 +141,10 @@ def rag_data_client_mock_exception():
138141

139142
@pytest.fixture
140143
def rag_data_client_preview_mock_exception():
144+
from vertexai.preview.rag.utils import _gapic_utils
145+
141146
with mock.patch.object(
142-
rag_preview.utils._gapic_utils, "create_rag_data_service_client"
147+
_gapic_utils, "create_rag_data_service_client"
143148
) as rag_data_client_mock_exception:
144149
api_client_mock = mock.Mock(spec=VertexRagDataServiceClientPreview)
145150
# create_rag_corpus
@@ -172,8 +177,10 @@ def rag_data_client_preview_mock_exception():
172177

173178
@pytest.fixture
174179
def rag_data_async_client_mock_exception():
180+
from vertexai.rag.utils import _gapic_utils
181+
175182
with mock.patch.object(
176-
rag.utils._gapic_utils, "create_rag_data_service_async_client"
183+
_gapic_utils, "create_rag_data_service_async_client"
177184
) as rag_data_async_client_mock_exception:
178185
api_client_mock = mock.Mock(spec=VertexRagDataServiceAsyncClient)
179186
# import_rag_files
@@ -184,8 +191,10 @@ def rag_data_async_client_mock_exception():
184191

185192
@pytest.fixture
186193
def rag_data_async_client_preview_mock_exception():
194+
from vertexai.preview.rag.utils import _gapic_utils
195+
187196
with mock.patch.object(
188-
rag_preview.utils._gapic_utils, "create_rag_data_service_async_client"
197+
_gapic_utils, "create_rag_data_service_async_client"
189198
) as rag_data_async_client_mock_exception:
190199
api_client_mock = mock.Mock(spec=VertexRagDataServiceAsyncClientPreview)
191200
# import_rag_files

tests/unit/vertex_rag/test_rag_constants_preview.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,23 @@
2323
ImportRagFilesRequest,
2424
ImportRagFilesResponse,
2525
JiraSource as GapicJiraSource,
26+
MetadataValue as GapicMetadataValue,
2627
RagContexts,
2728
RagCorpus as GapicRagCorpus,
29+
RagDataSchema as GapicRagDataSchema,
2830
RagEngineConfig as GapicRagEngineConfig,
2931
RagFileChunkingConfig,
3032
RagFileParsingConfig,
3133
RagFileTransformationConfig,
3234
RagFile as GapicRagFile,
3335
RagManagedDbConfig as GapicRagManagedDbConfig,
36+
RagMetadataSchemaDetails as GapicRagMetadataSchemaDetails,
37+
RagMetadata as GapicRagMetadata,
3438
RagVectorDbConfig as GapicRagVectorDbConfig,
3539
RetrieveContextsResponse,
3640
SharePointSources as GapicSharePointSources,
3741
SlackSource as GapicSlackSource,
42+
UserSpecifiedMetadata as GapicUserSpecifiedMetadata,
3843
VertexAiSearchConfig as GapicVertexAiSearchConfig,
3944
)
4045
from google.cloud.aiplatform_v1beta1.types import api_auth
@@ -54,15 +59,19 @@
5459
LlmParserConfig,
5560
LlmRanker,
5661
MemoryCorpus,
62+
MetadataValue,
5763
Pinecone,
5864
RagCorpus,
5965
RagCorpusTypeConfig,
66+
RagDataSchema,
6067
RagEmbeddingModelConfig,
6168
RagEngineConfig,
6269
RagFile,
6370
RagManagedDb,
6471
RagManagedDbConfig,
6572
RagManagedVertexVectorSearch,
73+
RagMetadata,
74+
RagMetadataSchemaDetails,
6675
RagResource,
6776
RagRetrievalConfig,
6877
RagVectorDbConfig,
@@ -76,6 +85,7 @@
7685
SlackChannelsSource,
7786
Spanner,
7887
Unprovisioned,
88+
UserSpecifiedMetadata,
7989
VertexAiSearchConfig,
8090
VertexFeatureStore,
8191
VertexPredictionEndpoint,
@@ -1146,3 +1156,54 @@
11461156
filter=Filter(vector_distance_threshold=0.5),
11471157
ranking=Ranking(llm_ranker=LlmRanker(model_name="test-model-name")),
11481158
)
1159+
1160+
# RagMetadata and RagDataSchema
1161+
TEST_RAG_DATA_SCHEMA_ID = "test-data-schema-id"
1162+
TEST_RAG_DATA_SCHEMA_RESOURCE_NAME = (
1163+
f"{TEST_RAG_CORPUS_RESOURCE_NAME}/ragDataSchemas/{TEST_RAG_DATA_SCHEMA_ID}"
1164+
)
1165+
TEST_RAG_METADATA_ID = "test-metadata-id"
1166+
TEST_RAG_METADATA_RESOURCE_NAME = (
1167+
f"{TEST_RAG_FILE_RESOURCE_NAME}/ragMetadata/{TEST_RAG_METADATA_ID}"
1168+
)
1169+
1170+
TEST_GAPIC_RAG_DATA_SCHEMA = GapicRagDataSchema(
1171+
name=TEST_RAG_DATA_SCHEMA_RESOURCE_NAME,
1172+
key="key1",
1173+
schema_details=GapicRagMetadataSchemaDetails(
1174+
type=GapicRagMetadataSchemaDetails.DataType.STRING,
1175+
search_strategy=GapicRagMetadataSchemaDetails.SearchStrategy(
1176+
search_strategy_type=GapicRagMetadataSchemaDetails.SearchStrategy.SearchStrategyType.EXACT_SEARCH
1177+
),
1178+
granularity=GapicRagMetadataSchemaDetails.Granularity.GRANULARITY_FILE_LEVEL,
1179+
),
1180+
)
1181+
1182+
TEST_RAG_DATA_SCHEMA = RagDataSchema(
1183+
name=TEST_RAG_DATA_SCHEMA_RESOURCE_NAME,
1184+
key="key1",
1185+
schema_details=RagMetadataSchemaDetails(
1186+
type="STRING",
1187+
search_strategy=RagMetadataSchemaDetails.SearchStrategy(
1188+
search_strategy_type="EXACT_SEARCH"
1189+
),
1190+
granularity="GRANULARITY_FILE_LEVEL",
1191+
),
1192+
)
1193+
1194+
TEST_GAPIC_RAG_METADATA = GapicRagMetadata(
1195+
name=TEST_RAG_METADATA_RESOURCE_NAME,
1196+
user_specified_metadata=GapicUserSpecifiedMetadata(
1197+
key="key1",
1198+
value=GapicMetadataValue(str_value="value1"),
1199+
),
1200+
)
1201+
1202+
TEST_RAG_METADATA = RagMetadata(
1203+
name=TEST_RAG_METADATA_RESOURCE_NAME,
1204+
user_specified_metadata=UserSpecifiedMetadata(
1205+
values={
1206+
"key1": MetadataValue(string_value="value1"),
1207+
}
1208+
),
1209+
)

0 commit comments

Comments
 (0)