2626from google .cloud .aiplatform import utils
2727from google .cloud .aiplatform .utils import json_utils
2828from google .cloud .aiplatform .utils import pipeline_utils
29+ from google .protobuf import json_format
2930
3031from google .cloud .aiplatform .compat .types import (
3132 pipeline_job_v1beta1 as gca_pipeline_job_v1beta1 ,
@@ -159,11 +160,16 @@ def __init__(
159160 self ._parent = initializer .global_config .common_location_path (
160161 project = project , location = location
161162 )
162- pipeline_root = pipeline_root or initializer .global_config .staging_bucket
163- pipeline_spec = json_utils .load_json (
163+ pipeline_job = json_utils .load_json (
164164 template_path , self .project , self .credentials
165165 )
166- pipeline_name = pipeline_spec ["pipelineSpec" ]["pipelineInfo" ]["name" ]
166+ pipeline_root = (
167+ pipeline_root
168+ or pipeline_job ["runtimeConfig" ].get ("gcsOutputDirectory" )
169+ or initializer .global_config .staging_bucket
170+ )
171+
172+ pipeline_name = pipeline_job ["pipelineSpec" ]["pipelineInfo" ]["name" ]
167173 job_id = job_id or "{pipeline_name}-{timestamp}" .format (
168174 pipeline_name = re .sub ("[^-0-9a-z]+" , "-" , pipeline_name .lower ())
169175 .lstrip ("-" )
@@ -177,32 +183,22 @@ def __init__(
177183 '"[a-z][-a-z0-9]{{0,127}}"' .format (job_id )
178184 )
179185
180- job_name = _JOB_NAME_PATTERN .format (parent = self ._parent , job_id = job_id )
181-
182- pipeline_spec ["name" ] = job_name
183- pipeline_spec ["displayName" ] = job_id
184-
185186 builder = pipeline_utils .PipelineRuntimeConfigBuilder .from_job_spec_json (
186- pipeline_spec
187+ pipeline_job
187188 )
188189 builder .update_pipeline_root (pipeline_root )
189190 builder .update_runtime_parameters (parameter_values )
191+ runtime_config_dict = builder .build ()
192+ runtime_config = gca_pipeline_job_v1beta1 .PipelineJob .RuntimeConfig ()._pb
193+ json_format .ParseDict (runtime_config_dict , runtime_config )
190194
191- runtime_config = builder .build ()
192- pipeline_spec ["runtimeConfig" ] = runtime_config
193-
194- _set_enable_caching_value (pipeline_spec ["pipelineSpec" ], enable_caching )
195-
196- if encryption_spec_key_name is not None :
197- pipeline_spec ["encryptionSpec" ] = {"kmsKeyName" : encryption_spec_key_name }
198-
199- if labels :
200- pipeline_spec ["labels" ] = labels
195+ _set_enable_caching_value (pipeline_job ["pipelineSpec" ], enable_caching )
201196
202197 self ._gca_resource = gca_pipeline_job_v1beta1 .PipelineJob (
203198 display_name = display_name ,
204- pipeline_spec = pipeline_spec ,
199+ pipeline_spec = pipeline_job [ "pipelineSpec" ] ,
205200 labels = labels ,
201+ runtime_config = runtime_config ,
206202 encryption_spec = initializer .global_config .get_encryption_spec (
207203 encryption_spec_key_name = encryption_spec_key_name
208204 ),
0 commit comments