1010import pytest
1111
1212from mlos_bench .environments .base_environment import Environment
13+ from mlos_bench .environments .status import Status
1314from mlos_bench .tunables .tunable_groups import TunableGroups
1415from mlos_bench .tunables .tunable_types import TunableValue
1516
1617
1718def check_env_success (
1819 env : Environment ,
1920 tunable_groups : TunableGroups ,
20- expected_results : dict [str , TunableValue ],
21+ * ,
22+ expected_results : dict [str , TunableValue ] | None ,
2123 expected_telemetry : list [tuple [datetime , str , Any ]],
24+ expected_status_run : set [Status ] | None = None ,
25+ expected_status_next : set [Status ] | None = None ,
2226 global_config : dict | None = None ,
2327) -> None :
2428 """
@@ -34,19 +38,37 @@ def check_env_success(
3438 Expected results of the benchmark.
3539 expected_telemetry : list[tuple[datetime, str, Any]]
3640 Expected telemetry data of the benchmark.
41+ expected_status_run : set[Status]
42+ Expected status right after the trial.
43+ Default is the `SUCCEEDED` value.
44+ expected_status_next : set[Status]
45+ Expected status values for the next trial.
46+ Default is the same set as in `.is_good()`.
3747 global_config : dict
3848 Global params.
3949 """
50+ # pylint: disable=too-many-arguments
51+ if expected_status_run is None :
52+ expected_status_run = {Status .SUCCEEDED }
53+
54+ if expected_status_next is None :
55+ expected_status_next = {
56+ Status .PENDING ,
57+ Status .READY ,
58+ Status .RUNNING ,
59+ Status .SUCCEEDED ,
60+ }
61+
4062 with env as env_context :
4163
4264 assert env_context .setup (tunable_groups , global_config )
4365
4466 (status , _ts , data ) = env_context .run ()
45- assert status . is_succeeded ()
46- assert data == pytest .approx (expected_results , nan_ok = True )
67+ assert status in expected_status_run
68+ assert data == expected_results or data == pytest .approx (expected_results , nan_ok = True )
4769
4870 (status , _ts , telemetry ) = env_context .status ()
49- assert status . is_good ()
71+ assert status in expected_status_next
5072 assert telemetry == pytest .approx (expected_telemetry , nan_ok = True )
5173
5274 env_context .teardown ()
0 commit comments