Skip to content

Commit a6a4877

Browse files
committed
Fix lint and nox
1 parent 0af0c68 commit a6a4877

9 files changed

Lines changed: 113 additions & 162 deletions

File tree

google/cloud/aiplatform/hyperparameter_tuning.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,7 @@ class DoubleParameterSpec(_ParameterSpec):
101101
_parameter_spec_value_key = "double_value_spec"
102102

103103
def __init__(
104-
self,
105-
min: float,
106-
max: float,
107-
scale: str,
104+
self, min: float, max: float, scale: str,
108105
):
109106
"""
110107
Value specification for a parameter in ``DOUBLE`` type.
@@ -138,10 +135,7 @@ class IntegerParameterSpec(_ParameterSpec):
138135
_parameter_spec_value_key = "integer_value_spec"
139136

140137
def __init__(
141-
self,
142-
min: int,
143-
max: int,
144-
scale: str,
138+
self, min: int, max: int, scale: str,
145139
):
146140
"""
147141
Value specification for a parameter in ``INTEGER`` type.
@@ -175,8 +169,7 @@ class CategoricalParameterSpec(_ParameterSpec):
175169
_parameter_spec_value_key = "categorical_value_spec"
176170

177171
def __init__(
178-
self,
179-
values: Sequence[str],
172+
self, values: Sequence[str],
180173
):
181174
"""Value specification for a parameter in ``CATEGORICAL`` type.
182175
@@ -199,9 +192,7 @@ class DiscreteParameterSpec(_ParameterSpec):
199192
_parameter_spec_value_key = "discrete_value_spec"
200193

201194
def __init__(
202-
self,
203-
values: Sequence[float],
204-
scale: str,
195+
self, values: Sequence[float], scale: str,
205196
):
206197
"""Value specification for a parameter in ``DISCRETE`` type.
207198

google/cloud/aiplatform/jobs.py

Lines changed: 25 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
import abc
2121
import copy
22-
import sys
2322
import time
2423

2524
from google.cloud import storage
@@ -568,11 +567,9 @@ def create(
568567
gapic_batch_prediction_job.output_config = output_config
569568

570569
# Optional Fields
571-
gapic_batch_prediction_job.encryption_spec = (
572-
initializer.global_config.get_encryption_spec(
573-
encryption_spec_key_name=encryption_spec_key_name,
574-
select_version=select_version,
575-
)
570+
gapic_batch_prediction_job.encryption_spec = initializer.global_config.get_encryption_spec(
571+
encryption_spec_key_name=encryption_spec_key_name,
572+
select_version=select_version,
576573
)
577574

578575
if model_parameters:
@@ -604,10 +601,8 @@ def create(
604601
gapic_batch_prediction_job.generate_explanation = generate_explanation
605602

606603
if explanation_metadata or explanation_parameters:
607-
gapic_batch_prediction_job.explanation_spec = (
608-
gca_explanation_v1beta1.ExplanationSpec(
609-
metadata=explanation_metadata, parameters=explanation_parameters
610-
)
604+
gapic_batch_prediction_job.explanation_spec = gca_explanation_v1beta1.ExplanationSpec(
605+
metadata=explanation_metadata, parameters=explanation_parameters
611606
)
612607

613608
# TODO (b/174502913): Support private feature once released
@@ -1088,23 +1083,19 @@ def from_local_script(
10881083
"should be set using aiplatform.init(staging_bucket='gs://my-bucket')"
10891084
)
10901085

1091-
worker_pool_specs = (
1092-
worker_spec_utils._DistributedTrainingSpec.chief_worker_pool(
1093-
replica_count=replica_count,
1094-
machine_type=machine_type,
1095-
accelerator_count=accelerator_count,
1096-
accelerator_type=accelerator_type,
1097-
).pool_specs
1098-
)
1086+
worker_pool_specs = worker_spec_utils._DistributedTrainingSpec.chief_worker_pool(
1087+
replica_count=replica_count,
1088+
machine_type=machine_type,
1089+
accelerator_count=accelerator_count,
1090+
accelerator_type=accelerator_type,
1091+
).pool_specs
10991092

11001093
python_packager = source_utils._TrainingScriptPythonPackager(
11011094
script_path=script_path, requirements=requirements
11021095
)
11031096

11041097
package_gcs_uri = python_packager.package_and_copy_to_gcs(
1105-
gcs_staging_dir=staging_bucket,
1106-
project=project,
1107-
credentials=credentials,
1098+
gcs_staging_dir=staging_bucket, project=project, credentials=credentials,
11081099
)
11091100

11101101
for spec in worker_pool_specs:
@@ -1429,18 +1420,16 @@ def __init__(
14291420
],
14301421
)
14311422

1432-
self._gca_resource = (
1433-
gca_hyperparameter_tuning_job_compat.HyperparameterTuningJob(
1434-
display_name=display_name,
1435-
study_spec=study_spec,
1436-
max_trial_count=max_trial_count,
1437-
parallel_trial_count=parallel_trial_count,
1438-
max_failed_trial_count=max_failed_trial_count,
1439-
trial_job_spec=copy.deepcopy(custom_job.job_spec),
1440-
encryption_spec=initializer.global_config.get_encryption_spec(
1441-
encryption_spec_key_name=encryption_spec_key_name
1442-
),
1443-
)
1423+
self._gca_resource = gca_hyperparameter_tuning_job_compat.HyperparameterTuningJob(
1424+
display_name=display_name,
1425+
study_spec=study_spec,
1426+
max_trial_count=max_trial_count,
1427+
parallel_trial_count=parallel_trial_count,
1428+
max_failed_trial_count=max_failed_trial_count,
1429+
trial_job_spec=copy.deepcopy(custom_job.job_spec),
1430+
encryption_spec=initializer.global_config.get_encryption_spec(
1431+
encryption_spec_key_name=encryption_spec_key_name
1432+
),
14441433
)
14451434

14461435
@base.optional_sync()
@@ -1499,11 +1488,9 @@ def run(
14991488

15001489
if timeout or restart_job_on_worker_restart:
15011490
duration = duration_pb2.Duration(seconds=timeout) if timeout else None
1502-
self._gca_resource.trial_job_spec.scheduling = (
1503-
gca_custom_job_compat.Scheduling(
1504-
timeout=duration,
1505-
restart_job_on_worker_restart=restart_job_on_worker_restart,
1506-
)
1491+
self._gca_resource.trial_job_spec.scheduling = gca_custom_job_compat.Scheduling(
1492+
timeout=duration,
1493+
restart_job_on_worker_restart=restart_job_on_worker_restart,
15071494
)
15081495

15091496
if tensorboard:

google/cloud/aiplatform/models.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -297,8 +297,7 @@ def _create(
297297

298298
@staticmethod
299299
def _allocate_traffic(
300-
traffic_split: Dict[str, int],
301-
traffic_percentage: int,
300+
traffic_split: Dict[str, int], traffic_percentage: int,
302301
) -> Dict[str, int]:
303302
"""Allocates desired traffic to new deployed model and scales traffic
304303
of older deployed models.
@@ -334,8 +333,7 @@ def _allocate_traffic(
334333

335334
@staticmethod
336335
def _unallocate_traffic(
337-
traffic_split: Dict[str, int],
338-
deployed_model_id: str,
336+
traffic_split: Dict[str, int], deployed_model_id: str,
339337
) -> Dict[str, int]:
340338
"""Sets deployed model id's traffic to 0 and scales the traffic of
341339
other deployed models.
@@ -829,20 +827,16 @@ def _deploy_call(
829827
machine_spec.accelerator_type = accelerator_type
830828
machine_spec.accelerator_count = accelerator_count
831829

832-
deployed_model.dedicated_resources = (
833-
gca_machine_resources.DedicatedResources(
834-
machine_spec=machine_spec,
835-
min_replica_count=min_replica_count,
836-
max_replica_count=max_replica_count,
837-
)
830+
deployed_model.dedicated_resources = gca_machine_resources.DedicatedResources(
831+
machine_spec=machine_spec,
832+
min_replica_count=min_replica_count,
833+
max_replica_count=max_replica_count,
838834
)
839835

840836
else:
841-
deployed_model.automatic_resources = (
842-
gca_machine_resources.AutomaticResources(
843-
min_replica_count=min_replica_count,
844-
max_replica_count=max_replica_count,
845-
)
837+
deployed_model.automatic_resources = gca_machine_resources.AutomaticResources(
838+
min_replica_count=min_replica_count,
839+
max_replica_count=max_replica_count,
846840
)
847841

848842
# Service will throw error if both metadata and parameters are not provided
@@ -2178,8 +2172,8 @@ def export_model(
21782172
)
21792173

21802174
if image_destination:
2181-
output_config.image_destination = (
2182-
gca_io_compat.ContainerRegistryDestination(output_uri=image_destination)
2175+
output_config.image_destination = gca_io_compat.ContainerRegistryDestination(
2176+
output_uri=image_destination
21832177
)
21842178

21852179
_LOGGER.log_action_start_against_resource("Exporting", "model", self)

google/cloud/aiplatform/training_jobs.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -616,9 +616,7 @@ def _get_model(self) -> Optional[models.Model]:
616616
)
617617

618618
return models.Model(
619-
fields.id,
620-
project=fields.project,
621-
location=fields.location,
619+
fields.id, project=fields.project, location=fields.location,
622620
)
623621

624622
def _block_until_complete(self):
@@ -1056,14 +1054,12 @@ def _prepare_and_validate_run(
10561054
model_display_name = model_display_name or self._display_name + "-model"
10571055

10581056
# validates args and will raise
1059-
worker_pool_specs = (
1060-
worker_spec_utils._DistributedTrainingSpec.chief_worker_pool(
1061-
replica_count=replica_count,
1062-
machine_type=machine_type,
1063-
accelerator_count=accelerator_count,
1064-
accelerator_type=accelerator_type,
1065-
).pool_specs
1066-
)
1057+
worker_pool_specs = worker_spec_utils._DistributedTrainingSpec.chief_worker_pool(
1058+
replica_count=replica_count,
1059+
machine_type=machine_type,
1060+
accelerator_count=accelerator_count,
1061+
accelerator_type=accelerator_type,
1062+
).pool_specs
10671063

10681064
managed_model = self._managed_model
10691065
if model_display_name:
@@ -4442,10 +4438,8 @@ def __init__(
44424438
schema.training_job.definition.automl_text_classification
44434439
)
44444440

4445-
training_task_inputs_dict = (
4446-
training_job_inputs.AutoMlTextClassificationInputs(
4447-
multi_label=multi_label
4448-
)
4441+
training_task_inputs_dict = training_job_inputs.AutoMlTextClassificationInputs(
4442+
multi_label=multi_label
44494443
)
44504444
elif prediction_type == "extraction":
44514445
training_task_definition = (

google/cloud/aiplatform/utils/json_utils.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@
2222
from google.cloud import storage
2323

2424

25-
def load_json(path: str,
26-
project: Optional[str] = None,
27-
credentials: Optional[auth_credentials.Credentials] = None
28-
) -> Dict[str, Any]:
25+
def load_json(
26+
path: str,
27+
project: Optional[str] = None,
28+
credentials: Optional[auth_credentials.Credentials] = None,
29+
) -> Dict[str, Any]:
2930
"""Loads data from a JSON document.
3031
3132
Args:
@@ -40,16 +41,17 @@ def load_json(path: str,
4041
Returns:
4142
A Dict object representing the JSON document.
4243
"""
43-
if path.startswith('gs://'):
44+
if path.startswith("gs://"):
4445
return _load_json_from_gs_uri(path, project, credentials)
4546
else:
4647
return _load_json_from_local_file(path)
4748

4849

49-
def _load_json_from_gs_uri(uri: str,
50-
project: Optional[str] = None,
51-
credentials: Optional[auth_credentials.Credentials]
52-
= None) -> Dict[str, Any]:
50+
def _load_json_from_gs_uri(
51+
uri: str,
52+
project: Optional[str] = None,
53+
credentials: Optional[auth_credentials.Credentials] = None,
54+
) -> Dict[str, Any]:
5355
"""Loads data from a JSON document referenced by a GCS URI.
5456
5557
Args:

0 commit comments

Comments
 (0)