From 44bdde2916535208150df25d6fc2092294bcef0e Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Wed, 21 May 2025 22:34:47 -0500 Subject: [PATCH 1/5] refactor status parsing a little bit again --- mlos_bench/mlos_bench/environments/status.py | 37 +++++++++++++------ mlos_bench/mlos_bench/storage/sql/common.py | 2 +- .../mlos_bench/storage/sql/experiment.py | 6 +-- .../tests/environments/test_status.py | 8 ++-- 4 files changed, 34 insertions(+), 19 deletions(-) diff --git a/mlos_bench/mlos_bench/environments/status.py b/mlos_bench/mlos_bench/environments/status.py index 6d76d7206c..d49c4a9e0f 100644 --- a/mlos_bench/mlos_bench/environments/status.py +++ b/mlos_bench/mlos_bench/environments/status.py @@ -24,21 +24,36 @@ class Status(enum.Enum): TIMED_OUT = 7 @staticmethod - def from_str(status_str: Any) -> "Status": - """Convert a string to a Status enum.""" - if not isinstance(status_str, str): - _LOG.warning("Expected type %s for status: %s", type(status_str), status_str) - status_str = str(status_str) - if status_str.isdigit(): + def parse(status: Any) -> "Status": + """Convert the input to a Status enum. + + Parameters + ---------- + status : Any + The status to parse. This can be a string (or string convertible), + int, or Status enum. + + Returns + ------- + Status + The corresponding Status enum value or else UNKNOWN if the input is not + recognized. + """ + if isinstance(status, Status): + return status + if not isinstance(status, str): + _LOG.warning("Expected type %s for status: %s", type(status), status) + status = str(status) + if status.isdigit(): try: - return Status(int(status_str)) + return Status(int(status)) except ValueError: - _LOG.warning("Unknown status: %d", int(status_str)) + _LOG.warning("Unknown status: %d", int(status)) try: - status_str = status_str.upper().strip() - return Status[status_str] + status = status.upper().strip() + return Status[status] except KeyError: - _LOG.warning("Unknown status: %s", status_str) + _LOG.warning("Unknown status: %s", status) return Status.UNKNOWN def is_good(self) -> bool: diff --git a/mlos_bench/mlos_bench/storage/sql/common.py b/mlos_bench/mlos_bench/storage/sql/common.py index 97eb270c9d..032cf9259d 100644 --- a/mlos_bench/mlos_bench/storage/sql/common.py +++ b/mlos_bench/mlos_bench/storage/sql/common.py @@ -95,7 +95,7 @@ def get_trials( config_id=trial.config_id, ts_start=utcify_timestamp(trial.ts_start, origin="utc"), ts_end=utcify_nullable_timestamp(trial.ts_end, origin="utc"), - status=Status.from_str(trial.status), + status=Status.parse(trial.status), trial_runner_id=trial.trial_runner_id, ) for trial in trials.fetchall() diff --git a/mlos_bench/mlos_bench/storage/sql/experiment.py b/mlos_bench/mlos_bench/storage/sql/experiment.py index acc2a497b4..0e380e3e13 100644 --- a/mlos_bench/mlos_bench/storage/sql/experiment.py +++ b/mlos_bench/mlos_bench/storage/sql/experiment.py @@ -188,7 +188,7 @@ def load( status: list[Status] = [] for trial in cur_trials.fetchall(): - stat = Status.from_str(trial.status) + stat = Status.parse(trial.status) status.append(stat) trial_ids.append(trial.trial_id) configs.append( @@ -272,7 +272,7 @@ def get_trial_by_id( config_id=trial.config_id, trial_runner_id=trial.trial_runner_id, opt_targets=self._opt_targets, - status=Status.from_str(trial.status), + status=Status.parse(trial.status), restoring=True, config=config, ) @@ -330,7 +330,7 @@ def pending_trials( config_id=trial.config_id, trial_runner_id=trial.trial_runner_id, opt_targets=self._opt_targets, - status=Status.from_str(trial.status), + status=Status.parse(trial.status), restoring=True, config=config, ) diff --git a/mlos_bench/mlos_bench/tests/environments/test_status.py b/mlos_bench/mlos_bench/tests/environments/test_status.py index 3c0a9bccf3..8123f2b852 100644 --- a/mlos_bench/mlos_bench/tests/environments/test_status.py +++ b/mlos_bench/mlos_bench/tests/environments/test_status.py @@ -51,16 +51,16 @@ def test_status_from_str_valid(input_str: str, expected_status: Status) -> None: Expected Status enum value. """ assert ( - Status.from_str(input_str) == expected_status + Status.parse(input_str) == expected_status ), f"Expected {expected_status} for input: {input_str}" # Check lowercase representation assert ( - Status.from_str(input_str.lower()) == expected_status + Status.parse(input_str.lower()) == expected_status ), f"Expected {expected_status} for input: {input_str.lower()}" if input_str.isdigit(): # Also test the numeric representation assert ( - Status.from_str(int(input_str)) == expected_status + Status.parse(int(input_str)) == expected_status ), f"Expected {expected_status} for input: {int(input_str)}" @@ -83,7 +83,7 @@ def test_status_from_str_invalid(invalid_input: Any) -> None: input. """ assert ( - Status.from_str(invalid_input) == Status.UNKNOWN + Status.parse(invalid_input) == Status.UNKNOWN ), f"Expected Status.UNKNOWN for invalid input: {invalid_input}" From 607fffdc53ef9e930b6a50814645ea79b508d250 Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Wed, 21 May 2025 22:36:27 -0500 Subject: [PATCH 2/5] extra test too --- mlos_bench/mlos_bench/tests/environments/test_status.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mlos_bench/mlos_bench/tests/environments/test_status.py b/mlos_bench/mlos_bench/tests/environments/test_status.py index 8123f2b852..785275825c 100644 --- a/mlos_bench/mlos_bench/tests/environments/test_status.py +++ b/mlos_bench/mlos_bench/tests/environments/test_status.py @@ -57,6 +57,9 @@ def test_status_from_str_valid(input_str: str, expected_status: Status) -> None: assert ( Status.parse(input_str.lower()) == expected_status ), f"Expected {expected_status} for input: {input_str.lower()}" + assert ( + Status.parse(expected_status) == expected_status + ), f"Expected {expected_status} for input: {expected_status}" if input_str.isdigit(): # Also test the numeric representation assert ( From aaf0842e6f5cc9b3edbce958145becbb82196749 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 22 May 2025 04:50:54 +0000 Subject: [PATCH 3/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mlos_bench/mlos_bench/environments/status.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlos_bench/mlos_bench/environments/status.py b/mlos_bench/mlos_bench/environments/status.py index d49c4a9e0f..8bec0a22c4 100644 --- a/mlos_bench/mlos_bench/environments/status.py +++ b/mlos_bench/mlos_bench/environments/status.py @@ -25,7 +25,8 @@ class Status(enum.Enum): @staticmethod def parse(status: Any) -> "Status": - """Convert the input to a Status enum. + """ + Convert the input to a Status enum. Parameters ---------- From 8a4aac29423e7996771ddb926fe47f14fbbaa89f Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Thu, 22 May 2025 13:19:45 -0500 Subject: [PATCH 4/5] comments --- mlos_bench/mlos_bench/environments/status.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/mlos_bench/mlos_bench/environments/status.py b/mlos_bench/mlos_bench/environments/status.py index 8bec0a22c4..6343d3e854 100644 --- a/mlos_bench/mlos_bench/environments/status.py +++ b/mlos_bench/mlos_bench/environments/status.py @@ -129,4 +129,15 @@ def is_timed_out(self) -> bool: Status.TIMED_OUT, } ) -"""The set of completed statuses.""" +""" +The set of completed statuses. + +Includes all statuses that indicate the trial or experiment has finished, either +successfully or not. +This set is used to determine if a trial or experiment has reached a final state. +This includes: +- :py:data:`.Status.SUCCEEDED`: The trial or experiment completed successfully. +- :py:data:`.Status.CANCELED`: The trial or experiment was canceled. +- :py:data:`.Status.FAILED`: The trial or experiment failed. +- :py:data:`.Status.TIMED_OUT`: The trial or experiment timed out. +""" From f5cb4689bde3a435dd8ec05b6111efe528a3735a Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Thu, 22 May 2025 17:26:54 -0500 Subject: [PATCH 5/5] doc tweaks --- mlos_bench/mlos_bench/environments/status.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlos_bench/mlos_bench/environments/status.py b/mlos_bench/mlos_bench/environments/status.py index 6343d3e854..aa3b3e99c1 100644 --- a/mlos_bench/mlos_bench/environments/status.py +++ b/mlos_bench/mlos_bench/environments/status.py @@ -136,8 +136,8 @@ def is_timed_out(self) -> bool: successfully or not. This set is used to determine if a trial or experiment has reached a final state. This includes: -- :py:data:`.Status.SUCCEEDED`: The trial or experiment completed successfully. -- :py:data:`.Status.CANCELED`: The trial or experiment was canceled. -- :py:data:`.Status.FAILED`: The trial or experiment failed. -- :py:data:`.Status.TIMED_OUT`: The trial or experiment timed out. +- :py:attr:`.Status.SUCCEEDED`: The trial or experiment completed successfully. +- :py:attr:`.Status.CANCELED`: The trial or experiment was canceled. +- :py:attr:`.Status.FAILED`: The trial or experiment failed. +- :py:attr:`.Status.TIMED_OUT`: The trial or experiment timed out. """