Skip to content

Commit 0af0c68

Browse files
committed
Address comments and fix lint
1 parent 645f1ba commit 0af0c68

4 files changed

Lines changed: 168 additions & 160 deletions

File tree

google/cloud/aiplatform/pipeline_jobs.py

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,29 +15,23 @@
1515
# limitations under the License.
1616
#
1717

18+
import datetime
1819
import time
19-
from typing import Any, Optional, Dict, List
20-
2120
import re
22-
import sys
21+
from typing import Any, Optional, Dict
2322

2423
from google.auth import credentials as auth_credentials
25-
2624
from google.cloud.aiplatform import base
27-
from google.cloud.aiplatform import compat
2825
from google.cloud.aiplatform import initializer
2926
from google.cloud.aiplatform import utils
3027
from google.cloud.aiplatform.utils import json_utils
3128
from google.cloud.aiplatform.utils import pipeline_utils
3229

33-
from google.cloud.aiplatform.compat.services import pipeline_service_client
3430
from google.cloud.aiplatform.compat.types import (
3531
pipeline_job_v1beta1 as gca_pipeline_job_v1beta1,
3632
pipeline_state_v1beta1 as gca_pipeline_state_v1beta1,
3733
)
3834

39-
from google.rpc import code_pb2
40-
4135
_LOGGER = base.Logger(__name__)
4236

4337
_PIPELINE_COMPLETE_STATES = set(
@@ -56,16 +50,21 @@
5650
_VALID_NAME_PATTERN = re.compile("^[a-z][-a-z0-9]{0,127}$")
5751

5852

53+
def _get_current_time() -> datetime.datetime:
54+
"""Gets the current timestamp."""
55+
return datetime.datetime.now()
56+
57+
5958
def _set_enable_caching_value(
6059
pipeline_spec: Dict[str, Any], enable_caching: bool
6160
) -> None:
6261
"""Sets pipeline tasks caching options.
6362
6463
Args:
6564
pipeline_spec (Dict[str, Any]):
66-
The dictionary of pipeline spec.
65+
Required. The dictionary of pipeline spec.
6766
enable_caching (bool):
68-
Whether to enable caching.
67+
Required. Whether to enable caching.
6968
"""
7069
for component in [pipeline_spec["root"]] + list(
7170
pipeline_spec["components"].values()
@@ -82,8 +81,6 @@ class PipelineJob(base.VertexAiResourceNounWithFutureManager):
8281

8382
_resource_noun = "pipelineJobs"
8483
_getter_method = "get_pipeline_job"
85-
_list_method = "list_pipeline_jobs"
86-
_cancel_method = "cancel_pipeline_job"
8784
_delete_method = "delete_pipeline_job"
8885

8986
def __init__(
@@ -93,7 +90,7 @@ def __init__(
9390
job_id: Optional[str] = None,
9491
pipeline_root: Optional[str] = None,
9592
parameter_values: Optional[Dict[str, Any]] = None,
96-
enable_caching: bool = True,
93+
enable_caching: Optional[bool] = True,
9794
encryption_spec_key_name: Optional[str] = None,
9895
labels: Optional[Dict[str, str]] = None,
9996
credentials: Optional[auth_credentials.Credentials] = None,
@@ -118,7 +115,7 @@ def __init__(
118115
Optional. The mapping from runtime parameter names to its values that
119116
control the pipeline run.
120117
enable_caching (bool):
121-
Required. Whether to turn on caching for the run. Defaults to True.
118+
Optional. Whether to turn on caching for the run. Defaults to True.
122119
encryption_spec_key_name (str):
123120
Optional. The Cloud KMS resource identifier of the customer
124121
managed encryption key used to protect the job. Has the
@@ -145,10 +142,18 @@ def __init__(
145142
location set in aiplatform.init will be used.
146143
147144
Raises:
148-
ValueError: If inputs are formatted wrong.
145+
ValueError: If job_id or labels have incorrect format.
149146
"""
150147
utils.validate_display_name(display_name)
151148

149+
if labels:
150+
for k, v in labels.items():
151+
if not isinstance(k, str) or not isinstance(v, str):
152+
raise ValueError(
153+
"Expect labels to be a mapping of string key value pairs. "
154+
'Got "{}".'.format(labels)
155+
)
156+
152157
super().__init__(project=project, location=location, credentials=credentials)
153158

154159
self._parent = initializer.global_config.common_location_path(
@@ -192,13 +197,6 @@ def __init__(
192197
pipeline_spec["encryptionSpec"] = {"kmsKeyName": encryption_spec_key_name}
193198

194199
if labels:
195-
for k, v in labels.items():
196-
if not isinstance(k, str) or not isinstance(v, str):
197-
raise ValueError(
198-
"Expect labels to be a mapping of string key value pairs. "
199-
'Got "{}".'.format(labels)
200-
)
201-
202200
pipeline_spec["labels"] = labels
203201

204202
self._gca_resource = gca_pipeline_job_v1beta1.PipelineJob(
@@ -215,7 +213,7 @@ def run(
215213
self,
216214
service_account: Optional[str] = None,
217215
network: Optional[str] = None,
218-
sync: bool = True,
216+
sync: Optional[bool] = True,
219217
) -> None:
220218
"""Run this configured PipelineJob.
221219
@@ -229,8 +227,7 @@ def run(
229227
Private services access must already be configured for the network.
230228
If left unspecified, the job is not peered with any network.
231229
sync (bool):
232-
Whether to execute this method synchronously. If False, this method
233-
will unblock and it will be executed in a concurrent Future.
230+
Optional. Whether to execute this method synchronously. If False, this method will unblock and it will be executed in a concurrent Future.
234231
"""
235232
if service_account:
236233
self._gca_resource.pipeline_spec.service_account = service_account
@@ -268,7 +265,7 @@ def state(self) -> Optional[gca_pipeline_state_v1beta1.PipelineState]:
268265
@property
269266
def _has_run(self) -> bool:
270267
"""Helper property to check if this pipeline job has been run."""
271-
return self._gca_resource is not None
268+
return bool(self._gca_resource.name)
272269

273270
@property
274271
def has_failed(self) -> bool:

google/cloud/aiplatform/utils/json_utils.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,17 @@
2121
from google.auth import credentials as auth_credentials
2222
from google.cloud import storage
2323

24+
2425
def load_json(path: str,
25-
project: Optional[str] = None,
26-
credentials: Optional[auth_credentials.Credentials] = None) -> Dict[str, Any]:
26+
project: Optional[str] = None,
27+
credentials: Optional[auth_credentials.Credentials] = None
28+
) -> Dict[str, Any]:
2729
"""Loads data from a JSON document.
2830
2931
Args:
3032
path (str):
31-
Required. The path of the JSON document in Google Cloud Storage or local.
33+
Required. The path of the JSON document in Google Cloud Storage or
34+
local.
3235
project (str):
3336
Optional. Project to initiate the Storage client with.
3437
credentials (auth_credentials.Credentials):
@@ -38,13 +41,15 @@ def load_json(path: str,
3841
A Dict object representing the JSON document.
3942
"""
4043
if path.startswith('gs://'):
41-
return _load_json_from_gs_uri(path, project,credentials)
44+
return _load_json_from_gs_uri(path, project, credentials)
4245
else:
43-
return _load_json_from_local_file(path)
46+
return _load_json_from_local_file(path)
4447

4548

46-
def _load_json_from_gs_uri(uri: str, project: Optional[str] = None,
47-
credentials: Optional[auth_credentials.Credentials] = None) -> Dict[str, Any]:
49+
def _load_json_from_gs_uri(uri: str,
50+
project: Optional[str] = None,
51+
credentials: Optional[auth_credentials.Credentials]
52+
= None) -> Dict[str, Any]:
4853
"""Loads data from a JSON document referenced by a GCS URI.
4954
5055
Args:
@@ -74,4 +79,4 @@ def _load_json_from_local_file(file_path: str) -> Dict[str, Any]:
7479
A Dict object representing the JSON document.
7580
"""
7681
with open(file_path) as f:
77-
return json.load(f)
82+
return json.load(f)

0 commit comments

Comments
 (0)