Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 39 additions & 12 deletions mlos_bench/mlos_bench/environments/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,37 @@ class Status(enum.Enum):
TIMED_OUT = 7

@staticmethod
def from_str(status_str: Any) -> "Status":
"""Convert a string to a Status enum."""
if not isinstance(status_str, str):
_LOG.warning("Expected type %s for status: %s", type(status_str), status_str)
status_str = str(status_str)
if status_str.isdigit():
def parse(status: Any) -> "Status":
"""
Convert the input to a Status enum.

Parameters
----------
status : Any
The status to parse. This can be a string (or string convertible),
int, or Status enum.

Returns
-------
Status
The corresponding Status enum value or else UNKNOWN if the input is not
recognized.
"""
if isinstance(status, Status):
return status
if not isinstance(status, str):
Comment thread
bpkroth marked this conversation as resolved.
_LOG.warning("Expected type %s for status: %s", type(status), status)
status = str(status)
if status.isdigit():
try:
return Status(int(status_str))
return Status(int(status))
except ValueError:
_LOG.warning("Unknown status: %d", int(status_str))
_LOG.warning("Unknown status: %d", int(status))
try:
status_str = status_str.upper().strip()
return Status[status_str]
status = status.upper().strip()
return Status[status]
except KeyError:
_LOG.warning("Unknown status: %s", status_str)
_LOG.warning("Unknown status: %s", status)
return Status.UNKNOWN

def is_good(self) -> bool:
Expand Down Expand Up @@ -113,4 +129,15 @@ def is_timed_out(self) -> bool:
Status.TIMED_OUT,
}
)
"""The set of completed statuses."""
"""
The set of completed statuses.

Includes all statuses that indicate the trial or experiment has finished, either
successfully or not.
This set is used to determine if a trial or experiment has reached a final state.
This includes:
- :py:attr:`.Status.SUCCEEDED`: The trial or experiment completed successfully.
- :py:attr:`.Status.CANCELED`: The trial or experiment was canceled.
- :py:attr:`.Status.FAILED`: The trial or experiment failed.
- :py:attr:`.Status.TIMED_OUT`: The trial or experiment timed out.
"""
2 changes: 1 addition & 1 deletion mlos_bench/mlos_bench/storage/sql/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def get_trials(
config_id=trial.config_id,
ts_start=utcify_timestamp(trial.ts_start, origin="utc"),
ts_end=utcify_nullable_timestamp(trial.ts_end, origin="utc"),
status=Status.from_str(trial.status),
status=Status.parse(trial.status),
trial_runner_id=trial.trial_runner_id,
)
for trial in trials.fetchall()
Expand Down
6 changes: 3 additions & 3 deletions mlos_bench/mlos_bench/storage/sql/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def load(
status: list[Status] = []

for trial in cur_trials.fetchall():
stat = Status.from_str(trial.status)
stat = Status.parse(trial.status)
status.append(stat)
trial_ids.append(trial.trial_id)
configs.append(
Expand Down Expand Up @@ -272,7 +272,7 @@ def get_trial_by_id(
config_id=trial.config_id,
trial_runner_id=trial.trial_runner_id,
opt_targets=self._opt_targets,
status=Status.from_str(trial.status),
status=Status.parse(trial.status),
restoring=True,
config=config,
)
Expand Down Expand Up @@ -330,7 +330,7 @@ def pending_trials(
config_id=trial.config_id,
trial_runner_id=trial.trial_runner_id,
opt_targets=self._opt_targets,
status=Status.from_str(trial.status),
status=Status.parse(trial.status),
restoring=True,
config=config,
)
Expand Down
11 changes: 7 additions & 4 deletions mlos_bench/mlos_bench/tests/environments/test_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,19 @@ def test_status_from_str_valid(input_str: str, expected_status: Status) -> None:
Expected Status enum value.
"""
assert (
Status.from_str(input_str) == expected_status
Status.parse(input_str) == expected_status
), f"Expected {expected_status} for input: {input_str}"
# Check lowercase representation
assert (
Status.from_str(input_str.lower()) == expected_status
Status.parse(input_str.lower()) == expected_status
), f"Expected {expected_status} for input: {input_str.lower()}"
assert (
Status.parse(expected_status) == expected_status
), f"Expected {expected_status} for input: {expected_status}"
Comment thread
motus marked this conversation as resolved.
if input_str.isdigit():
# Also test the numeric representation
assert (
Status.from_str(int(input_str)) == expected_status
Status.parse(int(input_str)) == expected_status
), f"Expected {expected_status} for input: {int(input_str)}"


Expand All @@ -83,7 +86,7 @@ def test_status_from_str_invalid(invalid_input: Any) -> None:
input.
"""
assert (
Status.from_str(invalid_input) == Status.UNKNOWN
Status.parse(invalid_input) == Status.UNKNOWN
), f"Expected Status.UNKNOWN for invalid input: {invalid_input}"


Expand Down
Loading