Skip to content

Commit bd2ff97

Browse files
Jiawei Yangmeta-codesync[bot]
authored andcommitted
Make SQA store SQLAlchemy 2.0-compatible (facebook#5201)
Summary: Pull Request resolved: facebook#5201 The OSS Ax SQA store has a hard guard in with_db_settings_base.py that raises IncompatibleDependencyVersion when SQLAlchemy major > 1, disabling SQL storage entirely. This blocks SA 2.0 adoption (T163607006) and Python 3.13/3.14 (which auto-select SA 2.0.48 from third-party). Two additional SA 2.0 incompatibilities exist in OSS: defer("col_name") in load.py and reduced_state.py, which SA 2.0 rejects in favor of class-bound attribute references. This diff removes the guard and converts the string-based loader options to attribute references. Adds a dual-version Buck test target tests_sa2 via constraint_overrides plus a self-proving TestSQLAlchemyDualVersionCompat class so each target proves its constraint took effect (EXPECTED_SA_MAJOR env var asserted at runtime). Reviewed By: yangjoanna, andycylmeta Differential Revision: D104875017 fbshipit-source-id: 6835c4753e251161c13eff2a452aeb50ea439001
1 parent 14d3475 commit bd2ff97

8 files changed

Lines changed: 75 additions & 25 deletions

File tree

ax/storage/sqa_store/db.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def init_engine_and_session_factory(
164164

165165
if SESSION_FACTORY is not None:
166166
if force_init:
167+
# pyre-ignore[16]: SA 2.0 bind is Union; runtime Engine.
167168
SESSION_FACTORY.bind.dispose()
168169
else:
169170
return
@@ -204,6 +205,7 @@ def init_test_engine_and_session_factory(
204205

205206
if SESSION_FACTORY is not None:
206207
if force_init:
208+
# pyre-ignore[16]: SA 2.0 bind is Union; runtime Engine.
207209
SESSION_FACTORY.bind.dispose()
208210
else:
209211
return
@@ -262,6 +264,7 @@ def get_engine() -> Engine:
262264
global SESSION_FACTORY
263265
if SESSION_FACTORY is None:
264266
raise ValueError("Engine must be initialized first.")
267+
# pyre-ignore[7]: SA 2.0 bind is Union; runtime Engine.
265268
return SESSION_FACTORY.bind
266269

267270

@@ -330,5 +333,6 @@ def session_context(
330333
finally:
331334
# Restore the old session factory
332335
session_factory.close()
336+
# pyre-ignore[16]: SA 2.0 bind is Union; runtime Engine.
333337
session_factory.bind.dispose()
334338
SESSION_FACTORY = old_session

ax/storage/sqa_store/decoder.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,10 @@ def _auxiliary_experiments_by_purpose_from_experiment_sqa(
194194
):
195195
continue
196196
aux_experiment = auxiliary_experiment_from_name(
197+
# pyre-ignore[6]: SA 2.0 Column[T] vs plain T param.
197198
experiment_name=auxiliary_experiment_sqa.source_experiment.name,
198199
config=self.config,
200+
# pyre-ignore[6]: SA 2.0 Column[T] vs plain T param.
199201
is_active=auxiliary_experiment_sqa.is_active,
200202
reduced_state=reduced_state,
201203
)
@@ -227,6 +229,7 @@ def _init_experiment_from_sqa(
227229
# so need to convert it to regular dict.
228230
properties = dict(experiment_sqa.properties or {})
229231
if Keys.LLM_MESSAGES in properties:
232+
# pyre-ignore[6]: SA 2.0 properties values are ColumnElement; runtime list.
230233
properties[Keys.LLM_MESSAGES] = [
231234
LLMMessage(**m) for m in properties[Keys.LLM_MESSAGES]
232235
]
@@ -249,7 +252,9 @@ def _init_experiment_from_sqa(
249252
raise SQADecodeError("Experiment SearchSpace cannot be None.")
250253
status_quo = (
251254
Arm(
255+
# pyre-ignore[6]: SA 2.0 Column[T] vs plain T param.
252256
parameters=experiment_sqa.status_quo_parameters,
257+
# pyre-ignore[6]: SA 2.0 Column[T] vs plain T param.
253258
name=experiment_sqa.status_quo_name,
254259
)
255260
if experiment_sqa.status_quo_parameters is not None
@@ -295,6 +300,7 @@ def _init_mt_experiment_from_sqa(
295300
"""First step of conversion within experiment_from_sqa."""
296301
properties = dict(experiment_sqa.properties or {})
297302
if Keys.LLM_MESSAGES in properties:
303+
# pyre-ignore[6]: SA 2.0 properties values are ColumnElement; runtime list.
298304
properties[Keys.LLM_MESSAGES] = [
299305
LLMMessage(**m) for m in properties[Keys.LLM_MESSAGES]
300306
]
@@ -317,7 +323,9 @@ def _init_mt_experiment_from_sqa(
317323
raise SQADecodeError("Experiment SearchSpace cannot be None.")
318324
status_quo = (
319325
Arm(
326+
# pyre-ignore[6]: SA 2.0 Column[T] vs plain T param.
320327
parameters=experiment_sqa.status_quo_parameters,
328+
# pyre-ignore[6]: SA 2.0 Column[T] vs plain T param.
321329
name=experiment_sqa.status_quo_name,
322330
)
323331
if experiment_sqa.status_quo_parameters is not None

ax/storage/sqa_store/json.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class JSONEncodedLongText(JSONEncodedObject):
9393
impl = Text(LONGTEXT_BYTES)
9494

9595

96+
# pyre-ignore[9]: SA 2.0 typed as_mutable returns TypeEngine; runtime TypeDecorator.
9697
JSONEncodedList: TypeDecorator = MutableList.as_mutable(JSONEncodedObject)
9798
JSONEncodedDict: TypeDecorator = MutableDict.as_mutable(JSONEncodedObject)
9899
JSONEncodedTextDict: TypeDecorator = MutableDict.as_mutable(JSONEncodedText)

ax/storage/sqa_store/load.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -614,10 +614,12 @@ def get_generation_strategy_sqa_reduced_state(
614614
gr_sqa_class.parameter_constraints
615615
),
616616
defaultload(gs_sqa_class.generator_runs).lazyload(gr_sqa_class.metrics),
617-
defaultload(gs_sqa_class.generator_runs).defer("model_kwargs"),
618-
defaultload(gs_sqa_class.generator_runs).defer("bridge_kwargs"),
619-
defaultload(gs_sqa_class.generator_runs).defer("model_state_after_gen"),
620-
defaultload(gs_sqa_class.generator_runs).defer("gen_metadata"),
617+
defaultload(gs_sqa_class.generator_runs).defer(gr_sqa_class.model_kwargs),
618+
defaultload(gs_sqa_class.generator_runs).defer(gr_sqa_class.bridge_kwargs),
619+
defaultload(gs_sqa_class.generator_runs).defer(
620+
gr_sqa_class.model_state_after_gen
621+
),
622+
defaultload(gs_sqa_class.generator_runs).defer(gr_sqa_class.gen_metadata),
621623
],
622624
)
623625

ax/storage/sqa_store/reduced_state.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,5 @@ def get_query_options_to_defer_large_model_cols() -> list[strategy_options.Load]
5858
when loading experiment and generation strategy in reduced state.
5959
"""
6060
return [
61-
defaultload(SQATrial.generator_runs).defer(col.key)
62-
for col in GR_LARGE_MODEL_ATTRS
61+
defaultload(SQATrial.generator_runs).defer(col) for col in GR_LARGE_MODEL_ATTRS
6362
]

ax/storage/sqa_store/tests/test_sqa_store.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,9 +240,12 @@ def test_connection_to_db_with_creator(self) -> None:
240240
)
241241
with session_scope() as session:
242242
engine = session.bind
243+
# pyre-ignore[16]: SA 2.0 bind is Union; runtime Engine.
243244
engine.connect()
244245
self.assertEqual(mocked_dbapi.connect.call_count, 1)
246+
# pyre-ignore[16]: SA 2.0 bind is Union; runtime Engine.
245247
self.assertTrue(engine.echo)
248+
# pyre-ignore[16]: SA 2.0 bind is Union; runtime Engine.
246249
self.assertEqual(engine.pool.size(), 2)
247250

248251
def test_connection_to_db_with_session_context(self) -> None:
@@ -279,10 +282,12 @@ def test_generator_run_type_validation(self) -> None:
279282

280283
generator_run._generator_run_type = "STATUS_QUO"
281284
generator_run_sqa = self.encoder.generator_run_to_sqa(generator_run)
285+
# pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine.
282286
generator_run_sqa.generator_run_type = 2
283287
with self.assertRaises(SQADecodeError):
284288
self.decoder.generator_run_from_sqa(generator_run_sqa, False, False)
285289

290+
# pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine.
286291
generator_run_sqa.generator_run_type = 0
287292
self.decoder.generator_run_from_sqa(generator_run_sqa, False, False)
288293

ax/storage/sqa_store/tests/test_with_db_settings_base.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,18 @@
66

77
# pyre-strict
88
import logging
9+
import os
910
import random
1011
import string
1112
from unittest.mock import patch
1213

14+
import sqlalchemy
1315
from ax.core.base_trial import BaseTrial
1416
from ax.core.experiment import Experiment
1517
from ax.core.trial import Trial
1618
from ax.core.trial_status import TrialStatus
1719
from ax.generation_strategy.generation_strategy import GenerationStrategy
20+
from ax.storage.sqa_store import with_db_settings_base as _wdb_module
1821
from ax.storage.sqa_store.db import init_test_engine_and_session_factory
1922
from ax.storage.sqa_store.load import (
2023
_load_experiment,
@@ -396,3 +399,47 @@ def test_try_load_generation_strategy(self) -> None:
396399
lg.output[0],
397400
)
398401
self.assertEqual(output, generation_strategy)
402+
403+
404+
class TestSQLAlchemyDualVersionCompat(TestCase):
405+
"""Self-proving checks that the dual-version SA 2.0 BUCK targets actually
406+
resolved their constraint_overrides and that the SA 2.0 hard guard is gone.
407+
408+
Part of the SA 2.0 dual-version migration (T163607006).
409+
"""
410+
411+
def test_module_level_dbsettings_is_defined(self) -> None:
412+
"""The SA 2.0 hard guard previously set DBSettings = None at module level
413+
when SA major > 1, breaking WithDBSettingsBase.__init__ type checks. Now
414+
that the guard is removed, DBSettings must always resolve to the real
415+
type when SQLAlchemy is importable. Uses getattr because DBSettings is
416+
conditionally defined in a try/except in with_db_settings_base.
417+
"""
418+
# pyre-ignore[16]: DBSettings is conditionally defined in with_db_settings_base.
419+
module_dbsettings = getattr(_wdb_module, "DBSettings", None)
420+
self.assertIsNotNone(
421+
module_dbsettings,
422+
"with_db_settings_base.DBSettings is None -- guard removal regressed",
423+
)
424+
self.assertIs(module_dbsettings, DBSettings)
425+
426+
def test_sa_major_matches_buck_target(self) -> None:
427+
"""When the BUCK target sets EXPECTED_SA_MAJOR, assert the runtime
428+
SQLAlchemy major matches. Makes :tests vs :tests_sa2 self-proving.
429+
Skipped when EXPECTED_SA_MAJOR is unset (e.g., local one-off invocations).
430+
"""
431+
expected_major_str = os.environ.get("EXPECTED_SA_MAJOR")
432+
if expected_major_str is None:
433+
self.skipTest(
434+
"EXPECTED_SA_MAJOR not set; only enforced under the dual-version "
435+
"BUCK targets that pin SQLAlchemy via constraint_overrides"
436+
)
437+
# pyre-ignore[16]: Module `sqlalchemy` has no attribute `__version__`.
438+
actual_version = sqlalchemy.__version__
439+
actual_major = int(actual_version.split(".")[0])
440+
self.assertEqual(
441+
actual_major,
442+
int(expected_major_str),
443+
f"BUCK target expected SQLAlchemy major {expected_major_str}, "
444+
f"got {actual_version}",
445+
)

ax/storage/sqa_store/with_db_settings_base.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
# pyre-strict
88

9-
import re
109
import time
1110
from collections.abc import Sequence
1211
from logging import INFO, Logger
@@ -16,11 +15,7 @@
1615
from ax.core.experiment import Experiment
1716
from ax.core.generator_run import GeneratorRun
1817
from ax.core.runner import Runner
19-
from ax.exceptions.core import (
20-
IncompatibleDependencyVersion,
21-
ObjectNotFoundError,
22-
UnsupportedError,
23-
)
18+
from ax.exceptions.core import ObjectNotFoundError, UnsupportedError
2419
from ax.generation_strategy.generation_strategy import GenerationStrategy
2520
from ax.utils.common.executils import retry_on_exception
2621
from ax.utils.common.logger import _round_floats_for_logging, get_logger
@@ -37,18 +32,7 @@
3732
from sqlalchemy import __version__ as sqa_version
3833

3934
# pyre-fixme[16]: Module `sqlalchemy` has no attribute `__version__`.
40-
sqa_major_version = int(none_throws(re.match(r"^\d*", sqa_version))[0])
41-
if sqa_major_version > 1:
42-
msg = (
43-
"Ax currently requires a sqlalchemy version below 2.0. This will be "
44-
"addressed in a future release. Disabling SQL storage in Ax for now, if "
45-
"you would like to use SQL storage please install Ax with mysql extras "
46-
"via `pip install ax-platform[mysql]`."
47-
)
48-
49-
logger.warning(msg)
50-
51-
raise IncompatibleDependencyVersion(msg)
35+
logger.info(f"Ax SQL storage initialized with SQLAlchemy {sqa_version}")
5236

5337
from ax.storage.sqa_store.db import init_engine_and_session_factory
5438
from ax.storage.sqa_store.decoder import Decoder
@@ -78,7 +62,7 @@
7862

7963
# We retry on `OperationalError` if saving to DB.
8064
RETRY_EXCEPTION_TYPES = (OperationalError, StaleDataError)
81-
except (ModuleNotFoundError, IncompatibleDependencyVersion, TypeError):
65+
except (ModuleNotFoundError, TypeError):
8266
DBSettings = None
8367
TDBSettings = None
8468
Decoder = None

0 commit comments

Comments
 (0)