Skip to content

Commit 0537b8c

Browse files
committed
Fix pipeline runtime_config layer
1 parent 0bb3549 commit 0537b8c

3 files changed

Lines changed: 30 additions & 36 deletions

File tree

google/cloud/aiplatform/pipeline_jobs.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from google.cloud.aiplatform import utils
2727
from google.cloud.aiplatform.utils import json_utils
2828
from google.cloud.aiplatform.utils import pipeline_utils
29+
from google.protobuf import json_format
2930

3031
from google.cloud.aiplatform.compat.types import (
3132
pipeline_job_v1beta1 as gca_pipeline_job_v1beta1,
@@ -159,11 +160,16 @@ def __init__(
159160
self._parent = initializer.global_config.common_location_path(
160161
project=project, location=location
161162
)
162-
pipeline_root = pipeline_root or initializer.global_config.staging_bucket
163-
pipeline_spec = json_utils.load_json(
163+
pipeline_job = json_utils.load_json(
164164
template_path, self.project, self.credentials
165165
)
166-
pipeline_name = pipeline_spec["pipelineSpec"]["pipelineInfo"]["name"]
166+
pipeline_root = (
167+
pipeline_root
168+
or pipeline_job["runtimeConfig"].get("gcsOutputDirectory")
169+
or initializer.global_config.staging_bucket
170+
)
171+
172+
pipeline_name = pipeline_job["pipelineSpec"]["pipelineInfo"]["name"]
167173
job_id = job_id or "{pipeline_name}-{timestamp}".format(
168174
pipeline_name=re.sub("[^-0-9a-z]+", "-", pipeline_name.lower())
169175
.lstrip("-")
@@ -177,32 +183,22 @@ def __init__(
177183
'"[a-z][-a-z0-9]{{0,127}}"'.format(job_id)
178184
)
179185

180-
job_name = _JOB_NAME_PATTERN.format(parent=self._parent, job_id=job_id)
181-
182-
pipeline_spec["name"] = job_name
183-
pipeline_spec["displayName"] = job_id
184-
185186
builder = pipeline_utils.PipelineRuntimeConfigBuilder.from_job_spec_json(
186-
pipeline_spec
187+
pipeline_job
187188
)
188189
builder.update_pipeline_root(pipeline_root)
189190
builder.update_runtime_parameters(parameter_values)
191+
runtime_config_dict = builder.build()
192+
runtime_config = gca_pipeline_job_v1beta1.PipelineJob.RuntimeConfig()._pb
193+
json_format.ParseDict(runtime_config_dict, runtime_config)
190194

191-
runtime_config = builder.build()
192-
pipeline_spec["runtimeConfig"] = runtime_config
193-
194-
_set_enable_caching_value(pipeline_spec["pipelineSpec"], enable_caching)
195-
196-
if encryption_spec_key_name is not None:
197-
pipeline_spec["encryptionSpec"] = {"kmsKeyName": encryption_spec_key_name}
198-
199-
if labels:
200-
pipeline_spec["labels"] = labels
195+
_set_enable_caching_value(pipeline_job["pipelineSpec"], enable_caching)
201196

202197
self._gca_resource = gca_pipeline_job_v1beta1.PipelineJob(
203198
display_name=display_name,
204-
pipeline_spec=pipeline_spec,
199+
pipeline_spec=pipeline_job["pipelineSpec"],
205200
labels=labels,
201+
runtime_config=runtime_config,
206202
encryption_spec=initializer.global_config.get_encryption_spec(
207203
encryption_spec_key_name=encryption_spec_key_name
208204
),

google/cloud/aiplatform/utils/pipeline_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def build(self) -> Dict[str, Any]:
103103
"compile time, or when calling the service."
104104
)
105105
return {
106-
"gcsOutputDirectory": self._pipeline_root,
106+
"gcs_output_directory": self._pipeline_root,
107107
"parameters": {
108108
k: self._get_vertex_value(k, v)
109109
for k, v in self._parameter_values.items()

tests/unit/aiplatform/test_pipeline_jobs.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@
2929

3030
from google.cloud.aiplatform import pipeline_jobs
3131
from google.cloud.aiplatform import initializer
32+
from google.protobuf import json_format
3233

3334
from google.cloud.aiplatform_v1beta1.services.pipeline_service import (
3435
client as pipeline_service_client_v1beta1,
3536
)
36-
3737
from google.cloud.aiplatform_v1beta1.types import (
3838
pipeline_job as gca_pipeline_job_v1beta1,
3939
pipeline_state as gca_pipeline_state_v1beta1,
@@ -185,31 +185,29 @@ def test_run_call_pipeline_service_create(
185185
if not sync:
186186
job.wait()
187187

188+
expected_runtime_config_dict = {
189+
"gcs_output_directory": _TEST_GCS_BUCKET_NAME,
190+
"parameters": {"name_param": {"stringValue": "hello"}},
191+
}
192+
runtime_config = gca_pipeline_job_v1beta1.PipelineJob.RuntimeConfig()._pb
193+
json_format.ParseDict(expected_runtime_config_dict, runtime_config)
194+
188195
# Construct expected request
189196
expected_gapic_pipeline_job = gca_pipeline_job_v1beta1.PipelineJob(
190197
display_name=_TEST_PIPELINE_JOB_ID,
191198
pipeline_spec={
192-
"displayName": _TEST_PIPELINE_JOB_ID,
193-
"name": _TEST_PIPELINE_JOB_NAME,
194-
"pipelineSpec": {
195-
"components": {},
196-
"pipelineInfo": _TEST_PIPELINE_JOB_SPEC["pipelineSpec"][
197-
"pipelineInfo"
198-
],
199-
"root": _TEST_PIPELINE_JOB_SPEC["pipelineSpec"]["root"],
200-
},
201-
"runtimeConfig": {
202-
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
203-
"parameters": {"name_param": {"stringValue": "hello"}},
204-
},
199+
"components": {},
200+
"pipelineInfo": _TEST_PIPELINE_JOB_SPEC["pipelineSpec"]["pipelineInfo"],
201+
"root": _TEST_PIPELINE_JOB_SPEC["pipelineSpec"]["root"],
205202
},
203+
runtime_config=runtime_config,
206204
)
207205

208206
mock_pipeline_service_create.assert_called_once_with(
209207
parent=_TEST_PARENT, pipeline_job=expected_gapic_pipeline_job,
210208
)
211209

212-
mock_pipeline_service_get.assert_called_with(name=_TEST_PIPELINE_JOB_NAME,)
210+
mock_pipeline_service_get.assert_called_with(name=_TEST_PIPELINE_JOB_NAME)
213211

214212
assert job._gca_resource == make_pipeline_job(
215213
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_SUCCEEDED

0 commit comments

Comments
 (0)