Skip to content

Commit d9675fd

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: allow setting max_wait_duration to 0 for indefinite waiting with DWS
PiperOrigin-RevId: 907837984
1 parent b803f3f commit d9675fd

3 files changed

Lines changed: 248 additions & 5 deletions

File tree

google/cloud/aiplatform/jobs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2604,12 +2604,12 @@ def submit(
26042604
or restart_job_on_worker_restart
26052605
or disable_retries
26062606
or scheduling_strategy
2607-
or max_wait_duration
2607+
or max_wait_duration is not None # 0 is a valid value
26082608
):
26092609
timeout = duration_pb2.Duration(seconds=timeout) if timeout else None
26102610
max_wait_duration = (
26112611
duration_pb2.Duration(seconds=max_wait_duration)
2612-
if max_wait_duration
2612+
if max_wait_duration is not None
26132613
else None
26142614
)
26152615
self._gca_resource.job_spec.scheduling = gca_custom_job_compat.Scheduling(
@@ -3133,13 +3133,13 @@ def _run(
31333133
timeout
31343134
or restart_job_on_worker_restart
31353135
or disable_retries
3136-
or max_wait_duration
3136+
or max_wait_duration is not None # 0 is a valid value
31373137
or scheduling_strategy
31383138
):
31393139
timeout = duration_pb2.Duration(seconds=timeout) if timeout else None
31403140
max_wait_duration = (
31413141
duration_pb2.Duration(seconds=max_wait_duration)
3142-
if max_wait_duration
3142+
if max_wait_duration is not None
31433143
else None
31443144
)
31453145
self._gca_resource.trial_job_spec.scheduling = (

tests/unit/aiplatform/test_custom_job.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
#
1616

17+
import datetime
1718
import pytest
1819
import logging
1920

@@ -730,6 +731,101 @@ def test_submit_custom_job(self, create_custom_job_mock, get_custom_job_mock):
730731
)
731732
assert job.network == _TEST_NETWORK
732733

734+
def test_submit_custom_job_with_zero_max_wait_duration(
735+
self, create_custom_job_mock, get_custom_job_mock
736+
):
737+
738+
aiplatform.init(
739+
project=_TEST_PROJECT,
740+
location=_TEST_LOCATION,
741+
staging_bucket=_TEST_STAGING_BUCKET,
742+
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
743+
)
744+
745+
job = aiplatform.CustomJob(
746+
display_name=_TEST_DISPLAY_NAME,
747+
worker_pool_specs=_TEST_WORKER_POOL_SPEC,
748+
base_output_dir=_TEST_BASE_OUTPUT_DIR,
749+
labels=_TEST_LABELS,
750+
)
751+
752+
job.submit(
753+
service_account=_TEST_SERVICE_ACCOUNT,
754+
network=_TEST_NETWORK,
755+
timeout=_TEST_TIMEOUT,
756+
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
757+
create_request_timeout=None,
758+
disable_retries=_TEST_DISABLE_RETRIES,
759+
max_wait_duration=0,
760+
)
761+
762+
job.wait_for_resource_creation()
763+
764+
assert job.resource_name == _TEST_CUSTOM_JOB_NAME
765+
766+
job.wait()
767+
768+
expected_custom_job = _get_custom_job_proto()
769+
expected_custom_job.job_spec.scheduling.max_wait_duration = datetime.timedelta(
770+
seconds=0
771+
)
772+
773+
create_custom_job_mock.assert_called_once_with(
774+
parent=_TEST_PARENT,
775+
custom_job=expected_custom_job,
776+
timeout=None,
777+
)
778+
assert (
779+
job._gca_resource.state == gca_job_state_compat.JobState.JOB_STATE_PENDING
780+
)
781+
782+
def test_submit_custom_job_with_default_max_wait_duration(
783+
self, create_custom_job_mock, get_custom_job_mock
784+
):
785+
786+
aiplatform.init(
787+
project=_TEST_PROJECT,
788+
location=_TEST_LOCATION,
789+
staging_bucket=_TEST_STAGING_BUCKET,
790+
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
791+
)
792+
793+
job = aiplatform.CustomJob(
794+
display_name=_TEST_DISPLAY_NAME,
795+
worker_pool_specs=_TEST_WORKER_POOL_SPEC,
796+
base_output_dir=_TEST_BASE_OUTPUT_DIR,
797+
labels=_TEST_LABELS,
798+
)
799+
800+
job.submit(
801+
service_account=_TEST_SERVICE_ACCOUNT,
802+
network=_TEST_NETWORK,
803+
timeout=_TEST_TIMEOUT,
804+
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
805+
create_request_timeout=None,
806+
disable_retries=_TEST_DISABLE_RETRIES,
807+
)
808+
809+
job.wait_for_resource_creation()
810+
811+
assert job.resource_name == _TEST_CUSTOM_JOB_NAME
812+
813+
job.wait()
814+
815+
expected_custom_job = _get_custom_job_proto()
816+
expected_custom_job.job_spec.scheduling.max_wait_duration = None
817+
818+
create_custom_job_mock.assert_called_once_with(
819+
parent=_TEST_PARENT,
820+
custom_job=expected_custom_job,
821+
timeout=None,
822+
)
823+
824+
assert "max_wait_duration" not in expected_custom_job.job_spec.scheduling
825+
assert (
826+
job._gca_resource.state == gca_job_state_compat.JobState.JOB_STATE_PENDING
827+
)
828+
733829
@pytest.mark.usefixtures(
734830
"get_experiment_run_mock", "get_tensorboard_run_artifact_not_found_mock"
735831
)

tests/unit/aiplatform/test_hyperparameter_tuning_job.py

Lines changed: 148 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414
# limitations under the License.
1515
#
1616

17+
import copy
18+
import datetime
1719
import pytest
1820

19-
import copy
2021
from importlib import reload
2122
from unittest import mock
2223
from unittest.mock import patch
@@ -523,6 +524,152 @@ def test_create_hyperparameter_tuning_job(
523524
assert job.network == _TEST_NETWORK
524525
assert job.trials == []
525526

527+
def test_create_hyperparameter_tuning_job_with_zero_max_wait_duration(
528+
self,
529+
create_hyperparameter_tuning_job_mock,
530+
get_hyperparameter_tuning_job_mock,
531+
):
532+
533+
aiplatform.init(
534+
project=_TEST_PROJECT,
535+
location=_TEST_LOCATION,
536+
staging_bucket=_TEST_STAGING_BUCKET,
537+
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
538+
)
539+
540+
custom_job = aiplatform.CustomJob(
541+
display_name=test_constants.TrainingJobConstants._TEST_DISPLAY_NAME,
542+
worker_pool_specs=test_constants.TrainingJobConstants._TEST_WORKER_POOL_SPEC,
543+
base_output_dir=test_constants.TrainingJobConstants._TEST_BASE_OUTPUT_DIR,
544+
)
545+
546+
job = aiplatform.HyperparameterTuningJob(
547+
display_name=_TEST_DISPLAY_NAME,
548+
custom_job=custom_job,
549+
metric_spec={_TEST_METRIC_SPEC_KEY: _TEST_METRIC_SPEC_VALUE},
550+
parameter_spec={
551+
"lr": hpt.DoubleParameterSpec(min=0.001, max=0.1, scale="log"),
552+
"units": hpt.IntegerParameterSpec(min=4, max=1028, scale="linear"),
553+
"activation": hpt.CategoricalParameterSpec(
554+
values=["relu", "sigmoid", "elu", "selu", "tanh"]
555+
),
556+
"batch_size": hpt.DiscreteParameterSpec(
557+
values=[4, 8, 16, 32, 64],
558+
scale="linear",
559+
conditional_parameter_spec={
560+
"decay": _TEST_CONDITIONAL_PARAMETER_DECAY,
561+
"learning_rate": _TEST_CONDITIONAL_PARAMETER_LR,
562+
},
563+
),
564+
},
565+
parallel_trial_count=_TEST_PARALLEL_TRIAL_COUNT,
566+
max_trial_count=_TEST_MAX_TRIAL_COUNT,
567+
max_failed_trial_count=_TEST_MAX_FAILED_TRIAL_COUNT,
568+
search_algorithm=_TEST_SEARCH_ALGORITHM,
569+
measurement_selection=_TEST_MEASUREMENT_SELECTION,
570+
labels=_TEST_LABELS,
571+
)
572+
573+
job.run(
574+
service_account=_TEST_SERVICE_ACCOUNT,
575+
network=_TEST_NETWORK,
576+
timeout=_TEST_TIMEOUT,
577+
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
578+
sync=True,
579+
create_request_timeout=None,
580+
disable_retries=_TEST_DISABLE_RETRIES,
581+
max_wait_duration=0,
582+
)
583+
584+
job.wait()
585+
586+
expected_hyperparameter_tuning_job = _get_hyperparameter_tuning_job_proto()
587+
expected_hyperparameter_tuning_job.trial_job_spec.scheduling.max_wait_duration = datetime.timedelta(
588+
seconds=0
589+
)
590+
591+
create_hyperparameter_tuning_job_mock.assert_called_once_with(
592+
parent=_TEST_PARENT,
593+
hyperparameter_tuning_job=expected_hyperparameter_tuning_job,
594+
timeout=None,
595+
)
596+
assert job.state == gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED
597+
598+
def test_create_hyperparameter_tuning_job_with_default_max_wait_duration(
599+
self,
600+
create_hyperparameter_tuning_job_mock,
601+
get_hyperparameter_tuning_job_mock,
602+
):
603+
604+
aiplatform.init(
605+
project=_TEST_PROJECT,
606+
location=_TEST_LOCATION,
607+
staging_bucket=_TEST_STAGING_BUCKET,
608+
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
609+
)
610+
611+
custom_job = aiplatform.CustomJob(
612+
display_name=test_constants.TrainingJobConstants._TEST_DISPLAY_NAME,
613+
worker_pool_specs=test_constants.TrainingJobConstants._TEST_WORKER_POOL_SPEC,
614+
base_output_dir=test_constants.TrainingJobConstants._TEST_BASE_OUTPUT_DIR,
615+
)
616+
617+
job = aiplatform.HyperparameterTuningJob(
618+
display_name=_TEST_DISPLAY_NAME,
619+
custom_job=custom_job,
620+
metric_spec={_TEST_METRIC_SPEC_KEY: _TEST_METRIC_SPEC_VALUE},
621+
parameter_spec={
622+
"lr": hpt.DoubleParameterSpec(min=0.001, max=0.1, scale="log"),
623+
"units": hpt.IntegerParameterSpec(min=4, max=1028, scale="linear"),
624+
"activation": hpt.CategoricalParameterSpec(
625+
values=["relu", "sigmoid", "elu", "selu", "tanh"]
626+
),
627+
"batch_size": hpt.DiscreteParameterSpec(
628+
values=[4, 8, 16, 32, 64],
629+
scale="linear",
630+
conditional_parameter_spec={
631+
"decay": _TEST_CONDITIONAL_PARAMETER_DECAY,
632+
"learning_rate": _TEST_CONDITIONAL_PARAMETER_LR,
633+
},
634+
),
635+
},
636+
parallel_trial_count=_TEST_PARALLEL_TRIAL_COUNT,
637+
max_trial_count=_TEST_MAX_TRIAL_COUNT,
638+
max_failed_trial_count=_TEST_MAX_FAILED_TRIAL_COUNT,
639+
search_algorithm=_TEST_SEARCH_ALGORITHM,
640+
measurement_selection=_TEST_MEASUREMENT_SELECTION,
641+
labels=_TEST_LABELS,
642+
)
643+
644+
job.run(
645+
service_account=_TEST_SERVICE_ACCOUNT,
646+
network=_TEST_NETWORK,
647+
timeout=_TEST_TIMEOUT,
648+
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
649+
sync=True,
650+
create_request_timeout=None,
651+
disable_retries=_TEST_DISABLE_RETRIES,
652+
)
653+
654+
job.wait()
655+
656+
expected_hyperparameter_tuning_job = _get_hyperparameter_tuning_job_proto()
657+
expected_hyperparameter_tuning_job.trial_job_spec.scheduling.max_wait_duration = (
658+
None
659+
)
660+
661+
create_hyperparameter_tuning_job_mock.assert_called_once_with(
662+
parent=_TEST_PARENT,
663+
hyperparameter_tuning_job=expected_hyperparameter_tuning_job,
664+
timeout=None,
665+
)
666+
667+
assert (
668+
"max_wait_duration"
669+
not in expected_hyperparameter_tuning_job.trial_job_spec.scheduling
670+
)
671+
assert job.state == gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED
672+
526673
@pytest.mark.parametrize("sync", [True, False])
527674
def test_create_hyperparameter_tuning_job_with_timeout(
528675
self,

0 commit comments

Comments
 (0)