diff --git a/tests/unit/vertex_langchain/test_reasoning_engines.py b/tests/unit/vertex_langchain/test_reasoning_engines.py index 019dc214d9..c0bf35579f 100644 --- a/tests/unit/vertex_langchain/test_reasoning_engines.py +++ b/tests/unit/vertex_langchain/test_reasoning_engines.py @@ -42,6 +42,7 @@ from vertexai.preview import reasoning_engines from vertexai.reasoning_engines import _reasoning_engines from vertexai.reasoning_engines import _utils +from google.iam.v1 import policy_pb2 from google.api import httpbody_pb2 from google.protobuf import field_mask_pb2 from google.protobuf import struct_pb2 @@ -794,6 +795,46 @@ def test_create_reasoning_engine( retry=_TEST_RETRY, ) + def test_get_iam_policy(self): + with mock.patch.object( + base.VertexAiResourceNoun, "_get_gca_resource" + ) as mock_get_gca_resource: + mock_get_gca_resource.return_value = types.ReasoningEngine( + name=_TEST_REASONING_ENGINE_RESOURCE_NAME + ) + reasoning_engine = reasoning_engines.ReasoningEngine( + _TEST_REASONING_ENGINE_RESOURCE_NAME + ) + + test_policy = policy_pb2.Policy(version=1) + with mock.patch.object( + reasoning_engine.api_client, "get_iam_policy" + ) as mock_get_iam_policy: + mock_get_iam_policy.return_value = test_policy + policy = reasoning_engine.get_iam_policy(policy_version=1) + mock_get_iam_policy.assert_called_once() + assert policy == test_policy + + def test_set_iam_policy(self): + with mock.patch.object( + base.VertexAiResourceNoun, "_get_gca_resource" + ) as mock_get_gca_resource: + mock_get_gca_resource.return_value = types.ReasoningEngine( + name=_TEST_REASONING_ENGINE_RESOURCE_NAME + ) + reasoning_engine = reasoning_engines.ReasoningEngine( + _TEST_REASONING_ENGINE_RESOURCE_NAME + ) + + test_policy = policy_pb2.Policy(version=1) + with mock.patch.object( + reasoning_engine.api_client, "set_iam_policy" + ) as mock_set_iam_policy: + mock_set_iam_policy.return_value = test_policy + policy = reasoning_engine.set_iam_policy(test_policy) + mock_set_iam_policy.assert_called_once() + assert policy == test_policy + @pytest.mark.usefixtures("caplog") def test_create_reasoning_engine_warn_resource_name( self, diff --git a/vertexai/reasoning_engines/_reasoning_engines.py b/vertexai/reasoning_engines/_reasoning_engines.py index 322bf2a2d4..e84559edfa 100644 --- a/vertexai/reasoning_engines/_reasoning_engines.py +++ b/vertexai/reasoning_engines/_reasoning_engines.py @@ -44,6 +44,9 @@ from google.cloud.aiplatform_v1beta1 import types as aip_types from google.cloud.aiplatform_v1beta1.types import reasoning_engine_service from vertexai.reasoning_engines import _utils +from google.iam.v1 import iam_policy_pb2 +from google.iam.v1 import options_pb2 +from google.iam.v1 import policy_pb2 from google.protobuf import field_mask_pb2 @@ -114,18 +117,18 @@ def register_operations(self, **kwargs): class ReasoningEngine(base.VertexAiResourceNounWithFutureManager): - """Represents a Vertex AI Reasoning Engine resource.""" + """Represents a Vertex AI Reasoning Engine resource.""" - client_class = aip_utils.ReasoningEngineClientWithOverride - _resource_noun = "reasoning_engine" - _getter_method = "get_reasoning_engine" - _list_method = "list_reasoning_engines" - _delete_method = "delete_reasoning_engine" - _parse_resource_name_method = "parse_reasoning_engine_path" - _format_resource_name_method = "reasoning_engine_path" + client_class = aip_utils.ReasoningEngineClientWithOverride + _resource_noun = "reasoning_engine" + _getter_method = "get_reasoning_engine" + _list_method = "list_reasoning_engines" + _delete_method = "delete_reasoning_engine" + _parse_resource_name_method = "parse_reasoning_engine_path" + _format_resource_name_method = "reasoning_engine_path" - def __init__(self, reasoning_engine_name: str): - """Retrieves a Reasoning Engine resource. + def __init__(self, reasoning_engine_name: str): + """Retrieves a Reasoning Engine resource. Args: reasoning_engine_name (str): @@ -133,24 +136,24 @@ def __init__(self, reasoning_engine_name: str): "projects/123/locations/us-central1/reasoningEngines/456" or "456" when project and location are initialized or passed. """ - super().__init__(resource_name=reasoning_engine_name) - self.execution_api_client = initializer.global_config.create_client( + super().__init__(resource_name=reasoning_engine_name) + self.execution_api_client = initializer.global_config.create_client( client_class=aip_utils.ReasoningEngineExecutionClientWithOverride, ) - self._gca_resource = self._get_gca_resource(resource_name=reasoning_engine_name) - try: - _register_api_methods_or_raise(self) - except Exception as e: - logging.warning("Failed to register API methods: {%s}", e) - self._operation_schemas = None - - @property - def resource_name(self) -> str: - """Fully-qualified resource name.""" - return self._gca_resource.name - - @classmethod - def create( + self._gca_resource = self._get_gca_resource(resource_name=reasoning_engine_name) + try: + _register_api_methods_or_raise(self) + except Exception as e: + logging.warning("Failed to register API methods: {%s}", e) + self._operation_schemas = None + + @property + def resource_name(self) -> str: + """Fully-qualified resource name.""" + return self._gca_resource.name + + @classmethod + def create( cls, reasoning_engine: Union[Queryable, OperationRegistrable], *, @@ -162,7 +165,7 @@ def create( sys_version: Optional[str] = None, extra_packages: Optional[Sequence[str]] = None, ) -> "ReasoningEngine": - """Creates a new ReasoningEngine. + """Creates a new ReasoningEngine. The Reasoning Engine will be an instance of the `reasoning_engine` that was passed in, running remotely on Vertex AI. @@ -244,31 +247,31 @@ def create( IOError: If requirements is a string that corresponds to a nonexistent file. """ - if not sys_version: - sys_version = f"{sys.version_info.major}.{sys.version_info.minor}" - _validate_sys_version_or_raise(sys_version) - reasoning_engine = _validate_reasoning_engine_or_raise(reasoning_engine) - requirements = _validate_requirements_or_raise(requirements) - extra_packages = _validate_extra_packages_or_raise(extra_packages) - - if reasoning_engine_name: - _LOGGER.warning( + if not sys_version: + sys_version = f"{sys.version_info.major}.{sys.version_info.minor}" + _validate_sys_version_or_raise(sys_version) + reasoning_engine = _validate_reasoning_engine_or_raise(reasoning_engine) + requirements = _validate_requirements_or_raise(requirements) + extra_packages = _validate_extra_packages_or_raise(extra_packages) + + if reasoning_engine_name: + _LOGGER.warning( "ReasoningEngine does not support user-defined resource IDs at " f"the moment. Therefore {reasoning_engine_name=} would be " "ignored and a random ID will be generated instead." ) - sdk_resource = cls.__new__(cls) - base.VertexAiResourceNounWithFutureManager.__init__( + sdk_resource = cls.__new__(cls) + base.VertexAiResourceNounWithFutureManager.__init__( sdk_resource, resource_name=reasoning_engine_name, ) - staging_bucket = initializer.global_config.staging_bucket - _validate_staging_bucket_or_raise(staging_bucket) - # Prepares the Reasoning Engine for creation in Vertex AI. - # This involves packaging and uploading the artifacts for - # reasoning_engine, requirements and extra_packages to - # `staging_bucket/gcs_dir_name`. - _prepare( + staging_bucket = initializer.global_config.staging_bucket + _validate_staging_bucket_or_raise(staging_bucket) + # Prepares the Reasoning Engine for creation in Vertex AI. + # This involves packaging and uploading the artifacts for + # reasoning_engine, requirements and extra_packages to + # `staging_bucket/gcs_dir_name`. + _prepare( reasoning_engine=reasoning_engine, requirements=requirements, project=sdk_resource.project, @@ -277,8 +280,8 @@ def create( gcs_dir_name=gcs_dir_name, extra_packages=extra_packages, ) - # Update the package spec. - package_spec = aip_types.ReasoningEngineSpec.PackageSpec( + # Update the package spec. + package_spec = aip_types.ReasoningEngineSpec.PackageSpec( python_version=sys_version, pickle_object_gcs_uri="{}/{}/{}".format( staging_bucket, @@ -286,26 +289,26 @@ def create( _BLOB_FILENAME, ), ) - if extra_packages: - package_spec.dependency_files_gcs_uri = "{}/{}/{}".format( + if extra_packages: + package_spec.dependency_files_gcs_uri = "{}/{}/{}".format( staging_bucket, gcs_dir_name, _EXTRA_PACKAGES_FILE, ) - if requirements: - package_spec.requirements_gcs_uri = "{}/{}/{}".format( + if requirements: + package_spec.requirements_gcs_uri = "{}/{}/{}".format( staging_bucket, gcs_dir_name, _REQUIREMENTS_FILE, ) - reasoning_engine_spec = aip_types.ReasoningEngineSpec( + reasoning_engine_spec = aip_types.ReasoningEngineSpec( package_spec=package_spec, ) - class_methods_spec = _generate_class_methods_spec_or_raise( + class_methods_spec = _generate_class_methods_spec_or_raise( reasoning_engine, _get_registered_operations(reasoning_engine) ) - reasoning_engine_spec.class_methods.extend(class_methods_spec) - operation_future = sdk_resource.api_client.create_reasoning_engine( + reasoning_engine_spec.class_methods.extend(class_methods_spec) + operation_future = sdk_resource.api_client.create_reasoning_engine( parent=initializer.global_config.common_location_path( project=sdk_resource.project, location=sdk_resource.location ), @@ -316,32 +319,32 @@ def create( spec=reasoning_engine_spec, ), ) - _LOGGER.log_create_with_lro(cls, operation_future) - created_resource = operation_future.result() - _LOGGER.log_create_complete( + _LOGGER.log_create_with_lro(cls, operation_future) + created_resource = operation_future.result() + _LOGGER.log_create_complete( cls, created_resource, cls._resource_noun, module_name="vertexai.preview.reasoning_engines", ) - # We use `._get_gca_resource(...)` instead of `created_resource` to - # fully instantiate the attributes of the reasoning engine. - sdk_resource._gca_resource = sdk_resource._get_gca_resource( + # We use `._get_gca_resource(...)` instead of `created_resource` to + # fully instantiate the attributes of the reasoning engine. + sdk_resource._gca_resource = sdk_resource._get_gca_resource( resource_name=created_resource.name ) - sdk_resource.execution_api_client = initializer.global_config.create_client( + sdk_resource.execution_api_client = initializer.global_config.create_client( client_class=aip_utils.ReasoningEngineExecutionClientWithOverride, credentials=sdk_resource.credentials, location_override=sdk_resource.location, ) - try: - _register_api_methods_or_raise(sdk_resource) - except Exception as e: - logging.warning("Failed to register API methods: {%s}", e) - sdk_resource._operation_schemas = None - return sdk_resource + try: + _register_api_methods_or_raise(sdk_resource) + except Exception as e: + logging.warning("Failed to register API methods: {%s}", e) + sdk_resource._operation_schemas = None + return sdk_resource - def update( + def update( self, *, reasoning_engine: Optional[Union[Queryable, OperationRegistrable]] = None, @@ -352,7 +355,7 @@ def update( sys_version: Optional[str] = None, extra_packages: Optional[Sequence[str]] = None, ) -> "ReasoningEngine": - """Updates an existing ReasoningEngine. + """Updates an existing ReasoningEngine. This method updates the configuration of an existing ReasoningEngine running remotely, which is identified by its resource name. @@ -406,12 +409,12 @@ def update( IOError: If requirements is a string that corresponds to a nonexistent file. """ - staging_bucket = initializer.global_config.staging_bucket - _validate_staging_bucket_or_raise(staging_bucket) - historical_operation_schemas = self.operation_schemas() + staging_bucket = initializer.global_config.staging_bucket + _validate_staging_bucket_or_raise(staging_bucket) + historical_operation_schemas = self.operation_schemas() - # Validate the arguments. - if not any( + # Validate the arguments. + if not any( [ reasoning_engine, requirements, @@ -420,25 +423,25 @@ def update( description, ] ): - raise ValueError( + raise ValueError( "At least one of `reasoning_engine`, `requirements`, " "`extra_packages`, `display_name`, or `description` must be " "specified." ) - if sys_version: - _LOGGER.warning("Updated sys_version is not supported.") - if requirements is not None: - requirements = _validate_requirements_or_raise(requirements) - if extra_packages is not None: - extra_packages = _validate_extra_packages_or_raise(extra_packages) - if reasoning_engine is not None: - reasoning_engine = _validate_reasoning_engine_or_raise(reasoning_engine) - - # Prepares the Reasoning Engine for update in Vertex AI. - # This involves packaging and uploading the artifacts for - # reasoning_engine, requirements and extra_packages to - # `staging_bucket/gcs_dir_name`. - _prepare( + if sys_version: + _LOGGER.warning("Updated sys_version is not supported.") + if requirements is not None: + requirements = _validate_requirements_or_raise(requirements) + if extra_packages is not None: + extra_packages = _validate_extra_packages_or_raise(extra_packages) + if reasoning_engine is not None: + reasoning_engine = _validate_reasoning_engine_or_raise(reasoning_engine) + + # Prepares the Reasoning Engine for update in Vertex AI. + # This involves packaging and uploading the artifacts for + # reasoning_engine, requirements and extra_packages to + # `staging_bucket/gcs_dir_name`. + _prepare( reasoning_engine=reasoning_engine, requirements=requirements, project=self.project, @@ -447,7 +450,7 @@ def update( gcs_dir_name=gcs_dir_name, extra_packages=extra_packages, ) - update_request = _generate_update_request_or_raise( + update_request = _generate_update_request_or_raise( resource_name=self.resource_name, staging_bucket=staging_bucket, gcs_dir_name=gcs_dir_name, @@ -457,47 +460,82 @@ def update( display_name=display_name, description=description, ) - operation_future = self.api_client.update_reasoning_engine( + operation_future = self.api_client.update_reasoning_engine( request=update_request ) - _LOGGER.info( + _LOGGER.info( f"Update ReasoningEngine backing LRO: {operation_future.operation.name}" ) - created_resource = operation_future.result() - _LOGGER.info(f"ReasoningEngine updated. Resource name: {created_resource.name}") - self._operation_schemas = None - self.execution_api_client = initializer.global_config.create_client( + created_resource = operation_future.result() + _LOGGER.info(f"ReasoningEngine updated. Resource name: {created_resource.name}") + self._operation_schemas = None + self.execution_api_client = initializer.global_config.create_client( client_class=aip_utils.ReasoningEngineExecutionClientWithOverride, ) - # We use `._get_gca_resource(...)` instead of `created_resource` to - # fully instantiate the attributes of the reasoning engine. - self._gca_resource = self._get_gca_resource(resource_name=self.resource_name) + # We use `._get_gca_resource(...)` instead of `created_resource` to + # fully instantiate the attributes of the reasoning engine. + self._gca_resource = self._get_gca_resource(resource_name=self.resource_name) - if ( + if ( reasoning_engine is None or historical_operation_schemas == self.operation_schemas() ): - # As the API/operations of the reasoning engine are unchanged, we - # can return it here. - return self - - # If the reasoning engine has changed and the historical operation - # schemas are different from the current operation schemas, we need to - # unregister the historical operation schemas and register the current - # operation schemas. - _unregister_api_methods(self, historical_operation_schemas) - try: - _register_api_methods_or_raise(self) - except Exception as e: - logging.warning("Failed to register API methods: {%s}", e) - return self - - def operation_schemas(self) -> Sequence[_utils.JsonDict]: - """Returns the (Open)API schemas for the Reasoning Engine.""" - spec = _utils.to_dict(self._gca_resource.spec) - if not hasattr(self, "_operation_schemas") or self._operation_schemas is None: - self._operation_schemas = spec.get("classMethods", []) - return self._operation_schemas + # As the API/operations of the reasoning engine are unchanged, we + # can return it here. + return self + + # If the reasoning engine has changed and the historical operation + # schemas are different from the current operation schemas, we need to + # unregister the historical operation schemas and register the current + # operation schemas. + _unregister_api_methods(self, historical_operation_schemas) + try: + _register_api_methods_or_raise(self) + except Exception as e: + logging.warning("Failed to register API methods: {%s}", e) + return self + + def operation_schemas(self) -> Sequence[_utils.JsonDict]: + """Returns the (Open)API schemas for the Reasoning Engine.""" + spec = _utils.to_dict(self._gca_resource.spec) + if not hasattr(self, "_operation_schemas") or self._operation_schemas is None: + self._operation_schemas = spec.get("classMethods", []) + return self._operation_schemas + + def get_iam_policy( + self, policy_version: Optional[int] = None + ) -> policy_pb2.Policy: + """Gets the access control policy for this ReasoningEngine. + + Args: + policy_version: Optional. The maximum policy version that will be used + to format the policy. Valid values are 0, 1, 3. + + Returns: + The IAM policy. + """ + request = iam_policy_pb2.GetIamPolicyRequest( + resource=self.resource_name, + options=options_pb2.GetPolicyOptions( + requested_policy_version=policy_version + ), + ) + return self.api_client.get_iam_policy(request=request) + + def set_iam_policy(self, policy: policy_pb2.Policy) -> policy_pb2.Policy: + """Sets the access control policy on this ReasoningEngine. + + Args: + policy: The complete policy to be applied to the resource. + + Returns: + The new IAM policy. + """ + request = iam_policy_pb2.SetIamPolicyRequest( + resource=self.resource_name, + policy=policy, + ) + return self.api_client.set_iam_policy(request=request) def _validate_sys_version_or_raise(sys_version: str) -> None: