Skip to content

Commit 74de814

Browse files
authored
Merge pull request #463 from ji-yaqi/dev
feat: add pipeline client to vertex
2 parents e96e1b3 + 68568dd commit 74de814

10 files changed

Lines changed: 903 additions & 7 deletions

File tree

google/cloud/aiplatform/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from google.cloud.aiplatform.metadata import metadata
3131
from google.cloud.aiplatform.models import Endpoint
3232
from google.cloud.aiplatform.models import Model
33+
from google.cloud.aiplatform.pipeline_jobs import PipelineJob
3334
from google.cloud.aiplatform.jobs import (
3435
BatchPredictionJob,
3536
CustomJob,
@@ -85,6 +86,7 @@
8586
"ImageDataset",
8687
"HyperparameterTuningJob",
8788
"Model",
89+
"PipelineJob",
8890
"TabularDataset",
8991
"TextDataset",
9092
"TimeSeriesDataset",

google/cloud/aiplatform/compat/types/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
model_evaluation_slice as model_evaluation_slice_v1beta1,
4545
model_service as model_service_v1beta1,
4646
operation as operation_v1beta1,
47+
pipeline_job as pipeline_job_v1beta1,
4748
pipeline_service as pipeline_service_v1beta1,
4849
pipeline_state as pipeline_state_v1beta1,
4950
prediction_service as prediction_service_v1beta1,
@@ -158,6 +159,7 @@
158159
model_evaluation_slice_v1beta1,
159160
model_service_v1beta1,
160161
operation_v1beta1,
162+
pipeline_job_v1beta1,
161163
pipeline_service_v1beta1,
162164
pipeline_state_v1beta1,
163165
prediction_service_v1beta1,

google/cloud/aiplatform/jobs.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@
2020
import abc
2121
import copy
2222
import datetime
23-
import sys
2423
import time
25-
import logging
2624

2725
from google.cloud import storage
2826
from google.cloud import bigquery
@@ -61,7 +59,6 @@
6159
study as gca_study_compat,
6260
)
6361

64-
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
6562
_LOGGER = base.Logger(__name__)
6663

6764
_JOB_COMPLETE_STATES = (
Lines changed: 312 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2021 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import datetime
19+
import time
20+
import re
21+
from typing import Any, Optional, Dict
22+
23+
from google.auth import credentials as auth_credentials
24+
from google.cloud.aiplatform import base
25+
from google.cloud.aiplatform import initializer
26+
from google.cloud.aiplatform import utils
27+
from google.cloud.aiplatform.utils import json_utils
28+
from google.cloud.aiplatform.utils import pipeline_utils
29+
from google.protobuf import json_format
30+
31+
from google.cloud.aiplatform.compat.types import (
32+
pipeline_job_v1beta1 as gca_pipeline_job_v1beta1,
33+
pipeline_state_v1beta1 as gca_pipeline_state_v1beta1,
34+
)
35+
36+
_LOGGER = base.Logger(__name__)
37+
38+
_PIPELINE_COMPLETE_STATES = set(
39+
[
40+
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_SUCCEEDED,
41+
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_FAILED,
42+
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_CANCELLED,
43+
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_PAUSED,
44+
]
45+
)
46+
47+
# Vertex AI Pipelines service API job name relative name prefix pattern.
48+
_JOB_NAME_PATTERN = "{parent}/pipelineJobs/{job_id}"
49+
50+
# Pattern for valid names used as a Vertex resource name.
51+
_VALID_NAME_PATTERN = re.compile("^[a-z][-a-z0-9]{0,127}$")
52+
53+
54+
def _get_current_time() -> datetime.datetime:
55+
"""Gets the current timestamp."""
56+
return datetime.datetime.now()
57+
58+
59+
def _set_enable_caching_value(
60+
pipeline_spec: Dict[str, Any], enable_caching: bool
61+
) -> None:
62+
"""Sets pipeline tasks caching options.
63+
64+
Args:
65+
pipeline_spec (Dict[str, Any]):
66+
Required. The dictionary of pipeline spec.
67+
enable_caching (bool):
68+
Required. Whether to enable caching.
69+
"""
70+
for component in [pipeline_spec["root"]] + list(
71+
pipeline_spec["components"].values()
72+
):
73+
if "dag" in component:
74+
for task in component["dag"]["tasks"].values():
75+
task["cachingOptions"] = {"enableCache": enable_caching}
76+
77+
78+
class PipelineJob(base.VertexAiResourceNounWithFutureManager):
79+
80+
client_class = utils.PipelineJobClientWithOverride
81+
_is_client_prediction_client = False
82+
83+
_resource_noun = "pipelineJobs"
84+
_delete_method = "delete_pipeline_job"
85+
_getter_method = "get_pipeline_job"
86+
_list_method = "list_pipeline_jobs"
87+
88+
def __init__(
89+
self,
90+
display_name: str,
91+
template_path: str,
92+
job_id: Optional[str] = None,
93+
pipeline_root: Optional[str] = None,
94+
parameter_values: Optional[Dict[str, Any]] = None,
95+
enable_caching: Optional[bool] = True,
96+
encryption_spec_key_name: Optional[str] = None,
97+
labels: Optional[Dict[str, str]] = None,
98+
credentials: Optional[auth_credentials.Credentials] = None,
99+
project: Optional[str] = None,
100+
location: Optional[str] = None,
101+
):
102+
"""Retrieves a PipelineJob resource and instantiates its
103+
representation.
104+
105+
Args:
106+
display_name (str):
107+
Required. The user-defined name of this Pipeline.
108+
template_path (str):
109+
Required. The path of PipelineJob JSON file. It can be a local path or a
110+
Google Cloud Storage URI. Example: "gs://project.name"
111+
job_id (str):
112+
Optional. The unique ID of the job run.
113+
If not specified, pipeline name + timestamp will be used.
114+
pipeline_root (str):
115+
Optional. The root of the pipeline outputs. Default to be staging bucket.
116+
parameter_values (Dict[str, Any]):
117+
Optional. The mapping from runtime parameter names to its values that
118+
control the pipeline run.
119+
enable_caching (bool):
120+
Optional. Whether to turn on caching for the run. Defaults to True.
121+
encryption_spec_key_name (str):
122+
Optional. The Cloud KMS resource identifier of the customer
123+
managed encryption key used to protect the job. Has the
124+
form:
125+
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
126+
The key needs to be in the same region as where the compute
127+
resource is created.
128+
129+
If this is set, then all
130+
resources created by the BatchPredictionJob will
131+
be encrypted with the provided encryption key.
132+
133+
Overrides encryption_spec_key_name set in aiplatform.init.
134+
labels (Dict[str,str]):
135+
Optional. The user defined metadata to organize PipelineJob.
136+
credentials (auth_credentials.Credentials):
137+
Optional. Custom credentials to use to create this batch prediction
138+
job. Overrides credentials set in aiplatform.init.
139+
project (str),
140+
Optional. Project to retrieve PipelineJob from. If not set,
141+
project set in aiplatform.init will be used.
142+
location (str),
143+
Optional. Location to create PipelineJob. If not set,
144+
location set in aiplatform.init will be used.
145+
146+
Raises:
147+
ValueError: If job_id or labels have incorrect format.
148+
"""
149+
utils.validate_display_name(display_name)
150+
151+
if labels:
152+
for k, v in labels.items():
153+
if not isinstance(k, str) or not isinstance(v, str):
154+
raise ValueError(
155+
"Expect labels to be a mapping of string key value pairs. "
156+
'Got "{}".'.format(labels)
157+
)
158+
159+
super().__init__(project=project, location=location, credentials=credentials)
160+
161+
self._parent = initializer.global_config.common_location_path(
162+
project=project, location=location
163+
)
164+
pipeline_job = json_utils.load_json(
165+
template_path, self.project, self.credentials
166+
)
167+
pipeline_root = (
168+
pipeline_root
169+
or pipeline_job["runtimeConfig"].get("gcsOutputDirectory")
170+
or initializer.global_config.staging_bucket
171+
)
172+
173+
pipeline_name = pipeline_job["pipelineSpec"]["pipelineInfo"]["name"]
174+
job_id = job_id or "{pipeline_name}-{timestamp}".format(
175+
pipeline_name=re.sub("[^-0-9a-z]+", "-", pipeline_name.lower())
176+
.lstrip("-")
177+
.rstrip("-"),
178+
timestamp=_get_current_time().strftime("%Y%m%d%H%M%S"),
179+
)
180+
if not _VALID_NAME_PATTERN.match(job_id):
181+
raise ValueError(
182+
"Generated job ID: {} is illegal as a Vertex pipelines job ID. "
183+
"Expecting an ID following the regex pattern "
184+
'"[a-z][-a-z0-9]{{0,127}}"'.format(job_id)
185+
)
186+
job_name = _JOB_NAME_PATTERN.format(parent=self._parent, job_id=job_id)
187+
188+
builder = pipeline_utils.PipelineRuntimeConfigBuilder.from_job_spec_json(
189+
pipeline_job
190+
)
191+
builder.update_pipeline_root(pipeline_root)
192+
builder.update_runtime_parameters(parameter_values)
193+
runtime_config_dict = builder.build()
194+
runtime_config = gca_pipeline_job_v1beta1.PipelineJob.RuntimeConfig()._pb
195+
json_format.ParseDict(runtime_config_dict, runtime_config)
196+
197+
_set_enable_caching_value(pipeline_job["pipelineSpec"], enable_caching)
198+
199+
self._gca_resource = gca_pipeline_job_v1beta1.PipelineJob(
200+
display_name=display_name,
201+
name=job_name,
202+
pipeline_spec=pipeline_job["pipelineSpec"],
203+
labels=labels,
204+
runtime_config=runtime_config,
205+
encryption_spec=initializer.global_config.get_encryption_spec(
206+
encryption_spec_key_name=encryption_spec_key_name
207+
),
208+
)
209+
210+
@base.optional_sync()
211+
def run(
212+
self,
213+
service_account: Optional[str] = None,
214+
network: Optional[str] = None,
215+
sync: Optional[bool] = True,
216+
) -> None:
217+
"""Run this configured PipelineJob.
218+
219+
Args:
220+
service_account (str):
221+
Optional. Specifies the service account for workload run-as account.
222+
Users submitting jobs must have act-as permission on this run-as account.
223+
network (str):
224+
Optional. The full name of the Compute Engine network to which the job
225+
should be peered. For example, projects/12345/global/networks/myVPC.
226+
Private services access must already be configured for the network.
227+
If left unspecified, the job is not peered with any network.
228+
sync (bool):
229+
Optional. Whether to execute this method synchronously. If False, this method will unblock and it will be executed in a concurrent Future.
230+
"""
231+
if service_account:
232+
self._gca_resource.pipeline_spec.service_account = service_account
233+
234+
if network:
235+
self._gca_resource.pipeline_spec.network = network
236+
237+
_LOGGER.log_create_with_lro(self.__class__)
238+
239+
self._gca_resource = self.api_client.create_pipeline_job(
240+
parent=self._parent, pipeline_job=self._gca_resource
241+
)
242+
243+
_LOGGER.log_create_complete_with_getter(
244+
self.__class__, self._gca_resource, "pipeline_job"
245+
)
246+
247+
_LOGGER.info("View Pipeline Job:\n%s" % self._dashboard_uri())
248+
249+
self._block_until_complete()
250+
251+
@property
252+
def pipeline_spec(self):
253+
return self._gca_resource.pipeline_spec
254+
255+
@property
256+
def state(self) -> Optional[gca_pipeline_state_v1beta1.PipelineState]:
257+
"""Current pipeline state."""
258+
if not self._has_run:
259+
raise RuntimeError("Job has not run. No state available.")
260+
261+
self._sync_gca_resource()
262+
return self._gca_resource.state
263+
264+
@property
265+
def _has_run(self) -> bool:
266+
"""Helper property to check if this pipeline job has been run."""
267+
return bool(self._gca_resource.name)
268+
269+
@property
270+
def has_failed(self) -> bool:
271+
"""Returns True if pipeline has failed.
272+
273+
False otherwise.
274+
"""
275+
return (
276+
self.state == gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_FAILED
277+
)
278+
279+
def _dashboard_uri(self) -> str:
280+
"""Helper method to compose the dashboard uri where pipeline can be
281+
viewed."""
282+
fields = utils.extract_fields_from_resource_name(self.resource_name)
283+
url = f"https://console.cloud.google.com/vertex-ai/locations/{fields.location}/pipelines/runs/{fields.id}?project={fields.project}"
284+
return url
285+
286+
def _sync_gca_resource(self):
287+
"""Helper method to sync the local gca_source against the service."""
288+
self._gca_resource = self.api_client.get_pipeline_job(name=self.resource_name)
289+
290+
def _block_until_complete(self):
291+
"""Helper method to block and check on job until complete."""
292+
# Used these numbers so failures surface fast
293+
wait = 5 # start at five seconds
294+
log_wait = 5
295+
max_wait = 60 * 5 # 5 minute wait
296+
multiplier = 2 # scale wait by 2 every iteration
297+
298+
previous_time = time.time()
299+
while self.state not in _PIPELINE_COMPLETE_STATES:
300+
current_time = time.time()
301+
if current_time - previous_time >= log_wait:
302+
_LOGGER.info(
303+
"%s %s current state:\n%s"
304+
% (
305+
self.__class__.__name__,
306+
self._gca_resource.name,
307+
self._gca_resource.state,
308+
)
309+
)
310+
log_wait = min(log_wait * multiplier, max_wait)
311+
previous_time = current_time
312+
time.sleep(wait)

google/cloud/aiplatform/utils/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,14 @@ class PipelineClientWithOverride(ClientWithOverride):
459459
)
460460

461461

462+
class PipelineJobClientWithOverride(ClientWithOverride):
463+
_is_temporary = True
464+
_default_version = compat.V1BETA1
465+
_version_map = (
466+
(compat.V1BETA1, pipeline_service_client_v1beta1.PipelineServiceClient),
467+
)
468+
469+
462470
class PredictionClientWithOverride(ClientWithOverride):
463471
_is_temporary = False
464472
_default_version = compat.DEFAULT_VERSION
@@ -491,6 +499,7 @@ class TensorboardClientWithOverride(ClientWithOverride):
491499
JobClientWithOverride,
492500
ModelClientWithOverride,
493501
PipelineClientWithOverride,
502+
PipelineJobClientWithOverride,
494503
PredictionClientWithOverride,
495504
MetadataClientWithOverride,
496505
TensorboardClientWithOverride,

0 commit comments

Comments
 (0)