diff --git a/sdks/python/apache_beam/ml/inference/vertex_ai_inference_it_test.py b/sdks/python/apache_beam/ml/inference/vertex_ai_inference_it_test.py index 11643992c392..8908477dae7f 100644 --- a/sdks/python/apache_beam/ml/inference/vertex_ai_inference_it_test.py +++ b/sdks/python/apache_beam/ml/inference/vertex_ai_inference_it_test.py @@ -27,6 +27,7 @@ from apache_beam.io.filesystems import FileSystems from apache_beam.ml.inference.base import RunInference from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.vertex_ai_skip import skip_if_vertex_ai_disabled # pylint: disable=ungrouped-imports try: @@ -53,8 +54,9 @@ _INVOKE_OUTPUT_DIR = "gs://apache-beam-ml/testing/outputs/vertex_invoke" +@skip_if_vertex_ai_disabled +@pytest.mark.vertex_ai_postcommit class VertexAIInference(unittest.TestCase): - @pytest.mark.vertex_ai_postcommit def test_vertex_ai_run_flower_image_classification(self): output_file = '/'.join([_OUTPUT_DIR, str(uuid.uuid4()), 'output.txt']) @@ -73,7 +75,6 @@ def test_vertex_ai_run_flower_image_classification(self): test_pipeline.get_full_options_as_args(**extra_opts)) self.assertEqual(FileSystems().exists(output_file), True) - @pytest.mark.vertex_ai_postcommit @unittest.skipIf( not _INVOKE_ENDPOINT_ID, "Invoke endpoint not configured. Set _INVOKE_ENDPOINT_ID.") diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai_test.py b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai_test.py index 50507c54e36d..5fc3bfa6d9ae 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai_test.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai_test.py @@ -20,10 +20,13 @@ import unittest import uuid +import pytest + import apache_beam as beam from apache_beam.ml.inference.base import RunInference from apache_beam.ml.transforms import base from apache_beam.ml.transforms.base import MLTransform +from apache_beam.testing.vertex_ai_skip import skip_if_vertex_ai_disabled # pylint: disable=ungrouped-imports # isort: off @@ -58,6 +61,8 @@ model_name: str = "text-embedding-005" +@skip_if_vertex_ai_disabled +@pytest.mark.vertex_ai_postcommit @unittest.skipIf( VertexAITextEmbeddings is None, 'Vertex AI Python SDK is not installed.') class VertexAIEmbeddingsTest(unittest.TestCase): @@ -308,6 +313,8 @@ def _make_text_chunk(input: str) -> Chunk: return Chunk(content=Content(text=input)) +@skip_if_vertex_ai_disabled +@pytest.mark.vertex_ai_postcommit @unittest.skipIf( VertexAIMultiModalEmbeddings is None, 'Vertex AI Python SDK is not installed.') diff --git a/sdks/python/apache_beam/testing/vertex_ai_skip.py b/sdks/python/apache_beam/testing/vertex_ai_skip.py new file mode 100644 index 000000000000..1196b29c720b --- /dev/null +++ b/sdks/python/apache_beam/testing/vertex_ai_skip.py @@ -0,0 +1,38 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Centralized skip for Vertex AI tests when dependencies are missing. + +Test modules use skip_if_vertex_ai_disabled on classes that require +the Vertex AI Python SDK to be installed. +""" + +import pytest + + +def _is_vertex_ai_available() -> bool: + """Return True if Vertex AI client dependencies are importable.""" + try: + import vertexai # pylint: disable=unused-import + except ImportError: + return False + return True + + +skip_if_vertex_ai_disabled = pytest.mark.skipif( + not _is_vertex_ai_available(), + reason='Vertex AI dependencies not available.')