Skip to content

Commit c060039

Browse files
Refactor storage tests to support testing multiple backends on schema changes (#983)
# Pull Request ## Title Refactor storage tests to support testing multiple backends on schema changes ______________________________________________________________________ ## Description - adds the ability to run mysql and postgres servers in a container as a part of pytest - checks that they succeed with schema setup - fixes the ones that were broken with mysql ______________________________________________________________________ ## Type of Change - 🛠️ Bug fix - 🔄 Refactor - 🧪 Tests ______________________________________________________________________ ## Testing - CI tests ______________________________________________________________________ ## Additional Notes (optional) Part of a series of changes to help improve testing for schedulers. Next will be schema changes to allow storing fractional seconds in MySQL. Merge after #987 ______________________________________________________________________ --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 869001b commit c060039

23 files changed

Lines changed: 568 additions & 79 deletions

mlos_bench/mlos_bench/config/storage/mysql.jsonc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"log_sql": false, // Write all SQL statements to the log.
77
// Parameters below must match kwargs of `sqlalchemy.URL.create()`:
88
"drivername": "mysql+mysqlconnector",
9-
"database": "osat",
9+
"database": "mlos_bench",
1010
"username": "root",
1111
"password": "PLACERHOLDER PASSWORD", // Comes from global config
1212
"host": "localhost",

mlos_bench/mlos_bench/config/storage/postgresql.jsonc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
"log_sql": false, // Write all SQL statements to the log.
99
// Parameters below must match kwargs of `sqlalchemy.URL.create()`:
1010
"drivername": "postgresql+psycopg2",
11-
"database": "osat",
11+
"database": "mlos_bench",
1212
"username": "postgres",
1313
"password": "PLACERHOLDER PASSWORD", // Comes from global config
1414
"host": "localhost",

mlos_bench/mlos_bench/storage/sql/alembic/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ This document contains some notes on how to use [`alembic`](https://alembic.sqla
4545

4646
1. If the migration script works, commit the changes to the [`mlos_bench/storage/sql/schema.py`](../schema.py) and [`mlos_bench/storage/sql/alembic/versions`](./versions/) files.
4747

48-
> Be sure to update the latest version in the [`test_storage_schemas.py`](../../../tests/storage/test_storage_schemas.py) file as well.
48+
> Be sure to update the latest version in the [`test_storage_schemas.py`](../../../tests/storage/sql/test_storage_schemas.py) file as well.
4949
5050
1. Merge that to the `main` branch.
5151

mlos_bench/mlos_bench/storage/sql/schema.py

Lines changed: 57 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,6 @@ class DbSchema:
7272
# for all DB tables, so it's ok to disable the warnings.
7373
# pylint: disable=too-many-instance-attributes
7474

75-
# Common string column sizes.
76-
_ID_LEN = 512
77-
_PARAM_VALUE_LEN = 1024
78-
_METRIC_VALUE_LEN = 255
79-
_STATUS_LEN = 16
80-
8175
def __init__(self, engine: Engine | None):
8276
"""
8377
Declare the SQLAlchemy schema for the database.
@@ -95,10 +89,24 @@ def __init__(self, engine: Engine | None):
9589
self._engine = engine
9690
self._meta = MetaData()
9791

92+
# Common string column sizes.
93+
self._exp_id_len = 512
94+
self._param_id_len = 512
95+
self._param_value_len = 1024
96+
self._metric_id_len = 512
97+
self._metric_value_len = 255
98+
self._status_len = 16
99+
100+
# Some overrides for certain DB engines:
101+
if engine and engine.dialect.name in {"mysql", "mariadb"}:
102+
self._exp_id_len = 255
103+
self._param_id_len = 255
104+
self._metric_id_len = 255
105+
98106
self.experiment = Table(
99107
"experiment",
100108
self._meta,
101-
Column("exp_id", String(self._ID_LEN), nullable=False),
109+
Column("exp_id", String(self._exp_id_len), nullable=False),
102110
Column("description", String(1024)),
103111
Column("root_env_config", String(1024), nullable=False),
104112
Column("git_repo", String(1024), nullable=False),
@@ -108,7 +116,7 @@ def __init__(self, engine: Engine | None):
108116
Column("ts_end", DateTime),
109117
# Should match the text IDs of `mlos_bench.environments.Status` enum:
110118
# For backwards compatibility, we allow NULL for status.
111-
Column("status", String(self._STATUS_LEN)),
119+
Column("status", String(self._status_len)),
112120
# There may be more than one mlos_benchd_service running on different hosts.
113121
# This column stores the host/container name of the driver that
114122
# picked up the experiment.
@@ -126,7 +134,7 @@ def __init__(self, engine: Engine | None):
126134
"objectives",
127135
self._meta,
128136
Column("exp_id"),
129-
Column("optimization_target", String(self._ID_LEN), nullable=False),
137+
Column("optimization_target", String(self._metric_id_len), nullable=False),
130138
Column("optimization_direction", String(4), nullable=False),
131139
# TODO: Note: weight is not fully supported yet as currently
132140
# multi-objective is expected to explore each objective equally.
@@ -175,14 +183,14 @@ def __init__(self, engine: Engine | None):
175183
self.trial = Table(
176184
"trial",
177185
self._meta,
178-
Column("exp_id", String(self._ID_LEN), nullable=False),
186+
Column("exp_id", String(self._exp_id_len), nullable=False),
179187
Column("trial_id", Integer, nullable=False),
180188
Column("config_id", Integer, nullable=False),
181189
Column("trial_runner_id", Integer, nullable=True, default=None),
182190
Column("ts_start", DateTime, nullable=False),
183191
Column("ts_end", DateTime),
184192
# Should match the text IDs of `mlos_bench.environments.Status` enum:
185-
Column("status", String(self._STATUS_LEN), nullable=False),
193+
Column("status", String(self._status_len), nullable=False),
186194
PrimaryKeyConstraint("exp_id", "trial_id"),
187195
ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]),
188196
ForeignKeyConstraint(["config_id"], [self.config.c.config_id]),
@@ -197,8 +205,8 @@ def __init__(self, engine: Engine | None):
197205
"config_param",
198206
self._meta,
199207
Column("config_id", Integer, nullable=False),
200-
Column("param_id", String(self._ID_LEN), nullable=False),
201-
Column("param_value", String(self._PARAM_VALUE_LEN)),
208+
Column("param_id", String(self._param_id_len), nullable=False),
209+
Column("param_value", String(self._param_value_len)),
202210
PrimaryKeyConstraint("config_id", "param_id"),
203211
ForeignKeyConstraint(["config_id"], [self.config.c.config_id]),
204212
)
@@ -212,10 +220,10 @@ def __init__(self, engine: Engine | None):
212220
self.trial_param = Table(
213221
"trial_param",
214222
self._meta,
215-
Column("exp_id", String(self._ID_LEN), nullable=False),
223+
Column("exp_id", String(self._exp_id_len), nullable=False),
216224
Column("trial_id", Integer, nullable=False),
217-
Column("param_id", String(self._ID_LEN), nullable=False),
218-
Column("param_value", String(self._PARAM_VALUE_LEN)),
225+
Column("param_id", String(self._param_id_len), nullable=False),
226+
Column("param_value", String(self._param_value_len)),
219227
PrimaryKeyConstraint("exp_id", "trial_id", "param_id"),
220228
ForeignKeyConstraint(
221229
["exp_id", "trial_id"],
@@ -230,10 +238,10 @@ def __init__(self, engine: Engine | None):
230238
self.trial_status = Table(
231239
"trial_status",
232240
self._meta,
233-
Column("exp_id", String(self._ID_LEN), nullable=False),
241+
Column("exp_id", String(self._exp_id_len), nullable=False),
234242
Column("trial_id", Integer, nullable=False),
235243
Column("ts", DateTime(timezone=True), nullable=False, default="now"),
236-
Column("status", String(self._STATUS_LEN), nullable=False),
244+
Column("status", String(self._status_len), nullable=False),
237245
UniqueConstraint("exp_id", "trial_id", "ts"),
238246
ForeignKeyConstraint(
239247
["exp_id", "trial_id"],
@@ -247,10 +255,10 @@ def __init__(self, engine: Engine | None):
247255
self.trial_result = Table(
248256
"trial_result",
249257
self._meta,
250-
Column("exp_id", String(self._ID_LEN), nullable=False),
258+
Column("exp_id", String(self._exp_id_len), nullable=False),
251259
Column("trial_id", Integer, nullable=False),
252-
Column("metric_id", String(self._ID_LEN), nullable=False),
253-
Column("metric_value", String(self._METRIC_VALUE_LEN)),
260+
Column("metric_id", String(self._metric_id_len), nullable=False),
261+
Column("metric_value", String(self._metric_value_len)),
254262
PrimaryKeyConstraint("exp_id", "trial_id", "metric_id"),
255263
ForeignKeyConstraint(
256264
["exp_id", "trial_id"],
@@ -265,11 +273,11 @@ def __init__(self, engine: Engine | None):
265273
self.trial_telemetry = Table(
266274
"trial_telemetry",
267275
self._meta,
268-
Column("exp_id", String(self._ID_LEN), nullable=False),
276+
Column("exp_id", String(self._exp_id_len), nullable=False),
269277
Column("trial_id", Integer, nullable=False),
270278
Column("ts", DateTime(timezone=True), nullable=False, default="now"),
271-
Column("metric_id", String(self._ID_LEN), nullable=False),
272-
Column("metric_value", String(self._METRIC_VALUE_LEN)),
279+
Column("metric_id", String(self._metric_id_len), nullable=False),
280+
Column("metric_value", String(self._metric_value_len)),
273281
UniqueConstraint("exp_id", "trial_id", "ts", "metric_id"),
274282
ForeignKeyConstraint(
275283
["exp_id", "trial_id"],
@@ -296,6 +304,31 @@ def _get_alembic_cfg(conn: Connection) -> config.Config:
296304
alembic_cfg.attributes["connection"] = conn
297305
return alembic_cfg
298306

307+
def drop_all_tables(self, *, force: bool = False) -> None:
308+
"""
309+
Helper method used in testing to reset the DB schema.
310+
311+
Notes
312+
-----
313+
This method is not intended for production use, as it will drop all tables
314+
in the database. Use with caution.
315+
316+
Parameters
317+
----------
318+
force : bool
319+
If True, drop all tables in the target database.
320+
If False, this method will not drop any tables and will log a warning.
321+
"""
322+
assert self._engine
323+
self.meta.reflect(bind=self._engine)
324+
if force:
325+
self.meta.drop_all(bind=self._engine)
326+
else:
327+
_LOG.warning(
328+
"Resetting the schema without force is not implemented. "
329+
"Use force=True to drop all tables."
330+
)
331+
299332
def create(self) -> "DbSchema":
300333
"""Create the DB schema."""
301334
_LOG.info("Create the DB schema")

mlos_bench/mlos_bench/storage/sql/storage.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""Saving and restoring the benchmark data in SQL database."""
66

77
import logging
8+
from types import TracebackType
89
from typing import Literal
910

1011
from sqlalchemy import URL, Engine, create_engine
@@ -72,6 +73,22 @@ def __setstate__(self, state: dict) -> None:
7273
# Recreate the engine and schema.
7374
self._init_engine()
7475

76+
def dispose(self) -> None:
77+
"""Closes the database connection pool."""
78+
if self._engine:
79+
self._engine.dispose()
80+
_LOG.info("Closed the database connection: %s", self)
81+
82+
def __exit__(
83+
self,
84+
exc_type: type[BaseException] | None, # pylint: disable=unused-argument
85+
exc_val: BaseException | None, # pylint: disable=unused-argument
86+
exc_tb: TracebackType | None, # pylint: disable=unused-argument
87+
) -> Literal[False]:
88+
"""Close the engine connection when exiting the context."""
89+
self.dispose()
90+
return False
91+
7592
@property
7693
def _schema(self) -> DbSchema:
7794
"""Lazily create schema upon first access."""
@@ -82,6 +99,33 @@ def _schema(self) -> DbSchema:
8299
_LOG.debug("DDL statements:\n%s", self._db_schema)
83100
return self._db_schema
84101

102+
def _reset_schema(self, *, force: bool = False) -> None:
103+
"""
104+
Helper method used in testing to reset the DB schema.
105+
106+
Notes
107+
-----
108+
This method is not intended for production use, as it will drop all tables
109+
in the database. Use with caution.
110+
111+
Parameters
112+
----------
113+
force : bool
114+
If True, drop all tables in the target database.
115+
If False, this method will not drop any tables and will log a warning.
116+
"""
117+
assert self._engine
118+
if force:
119+
self._schema.drop_all_tables(force=force)
120+
self._db_schema = DbSchema(self._engine)
121+
self._schema_created = False
122+
self._schema_updated = False
123+
else:
124+
_LOG.warning(
125+
"Resetting the schema without force is not implemented. "
126+
"Use force=True to drop all tables."
127+
)
128+
85129
def update_schema(self) -> None:
86130
"""Update the database schema."""
87131
if not self._schema_updated:

mlos_bench/mlos_bench/tests/__init__.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
Used to make mypy happy about multiple conftest.py modules.
99
"""
1010
import filecmp
11+
import json
1112
import os
1213
import shutil
1314
import socket
@@ -17,6 +18,7 @@
1718

1819
import pytest
1920
import pytz
21+
from pytest_docker.plugin import Services as DockerServices
2022

2123
from mlos_bench.util import get_class_from_name, nullable
2224

@@ -87,6 +89,48 @@ def check_class_name(obj: object, expected_class_name: str) -> bool:
8789
return full_class_name == try_resolve_class_name(expected_class_name)
8890

8991

92+
def is_docker_service_healthy(
93+
compose_project_name: str,
94+
service_name: str,
95+
) -> bool:
96+
"""Check if a docker service is healthy."""
97+
docker_ps_out = run(
98+
f"docker compose -p {compose_project_name} " f"ps --format json {service_name}",
99+
shell=True,
100+
check=True,
101+
capture_output=True,
102+
)
103+
docker_ps_json = json.loads(docker_ps_out.stdout.decode().strip())
104+
state = docker_ps_json["State"]
105+
assert isinstance(state, str)
106+
health = docker_ps_json["Health"]
107+
assert isinstance(health, str)
108+
return state == "running" and health == "healthy"
109+
110+
111+
def wait_docker_service_healthy(
112+
docker_services: DockerServices,
113+
project_name: str,
114+
service_name: str,
115+
timeout: float = 30.0,
116+
) -> None:
117+
"""Wait until a docker service is healthy."""
118+
docker_services.wait_until_responsive(
119+
check=lambda: is_docker_service_healthy(project_name, service_name),
120+
timeout=timeout,
121+
pause=0.5,
122+
)
123+
124+
125+
def wait_docker_service_socket(docker_services: DockerServices, hostname: str, port: int) -> None:
126+
"""Wait until a docker service is ready."""
127+
docker_services.wait_until_responsive(
128+
check=lambda: check_socket(hostname, port),
129+
timeout=30.0,
130+
pause=0.5,
131+
)
132+
133+
90134
def check_socket(host: str, port: int, timeout: float = 1.0) -> bool:
91135
"""
92136
Test to see if a socket is open.

mlos_bench/mlos_bench/tests/config/schedulers/test_load_scheduler_config_examples.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from mlos_bench.schedulers.base_scheduler import Scheduler
1515
from mlos_bench.schedulers.trial_runner import TrialRunner
1616
from mlos_bench.services.config_persistence import ConfigPersistenceService
17-
from mlos_bench.storage.sql.storage import SqlStorage
17+
from mlos_bench.storage.base_storage import Storage
1818
from mlos_bench.tests.config import BUILTIN_TEST_CONFIG_PATH, locate_config_examples
1919
from mlos_bench.util import get_class_from_name
2020

@@ -58,7 +58,7 @@ def test_load_scheduler_config_examples(
5858
config_path: str,
5959
mock_env_config_path: str,
6060
trial_runners: list[TrialRunner],
61-
storage: SqlStorage,
61+
storage: Storage,
6262
mock_opt: MockOptimizer,
6363
) -> None:
6464
"""Tests loading a config example."""

mlos_bench/mlos_bench/tests/conftest.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from pytest_docker.plugin import get_docker_services
1616

1717
from mlos_bench.environments.mock_env import MockEnv
18-
from mlos_bench.tests import SEED, tunable_groups_fixtures
18+
from mlos_bench.tests import SEED, resolve_host_name, tunable_groups_fixtures
1919
from mlos_bench.tunables.tunable_groups import TunableGroups
2020

2121
# pylint: disable=redefined-outer-name
@@ -29,6 +29,22 @@
2929
covariant_group = tunable_groups_fixtures.covariant_group
3030

3131

32+
HOST_DOCKER_NAME = "host.docker.internal"
33+
34+
35+
@pytest.fixture(scope="session")
36+
def docker_hostname() -> str:
37+
"""Returns the local hostname to use to connect to the test ssh server."""
38+
if sys.platform != "win32" and resolve_host_name(HOST_DOCKER_NAME):
39+
# On Linux, if we're running in a docker container, we can use the
40+
# --add-host (extra_hosts in docker-compose.yml) to refer to the host IP.
41+
return HOST_DOCKER_NAME
42+
# Docker (Desktop) for Windows (WSL2) uses a special networking magic
43+
# to refer to the host machine as `localhost` when exposing ports.
44+
# In all other cases, assume we're executing directly inside conda on the host.
45+
return "127.0.0.1" # "localhost"
46+
47+
3248
@pytest.fixture
3349
def mock_env(tunable_groups: TunableGroups) -> MockEnv:
3450
"""Test fixture for MockEnv."""
@@ -90,6 +106,7 @@ def docker_compose_file(pytestconfig: pytest.Config) -> list[str]:
90106
_ = pytestconfig # unused
91107
return [
92108
os.path.join(os.path.dirname(__file__), "services", "remote", "ssh", "docker-compose.yml"),
109+
os.path.join(os.path.dirname(__file__), "storage", "sql", "docker-compose.yml"),
93110
# Add additional configs as necessary here.
94111
]
95112

mlos_bench/mlos_bench/tests/environments/remote/conftest.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,3 @@
88

99
# Expose some of those as local names so they can be picked up as fixtures by pytest.
1010
ssh_test_server = ssh_fixtures.ssh_test_server
11-
ssh_test_server_hostname = ssh_fixtures.ssh_test_server_hostname

0 commit comments

Comments
 (0)