Skip to content

Commit 22ec383

Browse files
committed
fix: job state name print
1 parent 6332d33 commit 22ec383

8 files changed

Lines changed: 169 additions & 5 deletions

File tree

google/cloud/aiplatform/jobs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def _log_job_state(self):
222222
% (
223223
self.__class__.__name__,
224224
self._gca_resource.name,
225-
self._gca_resource.state,
225+
self._gca_resource.state.name,
226226
)
227227
)
228228

@@ -1490,7 +1490,7 @@ def iter_outputs(
14901490
if self.state != gca_job_state.JobState.JOB_STATE_SUCCEEDED:
14911491
raise RuntimeError(
14921492
f"Cannot read outputs until BatchPredictionJob has succeeded, "
1493-
f"current state: {self._gca_resource.state}"
1493+
f"current state: {self._gca_resource.state.name}"
14941494
)
14951495

14961496
output_info = self._gca_resource.output_info

google/cloud/aiplatform/pipeline_jobs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,7 @@ def _block_until_complete(self):
786786
% (
787787
self.__class__.__name__,
788788
self._gca_resource.name,
789-
self._gca_resource.state,
789+
self._gca_resource.state.name,
790790
)
791791
)
792792
log_wait = min(log_wait * multiplier, max_wait)

google/cloud/aiplatform/schedules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def _block_until_complete(self) -> None:
223223
% (
224224
self.__class__.__name__,
225225
self._gca_resource.name,
226-
self._gca_resource.state,
226+
self._gca_resource.state.name,
227227
)
228228
)
229229
log_wait = min(log_wait * multiplier, max_wait)

google/cloud/aiplatform/training_jobs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -974,7 +974,7 @@ def _block_until_complete(self):
974974
% (
975975
self.__class__.__name__,
976976
self._gca_resource.name,
977-
self._gca_resource.state,
977+
self._gca_resource.state.name,
978978
)
979979
)
980980
log_wait = min(log_wait * _WAIT_TIME_MULTIPLIER, _MAX_WAIT_TIME)

tests/unit/aiplatform/test_jobs.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,21 @@ def test_cancel_mock_job(self, fake_job_cancel_mock):
380380

381381
fake_job_cancel_mock.assert_called_once_with(name=_TEST_JOB_RESOURCE_NAME)
382382

383+
@pytest.mark.usefixtures("fake_job_getter_mock")
384+
def test_log_job_state_uses_symbolic_name(self):
385+
"""_log_job_state must log the enum name, not the integer value (regression for Python 3.11+)."""
386+
fake_job = self.FakeJob(job_name=_TEST_JOB_RESOURCE_NAME)
387+
fake_job._gca_resource = mock.Mock()
388+
fake_job._gca_resource.name = _TEST_JOB_RESOURCE_NAME
389+
fake_job._gca_resource.state = gca_job_state_compat.JobState.JOB_STATE_RUNNING
390+
391+
with mock.patch.object(jobs._LOGGER, "info") as mock_info:
392+
fake_job._log_job_state()
393+
394+
logged_msg = mock_info.call_args[0][0]
395+
assert "JOB_STATE_RUNNING" in logged_msg
396+
assert "current state:\n3" not in logged_msg
397+
383398

384399
@pytest.fixture
385400
def get_batch_prediction_job_mock():
@@ -695,6 +710,21 @@ def test_batch_prediction_iter_dirs_while_running(self):
695710
)
696711
bp.iter_outputs()
697712

713+
@pytest.mark.usefixtures("get_batch_prediction_job_running_bq_output_mock")
714+
def test_batch_prediction_iter_dirs_while_running_error_uses_symbolic_state_name(
715+
self,
716+
):
717+
"""RuntimeError message must use symbolic state name, not integer (regression for Python 3.11+)."""
718+
with pytest.raises(RuntimeError) as exc_info:
719+
bp = jobs.BatchPredictionJob(
720+
batch_prediction_job_name=_TEST_BATCH_PREDICTION_JOB_NAME
721+
)
722+
bp.iter_outputs()
723+
724+
error_msg = str(exc_info.value)
725+
assert "JOB_STATE_RUNNING" in error_msg
726+
assert "current state: 3" not in error_msg
727+
698728
@pytest.mark.usefixtures("get_batch_prediction_job_empty_output_mock")
699729
def test_batch_prediction_iter_dirs_invalid_output_info(self):
700730
"""

tests/unit/aiplatform/test_pipeline_job_schedules.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
)
4444
from google.cloud.aiplatform import (
4545
pipeline_job_schedules,
46+
schedules as aiplatform_schedules,
4647
)
4748
from google.cloud.aiplatform.preview.pipelinejob import (
4849
pipeline_jobs as preview_pipeline_jobs,
@@ -434,6 +435,47 @@ def setup_method(self):
434435
def teardown_method(self):
435436
initializer.global_pool.shutdown(wait=True)
436437

438+
def test_block_until_complete_logs_symbolic_state_name(self):
439+
"""State log must use symbolic enum name, not a bare integer (regression for Python 3.11+)."""
440+
state_sequence = [
441+
gca_schedule.Schedule.State.ACTIVE, # first loop check
442+
gca_schedule.Schedule.State.COMPLETED, # second check exits loop
443+
]
444+
state_index = [0]
445+
446+
def get_state():
447+
s = state_sequence[state_index[0]]
448+
state_index[0] = min(state_index[0] + 1, len(state_sequence) - 1)
449+
return s
450+
451+
mock_schedule = mock.Mock()
452+
type(mock_schedule).state = mock.PropertyMock(side_effect=get_state)
453+
454+
active_gca = gca_schedule.Schedule(
455+
name=_TEST_PIPELINE_JOB_SCHEDULE_NAME,
456+
state=gca_schedule.Schedule.State.ACTIVE,
457+
)
458+
mock_schedule._gca_resource = active_gca
459+
460+
logged_messages = []
461+
462+
# time.time: first call sets previous_time=0; second gives 10 → triggers log (10 >= 5)
463+
time_vals = iter([0.0, 10.0, 20.0])
464+
with mock.patch("google.cloud.aiplatform.schedules.time.time", side_effect=time_vals), \
465+
mock.patch("google.cloud.aiplatform.schedules.time.sleep"), \
466+
mock.patch.object(
467+
aiplatform_schedules._LOGGER, "info",
468+
side_effect=lambda msg, *a, **kw: logged_messages.append(msg)
469+
):
470+
aiplatform_schedules._Schedule._block_until_complete(mock_schedule)
471+
472+
state_log = next(
473+
(m for m in logged_messages if "current state" in m), None
474+
)
475+
assert state_log is not None, "No 'current state' log message found"
476+
assert "ACTIVE" in state_log
477+
assert "current state:\n1" not in state_log
478+
437479
@pytest.mark.parametrize(
438480
"job_spec",
439481
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],

tests/unit/aiplatform/test_pipeline_jobs.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,45 @@ def setup_method(self):
708708
def teardown_method(self):
709709
initializer.global_pool.shutdown(wait=True)
710710

711+
@mock.patch.object(pipeline_jobs, "_JOB_WAIT_TIME", 0)
712+
@mock.patch.object(pipeline_jobs, "_LOG_WAIT_TIME", 0)
713+
def test_block_until_complete_logs_symbolic_state_name(
714+
self,
715+
mock_pipeline_service_create,
716+
mock_pipeline_service_get,
717+
mock_pipeline_bucket_exists,
718+
):
719+
"""State log must use symbolic enum name, not a bare integer (regression for Python 3.11+)."""
720+
aiplatform.init(
721+
project=_TEST_PROJECT,
722+
staging_bucket=_TEST_GCS_BUCKET_NAME,
723+
location=_TEST_LOCATION,
724+
credentials=_TEST_CREDENTIALS,
725+
)
726+
727+
logged_messages = []
728+
729+
with patch.object(storage.Blob, "download_as_bytes") as mock_load, \
730+
mock.patch.object(
731+
pipeline_jobs._LOGGER, "info",
732+
side_effect=lambda msg, *a, **kw: logged_messages.append(msg)
733+
):
734+
mock_load.return_value = _TEST_PIPELINE_SPEC_JSON.encode()
735+
736+
job = pipeline_jobs.PipelineJob(
737+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
738+
template_path=_TEST_TEMPLATE_PATH,
739+
job_id=_TEST_PIPELINE_JOB_ID,
740+
)
741+
job.run(sync=True, create_request_timeout=None)
742+
743+
state_log = next(
744+
(m for m in logged_messages if "current state" in m), None
745+
)
746+
assert state_log is not None, "No 'current state' log message found"
747+
assert "PIPELINE_STATE_RUNNING" in state_log
748+
assert "current state:\n3" not in state_log
749+
711750
@pytest.mark.parametrize(
712751
"job_spec",
713752
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],

tests/unit/aiplatform/test_training_jobs.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,6 +1288,59 @@ def teardown_method(self):
12881288
pathlib.Path(self._local_script_file_name).unlink()
12891289
initializer.global_pool.shutdown(wait=True)
12901290

1291+
def test_block_until_complete_logs_symbolic_state_name(
1292+
self, mock_model_service_get
1293+
):
1294+
"""State log must use symbolic enum name, not a bare integer (regression for Python 3.11+)."""
1295+
aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME)
1296+
1297+
logged_messages = []
1298+
1299+
with mock.patch.object(
1300+
pipeline_service_client.PipelineServiceClient, "create_training_pipeline"
1301+
) as mock_create, mock.patch.object(
1302+
source_utils._TrainingScriptPythonPackager, "package_and_copy_to_gcs"
1303+
) as mock_pkg, mock.patch.object(
1304+
pipeline_service_client.PipelineServiceClient, "get_training_pipeline"
1305+
) as mock_get, mock.patch.object(
1306+
training_jobs, "_LOG_WAIT_TIME", 0
1307+
), mock.patch.object(
1308+
training_jobs, "_JOB_WAIT_TIME", 0
1309+
), mock.patch.object(
1310+
training_jobs._LOGGER, "info", side_effect=lambda msg, *a, **kw: logged_messages.append(msg)
1311+
):
1312+
mock_pkg.return_value = _TEST_OUTPUT_PYTHON_PACKAGE_PATH
1313+
mock_create.return_value = gca_training_pipeline.TrainingPipeline(
1314+
name=_TEST_PIPELINE_RESOURCE_NAME,
1315+
state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED,
1316+
model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME),
1317+
)
1318+
_running = gca_training_pipeline.TrainingPipeline(
1319+
name=_TEST_PIPELINE_RESOURCE_NAME,
1320+
state=gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING,
1321+
training_task_inputs={},
1322+
)
1323+
_succeeded = gca_training_pipeline.TrainingPipeline(
1324+
name=_TEST_PIPELINE_RESOURCE_NAME,
1325+
state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED,
1326+
training_task_inputs={},
1327+
model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME),
1328+
)
1329+
mock_get.side_effect = [_running, _running] + [_succeeded] * 8
1330+
job = training_jobs.CustomTrainingJob(
1331+
display_name=_TEST_DISPLAY_NAME,
1332+
script_path=self._local_script_file_name,
1333+
container_uri=_TEST_TRAINING_CONTAINER_IMAGE,
1334+
)
1335+
job.run(base_output_dir=_TEST_BASE_OUTPUT_DIR, sync=True)
1336+
1337+
state_log = next(
1338+
(m for m in logged_messages if "current state" in m), None
1339+
)
1340+
assert state_log is not None, "No 'current state' log message found"
1341+
assert "PIPELINE_STATE_RUNNING" in state_log
1342+
assert "current state:\n3" not in state_log
1343+
12911344
@mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1)
12921345
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
12931346
@pytest.mark.parametrize("sync", [True, False])

0 commit comments

Comments
 (0)