Skip to content

Commit 4dce274

Browse files
committed
add a unit test for the empty csv results case
1 parent ed18ec1 commit 4dce274

3 files changed

Lines changed: 49 additions & 5 deletions

File tree

mlos_bench/mlos_bench/tests/environments/__init__.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,19 @@
1010
import pytest
1111

1212
from mlos_bench.environments.base_environment import Environment
13+
from mlos_bench.environments.status import Status
1314
from mlos_bench.tunables.tunable_groups import TunableGroups
1415
from mlos_bench.tunables.tunable_types import TunableValue
1516

1617

1718
def check_env_success(
1819
env: Environment,
1920
tunable_groups: TunableGroups,
20-
expected_results: dict[str, TunableValue],
21+
*,
22+
expected_results: dict[str, TunableValue] | None,
2123
expected_telemetry: list[tuple[datetime, str, Any]],
24+
expected_status_run: set[Status] | None = None,
25+
expected_status_next: set[Status] | None = None,
2226
global_config: dict | None = None,
2327
) -> None:
2428
"""
@@ -34,19 +38,37 @@ def check_env_success(
3438
Expected results of the benchmark.
3539
expected_telemetry : list[tuple[datetime, str, Any]]
3640
Expected telemetry data of the benchmark.
41+
expected_status_run : set[Status]
42+
Expected status right after the trial.
43+
Default is the `SUCCEEDED` value.
44+
expected_status_next : set[Status]
45+
Expected status values for the next trial.
46+
Default is the same set as in `.is_good()`.
3747
global_config : dict
3848
Global params.
3949
"""
50+
# pylint: disable=too-many-arguments
51+
if expected_status_run is None:
52+
expected_status_run = {Status.SUCCEEDED}
53+
54+
if expected_status_next is None:
55+
expected_status_next = {
56+
Status.PENDING,
57+
Status.READY,
58+
Status.RUNNING,
59+
Status.SUCCEEDED,
60+
}
61+
4062
with env as env_context:
4163

4264
assert env_context.setup(tunable_groups, global_config)
4365

4466
(status, _ts, data) = env_context.run()
45-
assert status.is_succeeded()
46-
assert data == pytest.approx(expected_results, nan_ok=True)
67+
assert status in expected_status_run
68+
assert data == expected_results or data == pytest.approx(expected_results, nan_ok=True)
4769

4870
(status, _ts, telemetry) = env_context.status()
49-
assert status.is_good()
71+
assert status in expected_status_next
5072
assert telemetry == pytest.approx(expected_telemetry, nan_ok=True)
5173

5274
env_context.teardown()

mlos_bench/mlos_bench/tests/environments/local/local_env_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""Unit tests for LocalEnv benchmark environment."""
66
import pytest
77

8+
from mlos_bench.environments.status import Status
89
from mlos_bench.tests.environments import check_env_success
910
from mlos_bench.tests.environments.local import create_local_env
1011
from mlos_bench.tunables.tunable_groups import TunableGroups
@@ -101,3 +102,24 @@ def test_local_env_wide(tunable_groups: TunableGroups) -> None:
101102
},
102103
expected_telemetry=[],
103104
)
105+
106+
107+
def test_local_env_results_empty_file(tunable_groups: TunableGroups) -> None:
108+
"""When the results file is empty, do not crash but mark the trial FAILED."""
109+
local_env = create_local_env(
110+
tunable_groups,
111+
{
112+
"run": [
113+
"echo 'latency,throughput,score' > output.csv",
114+
],
115+
"read_results_file": "output.csv",
116+
},
117+
)
118+
119+
check_env_success(
120+
local_env,
121+
tunable_groups,
122+
expected_status_run={Status.FAILED},
123+
expected_results=None,
124+
expected_telemetry=[],
125+
)

mlos_bench/mlos_bench/tests/environments/local/local_env_vars_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def _run_local_env(tunable_groups: TunableGroups, shell_subcmd: str, expected: d
3434
},
3535
)
3636

37-
check_env_success(local_env, tunable_groups, expected, [])
37+
check_env_success(local_env, tunable_groups, expected_results=expected, expected_telemetry=[])
3838

3939

4040
@pytest.mark.skipif(sys.platform == "win32", reason="sh-like shell only")

0 commit comments

Comments
 (0)