Skip to content

Commit f27e17f

Browse files
authored
samples: Add Batch Prediction Job with dedicated resources sample (#396)
Manually tested usage with: ```py import os import create_batch_prediction_job_dedicated_resources_sample bpj = create_batch_prediction_job_dedicated_resources_sample.create_batch_prediction_job_dedicated_resources_sample( project=os.getenv("BUILD_SPECIFIC_GCLOUD_PROJECT"), location='us-central1', model_resource_name='3512561418744365056', job_display_name='temp_create_batch_prediction_job_test_manual_vinnys', gcs_source='gs://ucaip-samples-test-output/inputs/icn_batch_prediction_input.jsonl', gcs_destination='gs://ucaip-samples-test-output/', ) ```
1 parent 066624b commit f27e17f

2 files changed

Lines changed: 114 additions & 0 deletions

File tree

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Sequence, Union
16+
17+
from google.cloud import aiplatform, aiplatform_v1
18+
19+
20+
# [START aiplatform_sdk_create_batch_prediction_job_dedicated_resources_sample]
21+
def create_batch_prediction_job_dedicated_resources_sample(
22+
project: str,
23+
location: str,
24+
model_resource_name: str,
25+
job_display_name: str,
26+
gcs_source: Union[str, Sequence[str]],
27+
gcs_destination: str,
28+
machine_type: str = "n1-standard-2",
29+
accelerator_count: int = 1,
30+
accelerator_type: Union[str, aiplatform_v1.AcceleratorType] = "NVIDIA_TESLA_K80",
31+
starting_replica_count: int = 1,
32+
max_replica_count: int = 1,
33+
sync: bool = True,
34+
):
35+
aiplatform.init(project=project, location=location)
36+
37+
my_model = aiplatform.Model(model_resource_name)
38+
39+
batch_prediction_job = my_model.batch_predict(
40+
job_display_name=job_display_name,
41+
gcs_source=gcs_source,
42+
gcs_destination_prefix=gcs_destination,
43+
machine_type=machine_type,
44+
accelerator_count=accelerator_count,
45+
accelerator_type=accelerator_type,
46+
starting_replica_count=starting_replica_count,
47+
max_replica_count=max_replica_count,
48+
sync=sync,
49+
)
50+
51+
batch_prediction_job.wait()
52+
53+
print(batch_prediction_job.display_name)
54+
print(batch_prediction_job.resource_name)
55+
print(batch_prediction_job.state)
56+
return batch_prediction_job
57+
58+
59+
# [END aiplatform_sdk_create_batch_prediction_job_dedicated_resources_sample]
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import pytest
17+
18+
import create_batch_prediction_job_dedicated_resources_sample
19+
import test_constants as constants
20+
21+
22+
@pytest.mark.usefixtures("mock_model")
23+
def test_create_batch_prediction_job_sample(
24+
mock_sdk_init, mock_init_model, mock_batch_predict_model
25+
):
26+
27+
create_batch_prediction_job_dedicated_resources_sample.create_batch_prediction_job_dedicated_resources_sample(
28+
project=constants.PROJECT,
29+
location=constants.LOCATION,
30+
model_resource_name=constants.MODEL_NAME,
31+
job_display_name=constants.DISPLAY_NAME,
32+
gcs_source=constants.GCS_SOURCES,
33+
gcs_destination=constants.GCS_DESTINATION,
34+
machine_type=constants.ACCELERATOR_TYPE,
35+
accelerator_count=constants.ACCELERATOR_COUNT,
36+
accelerator_type=constants.ACCELERATOR_TYPE,
37+
starting_replica_count=constants.MIN_REPLICA_COUNT,
38+
max_replica_count=constants.MAX_REPLICA_COUNT,
39+
)
40+
41+
mock_sdk_init.assert_called_once_with(
42+
project=constants.PROJECT, location=constants.LOCATION
43+
)
44+
mock_init_model.assert_called_once_with(constants.MODEL_NAME)
45+
mock_batch_predict_model.assert_called_once_with(
46+
job_display_name=constants.DISPLAY_NAME,
47+
gcs_source=constants.GCS_SOURCES,
48+
gcs_destination_prefix=constants.GCS_DESTINATION,
49+
machine_type=constants.ACCELERATOR_TYPE,
50+
accelerator_count=constants.ACCELERATOR_COUNT,
51+
accelerator_type=constants.ACCELERATOR_TYPE,
52+
starting_replica_count=constants.MIN_REPLICA_COUNT,
53+
max_replica_count=constants.MAX_REPLICA_COUNT,
54+
sync=True,
55+
)

0 commit comments

Comments
 (0)