Skip to content

Commit a9a8e6e

Browse files
authored
Merge branch 'master' into dev
2 parents 2f1674e + 572a27c commit a9a8e6e

328 files changed

Lines changed: 10938 additions & 1855 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/.OwlBot.lock.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
docker:
2+
image: gcr.io/repo-automation-bots/owlbot-python:latest
3+
digest: sha256:9d6a2d613e2c04c07ecdb6c287e3931890f6d30266ab5ee4ee412f748dc98341

.github/.OwlBot.yaml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
docker:
16+
image: gcr.io/repo-automation-bots/owlbot-python:latest
17+
18+
deep-remove-regex:
19+
- /owl-bot-staging
20+
21+
deep-copy-regex:
22+
- source: /google/cloud/aiplatform/(v.*)/.*-py/(.*)
23+
dest: /owl-bot-staging/$1/$2
24+
25+
begin-after-commit-hash: 7774246dfb7839067cd64bba0600089b1c91bd85
26+

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# Changelog
22

3+
### [1.0.1](https://www.github.com/googleapis/python-aiplatform/compare/v1.0.0...v1.0.1) (2021-05-21)
4+
5+
6+
### Bug Fixes
7+
8+
* use resource name location when passed full resource name ([#421](https://www.github.com/googleapis/python-aiplatform/issues/421)) ([f40f322](https://www.github.com/googleapis/python-aiplatform/commit/f40f32289e1fbeb93b35e4b66f65d15528a6481c))
9+
310
## [1.0.0](https://www.github.com/googleapis/python-aiplatform/compare/v0.9.0...v1.0.0) (2021-05-19)
411

512

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@
363363
"google.api_core": ("https://googleapis.dev/python/google-api-core/latest/", None,),
364364
"grpc": ("https://grpc.github.io/grpc/python/", None),
365365
"proto-plus": ("https://proto-plus-python.readthedocs.io/en/latest/", None),
366+
"protobuf": ("https://googleapis.dev/python/protobuf/latest/", None),
366367
}
367368

368369

docs/multiprocessing.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
.. note::
22

3-
Because this client uses :mod:`grpcio` library, it is safe to
3+
Because this client uses :mod:`grpc` library, it is safe to
44
share instances across threads. In multiprocessing scenarios, the best
55
practice is to create client instances *after* the invocation of
6-
:func:`os.fork` by :class:`multiprocessing.Pool` or
6+
:func:`os.fork` by :class:`multiprocessing.pool.Pool` or
77
:class:`multiprocessing.Process`.

google/cloud/aiplatform/base.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from google.auth import credentials as auth_credentials
4343
from google.cloud.aiplatform import initializer
4444
from google.cloud.aiplatform import utils
45-
45+
from google.cloud.aiplatform.compat.types import encryption_spec as gca_encryption_spec
4646

4747
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
4848

@@ -563,6 +563,23 @@ def update_time(self) -> datetime.datetime:
563563
self._sync_gca_resource()
564564
return self._gca_resource.update_time
565565

566+
@property
567+
def encryption_spec(self) -> Optional[gca_encryption_spec.EncryptionSpec]:
568+
"""Customer-managed encryption key options for this Vertex AI resource.
569+
570+
If this is set, then all resources created by this Vertex AI resource will
571+
be encrypted with the provided encryption key.
572+
"""
573+
return getattr(self._gca_resource, "encryption_spec")
574+
575+
@property
576+
def labels(self) -> Dict[str, str]:
577+
"""User-defined labels containing metadata about this resource.
578+
579+
Read more about labels at https://goo.gl/xmQnxf
580+
"""
581+
return self._gca_resource.labels
582+
566583
@property
567584
def gca_resource(self) -> proto.Message:
568585
"""The underlying resource proto represenation."""
@@ -813,7 +830,7 @@ def _construct_sdk_resource_from_gapic(
813830
814831
Args:
815832
gapic_resource (proto.Message):
816-
A GAPIC representation of an Vertex AI resource, usually
833+
A GAPIC representation of a Vertex AI resource, usually
817834
retrieved by a get_* or in a list_* API call.
818835
project (str):
819836
Optional. Project to construct SDK object from. If not set,

google/cloud/aiplatform/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
}
3434

3535
API_BASE_PATH = "aiplatform.googleapis.com"
36+
PREDICTION_API_BASE_PATH = API_BASE_PATH
3637

3738
# Batch Prediction
3839
BATCH_PREDICTION_INPUT_STORAGE_FORMATS = (

google/cloud/aiplatform/datasets/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(
6868
Optional location to retrieve dataset from. If not set, location
6969
set in aiplatform.init will be used.
7070
credentials (auth_credentials.Credentials):
71-
Custom credentials to use to upload this model. Overrides
71+
Custom credentials to use to retreive this Dataset. Overrides
7272
credentials set in aiplatform.init.
7373
"""
7474

google/cloud/aiplatform/datasets/tabular_dataset.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,20 +73,28 @@ def column_names(self) -> List[str]:
7373
gcs_source_uris.sort()
7474

7575
# Get the first file in sorted list
76-
return TabularDataset._retrieve_gcs_source_columns(
77-
self.project, gcs_source_uris[0]
76+
return self._retrieve_gcs_source_columns(
77+
project=self.project,
78+
gcs_csv_file_path=gcs_source_uris[0],
79+
credentials=self.credentials,
7880
)
7981
elif bq_source:
8082
bq_table_uri = bq_source.get("uri")
8183
if bq_table_uri:
82-
return TabularDataset._retrieve_bq_source_columns(
83-
self.project, bq_table_uri
84+
return self._retrieve_bq_source_columns(
85+
project=self.project,
86+
bq_table_uri=bq_table_uri,
87+
credentials=self.credentials,
8488
)
8589

8690
raise RuntimeError("No valid CSV or BigQuery datasource found.")
8791

8892
@staticmethod
89-
def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[str]:
93+
def _retrieve_gcs_source_columns(
94+
project: str,
95+
gcs_csv_file_path: str,
96+
credentials: Optional[auth_credentials.Credentials] = None,
97+
) -> List[str]:
9098
"""Retrieve the columns from a comma-delimited CSV file stored on Google Cloud Storage
9199
92100
Example Usage:
@@ -104,7 +112,8 @@ def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[s
104112
gcs_csv_file_path (str):
105113
Required. A full path to a CSV files stored on Google Cloud Storage.
106114
Must include "gs://" prefix.
107-
115+
credentials (auth_credentials.Credentials):
116+
Credentials to use to with GCS Client.
108117
Returns:
109118
List[str]
110119
A list of columns names in the CSV file.
@@ -116,7 +125,7 @@ def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[s
116125
gcs_bucket, gcs_blob = utils.extract_bucket_and_prefix_from_gcs_path(
117126
gcs_csv_file_path
118127
)
119-
client = storage.Client(project=project)
128+
client = storage.Client(project=project, credentials=credentials)
120129
bucket = client.bucket(gcs_bucket)
121130
blob = bucket.blob(gcs_blob)
122131

@@ -135,6 +144,7 @@ def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[s
135144
line += blob.download_as_bytes(
136145
start=start_index, end=start_index + increment
137146
).decode("utf-8")
147+
138148
first_new_line_index = line.find("\n")
139149
start_index += increment
140150

@@ -156,7 +166,11 @@ def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[s
156166
return next(csv_reader)
157167

158168
@staticmethod
159-
def _retrieve_bq_source_columns(project: str, bq_table_uri: str) -> List[str]:
169+
def _retrieve_bq_source_columns(
170+
project: str,
171+
bq_table_uri: str,
172+
credentials: Optional[auth_credentials.Credentials] = None,
173+
) -> List[str]:
160174
"""Retrieve the columns from a table on Google BigQuery
161175
162176
Example Usage:
@@ -174,6 +188,8 @@ def _retrieve_bq_source_columns(project: str, bq_table_uri: str) -> List[str]:
174188
bq_table_uri (str):
175189
Required. A URI to a BigQuery table.
176190
Can include "bq://" prefix but not required.
191+
credentials (auth_credentials.Credentials):
192+
Credentials to use with BQ Client.
177193
178194
Returns:
179195
List[str]
@@ -185,7 +201,7 @@ def _retrieve_bq_source_columns(project: str, bq_table_uri: str) -> List[str]:
185201
if bq_table_uri.startswith(prefix):
186202
bq_table_uri = bq_table_uri[len(prefix) :]
187203

188-
client = bigquery.Client(project=project)
204+
client = bigquery.Client(project=project, credentials=credentials)
189205
table = client.get_table(bq_table_uri)
190206
schema = table.schema
191207
return [schema.name for schema in schema]

google/cloud/aiplatform/initializer.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,19 @@ def init(
9292
if metadata.metadata_service.experiment_name:
9393
logging.info("project/location updated, reset Metadata config.")
9494
metadata.metadata_service.reset()
95+
9596
if project:
9697
self._project = project
9798
if location:
9899
utils.validate_region(location)
99100
self._location = location
101+
if staging_bucket:
102+
self._staging_bucket = staging_bucket
103+
if credentials:
104+
self._credentials = credentials
105+
if encryption_spec_key_name:
106+
self._encryption_spec_key_name = encryption_spec_key_name
107+
100108
if experiment:
101109
metadata.metadata_service.set_experiment(
102110
experiment=experiment, description=experiment_description
@@ -105,12 +113,6 @@ def init(
105113
raise ValueError(
106114
"Experiment name needs to be set in `init` in order to add experiment descriptions."
107115
)
108-
if staging_bucket:
109-
self._staging_bucket = staging_bucket
110-
if credentials:
111-
self._credentials = credentials
112-
if encryption_spec_key_name:
113-
self._encryption_spec_key_name = encryption_spec_key_name
114116

115117
def get_encryption_spec(
116118
self,
@@ -192,7 +194,7 @@ def encryption_spec_key_name(self) -> Optional[str]:
192194
return self._encryption_spec_key_name
193195

194196
def get_client_options(
195-
self, location_override: Optional[str] = None
197+
self, location_override: Optional[str] = None, prediction_client: bool = False
196198
) -> client_options.ClientOptions:
197199
"""Creates GAPIC client_options using location and type.
198200
@@ -201,6 +203,8 @@ def get_client_options(
201203
Set this parameter to get client options for a location different from
202204
location set by initializer. Must be a GCP region supported by AI
203205
Platform (Unified).
206+
prediction_client (str): Optional flag to use a prediction endpoint.
207+
204208
205209
Returns:
206210
clients_options (google.api_core.client_options.ClientOptions):
@@ -218,8 +222,14 @@ def get_client_options(
218222

219223
utils.validate_region(region)
220224

225+
service_base_path = (
226+
constants.PREDICTION_API_BASE_PATH
227+
if prediction_client
228+
else constants.API_BASE_PATH
229+
)
230+
221231
return client_options.ClientOptions(
222-
api_endpoint=f"{region}-{constants.API_BASE_PATH}"
232+
api_endpoint=f"{region}-{service_base_path}"
223233
)
224234

225235
def common_location_path(
@@ -257,7 +267,7 @@ def create_client(
257267
258268
Args:
259269
client_class (utils.VertexAiServiceClientWithOverride):
260-
(Required) An Vertex AI Service Client with optional overrides.
270+
(Required) A Vertex AI Service Client with optional overrides.
261271
credentials (auth_credentials.Credentials):
262272
Custom auth credentials. If not provided will use the current config.
263273
location_override (str): Optional location override.
@@ -276,7 +286,8 @@ def create_client(
276286
kwargs = {
277287
"credentials": credentials or self.credentials,
278288
"client_options": self.get_client_options(
279-
location_override=location_override
289+
location_override=location_override,
290+
prediction_client=prediction_client,
280291
),
281292
"client_info": client_info,
282293
}

0 commit comments

Comments
 (0)