diff --git a/mlos_bench/mlos_bench/environments/status.py b/mlos_bench/mlos_bench/environments/status.py index 6d76d7206c..aa3b3e99c1 100644 --- a/mlos_bench/mlos_bench/environments/status.py +++ b/mlos_bench/mlos_bench/environments/status.py @@ -24,21 +24,37 @@ class Status(enum.Enum): TIMED_OUT = 7 @staticmethod - def from_str(status_str: Any) -> "Status": - """Convert a string to a Status enum.""" - if not isinstance(status_str, str): - _LOG.warning("Expected type %s for status: %s", type(status_str), status_str) - status_str = str(status_str) - if status_str.isdigit(): + def parse(status: Any) -> "Status": + """ + Convert the input to a Status enum. + + Parameters + ---------- + status : Any + The status to parse. This can be a string (or string convertible), + int, or Status enum. + + Returns + ------- + Status + The corresponding Status enum value or else UNKNOWN if the input is not + recognized. + """ + if isinstance(status, Status): + return status + if not isinstance(status, str): + _LOG.warning("Expected type %s for status: %s", type(status), status) + status = str(status) + if status.isdigit(): try: - return Status(int(status_str)) + return Status(int(status)) except ValueError: - _LOG.warning("Unknown status: %d", int(status_str)) + _LOG.warning("Unknown status: %d", int(status)) try: - status_str = status_str.upper().strip() - return Status[status_str] + status = status.upper().strip() + return Status[status] except KeyError: - _LOG.warning("Unknown status: %s", status_str) + _LOG.warning("Unknown status: %s", status) return Status.UNKNOWN def is_good(self) -> bool: @@ -113,4 +129,15 @@ def is_timed_out(self) -> bool: Status.TIMED_OUT, } ) -"""The set of completed statuses.""" +""" +The set of completed statuses. + +Includes all statuses that indicate the trial or experiment has finished, either +successfully or not. +This set is used to determine if a trial or experiment has reached a final state. +This includes: +- :py:attr:`.Status.SUCCEEDED`: The trial or experiment completed successfully. +- :py:attr:`.Status.CANCELED`: The trial or experiment was canceled. +- :py:attr:`.Status.FAILED`: The trial or experiment failed. +- :py:attr:`.Status.TIMED_OUT`: The trial or experiment timed out. +""" diff --git a/mlos_bench/mlos_bench/storage/sql/common.py b/mlos_bench/mlos_bench/storage/sql/common.py index 97eb270c9d..032cf9259d 100644 --- a/mlos_bench/mlos_bench/storage/sql/common.py +++ b/mlos_bench/mlos_bench/storage/sql/common.py @@ -95,7 +95,7 @@ def get_trials( config_id=trial.config_id, ts_start=utcify_timestamp(trial.ts_start, origin="utc"), ts_end=utcify_nullable_timestamp(trial.ts_end, origin="utc"), - status=Status.from_str(trial.status), + status=Status.parse(trial.status), trial_runner_id=trial.trial_runner_id, ) for trial in trials.fetchall() diff --git a/mlos_bench/mlos_bench/storage/sql/experiment.py b/mlos_bench/mlos_bench/storage/sql/experiment.py index acc2a497b4..0e380e3e13 100644 --- a/mlos_bench/mlos_bench/storage/sql/experiment.py +++ b/mlos_bench/mlos_bench/storage/sql/experiment.py @@ -188,7 +188,7 @@ def load( status: list[Status] = [] for trial in cur_trials.fetchall(): - stat = Status.from_str(trial.status) + stat = Status.parse(trial.status) status.append(stat) trial_ids.append(trial.trial_id) configs.append( @@ -272,7 +272,7 @@ def get_trial_by_id( config_id=trial.config_id, trial_runner_id=trial.trial_runner_id, opt_targets=self._opt_targets, - status=Status.from_str(trial.status), + status=Status.parse(trial.status), restoring=True, config=config, ) @@ -330,7 +330,7 @@ def pending_trials( config_id=trial.config_id, trial_runner_id=trial.trial_runner_id, opt_targets=self._opt_targets, - status=Status.from_str(trial.status), + status=Status.parse(trial.status), restoring=True, config=config, ) diff --git a/mlos_bench/mlos_bench/tests/environments/test_status.py b/mlos_bench/mlos_bench/tests/environments/test_status.py index 3c0a9bccf3..785275825c 100644 --- a/mlos_bench/mlos_bench/tests/environments/test_status.py +++ b/mlos_bench/mlos_bench/tests/environments/test_status.py @@ -51,16 +51,19 @@ def test_status_from_str_valid(input_str: str, expected_status: Status) -> None: Expected Status enum value. """ assert ( - Status.from_str(input_str) == expected_status + Status.parse(input_str) == expected_status ), f"Expected {expected_status} for input: {input_str}" # Check lowercase representation assert ( - Status.from_str(input_str.lower()) == expected_status + Status.parse(input_str.lower()) == expected_status ), f"Expected {expected_status} for input: {input_str.lower()}" + assert ( + Status.parse(expected_status) == expected_status + ), f"Expected {expected_status} for input: {expected_status}" if input_str.isdigit(): # Also test the numeric representation assert ( - Status.from_str(int(input_str)) == expected_status + Status.parse(int(input_str)) == expected_status ), f"Expected {expected_status} for input: {int(input_str)}" @@ -83,7 +86,7 @@ def test_status_from_str_invalid(invalid_input: Any) -> None: input. """ assert ( - Status.from_str(invalid_input) == Status.UNKNOWN + Status.parse(invalid_input) == Status.UNKNOWN ), f"Expected Status.UNKNOWN for invalid input: {invalid_input}"