|
19 | 19 |
|
20 | 20 | import abc |
21 | 21 | import copy |
22 | | -import sys |
23 | 22 | import time |
24 | 23 |
|
25 | 24 | from google.cloud import storage |
@@ -568,11 +567,9 @@ def create( |
568 | 567 | gapic_batch_prediction_job.output_config = output_config |
569 | 568 |
|
570 | 569 | # Optional Fields |
571 | | - gapic_batch_prediction_job.encryption_spec = ( |
572 | | - initializer.global_config.get_encryption_spec( |
573 | | - encryption_spec_key_name=encryption_spec_key_name, |
574 | | - select_version=select_version, |
575 | | - ) |
| 570 | + gapic_batch_prediction_job.encryption_spec = initializer.global_config.get_encryption_spec( |
| 571 | + encryption_spec_key_name=encryption_spec_key_name, |
| 572 | + select_version=select_version, |
576 | 573 | ) |
577 | 574 |
|
578 | 575 | if model_parameters: |
@@ -604,10 +601,8 @@ def create( |
604 | 601 | gapic_batch_prediction_job.generate_explanation = generate_explanation |
605 | 602 |
|
606 | 603 | if explanation_metadata or explanation_parameters: |
607 | | - gapic_batch_prediction_job.explanation_spec = ( |
608 | | - gca_explanation_v1beta1.ExplanationSpec( |
609 | | - metadata=explanation_metadata, parameters=explanation_parameters |
610 | | - ) |
| 604 | + gapic_batch_prediction_job.explanation_spec = gca_explanation_v1beta1.ExplanationSpec( |
| 605 | + metadata=explanation_metadata, parameters=explanation_parameters |
611 | 606 | ) |
612 | 607 |
|
613 | 608 | # TODO (b/174502913): Support private feature once released |
@@ -1088,23 +1083,19 @@ def from_local_script( |
1088 | 1083 | "should be set using aiplatform.init(staging_bucket='gs://my-bucket')" |
1089 | 1084 | ) |
1090 | 1085 |
|
1091 | | - worker_pool_specs = ( |
1092 | | - worker_spec_utils._DistributedTrainingSpec.chief_worker_pool( |
1093 | | - replica_count=replica_count, |
1094 | | - machine_type=machine_type, |
1095 | | - accelerator_count=accelerator_count, |
1096 | | - accelerator_type=accelerator_type, |
1097 | | - ).pool_specs |
1098 | | - ) |
| 1086 | + worker_pool_specs = worker_spec_utils._DistributedTrainingSpec.chief_worker_pool( |
| 1087 | + replica_count=replica_count, |
| 1088 | + machine_type=machine_type, |
| 1089 | + accelerator_count=accelerator_count, |
| 1090 | + accelerator_type=accelerator_type, |
| 1091 | + ).pool_specs |
1099 | 1092 |
|
1100 | 1093 | python_packager = source_utils._TrainingScriptPythonPackager( |
1101 | 1094 | script_path=script_path, requirements=requirements |
1102 | 1095 | ) |
1103 | 1096 |
|
1104 | 1097 | package_gcs_uri = python_packager.package_and_copy_to_gcs( |
1105 | | - gcs_staging_dir=staging_bucket, |
1106 | | - project=project, |
1107 | | - credentials=credentials, |
| 1098 | + gcs_staging_dir=staging_bucket, project=project, credentials=credentials, |
1108 | 1099 | ) |
1109 | 1100 |
|
1110 | 1101 | for spec in worker_pool_specs: |
@@ -1429,18 +1420,16 @@ def __init__( |
1429 | 1420 | ], |
1430 | 1421 | ) |
1431 | 1422 |
|
1432 | | - self._gca_resource = ( |
1433 | | - gca_hyperparameter_tuning_job_compat.HyperparameterTuningJob( |
1434 | | - display_name=display_name, |
1435 | | - study_spec=study_spec, |
1436 | | - max_trial_count=max_trial_count, |
1437 | | - parallel_trial_count=parallel_trial_count, |
1438 | | - max_failed_trial_count=max_failed_trial_count, |
1439 | | - trial_job_spec=copy.deepcopy(custom_job.job_spec), |
1440 | | - encryption_spec=initializer.global_config.get_encryption_spec( |
1441 | | - encryption_spec_key_name=encryption_spec_key_name |
1442 | | - ), |
1443 | | - ) |
| 1423 | + self._gca_resource = gca_hyperparameter_tuning_job_compat.HyperparameterTuningJob( |
| 1424 | + display_name=display_name, |
| 1425 | + study_spec=study_spec, |
| 1426 | + max_trial_count=max_trial_count, |
| 1427 | + parallel_trial_count=parallel_trial_count, |
| 1428 | + max_failed_trial_count=max_failed_trial_count, |
| 1429 | + trial_job_spec=copy.deepcopy(custom_job.job_spec), |
| 1430 | + encryption_spec=initializer.global_config.get_encryption_spec( |
| 1431 | + encryption_spec_key_name=encryption_spec_key_name |
| 1432 | + ), |
1444 | 1433 | ) |
1445 | 1434 |
|
1446 | 1435 | @base.optional_sync() |
@@ -1499,11 +1488,9 @@ def run( |
1499 | 1488 |
|
1500 | 1489 | if timeout or restart_job_on_worker_restart: |
1501 | 1490 | duration = duration_pb2.Duration(seconds=timeout) if timeout else None |
1502 | | - self._gca_resource.trial_job_spec.scheduling = ( |
1503 | | - gca_custom_job_compat.Scheduling( |
1504 | | - timeout=duration, |
1505 | | - restart_job_on_worker_restart=restart_job_on_worker_restart, |
1506 | | - ) |
| 1491 | + self._gca_resource.trial_job_spec.scheduling = gca_custom_job_compat.Scheduling( |
| 1492 | + timeout=duration, |
| 1493 | + restart_job_on_worker_restart=restart_job_on_worker_restart, |
1507 | 1494 | ) |
1508 | 1495 |
|
1509 | 1496 | if tensorboard: |
|
0 commit comments