Skip to content

Commit 4ddba9d

Browse files
authored
Merge pull request #375 from chaen/fix_mock_os
tests: fix the mock_osdb
2 parents 2bf4e28 + 2a46eea commit 4ddba9d

File tree

2 files changed

+24
-9
lines changed

2 files changed

+24
-9
lines changed

diracx-testing/src/diracx/testing/mock_osdb.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,18 +72,22 @@ async def client_context(self) -> AsyncIterator[None]:
7272
yield
7373

7474
async def __aenter__(self):
75-
await self._sql_db.__aenter__()
75+
"""Enter the request context.
76+
77+
This is a no-op as the real OpenSearch class doesn't use transactions.
78+
Instead we enter a transaction in each method that needs it.
79+
"""
7680
return self
7781

7882
async def __aexit__(self, exc_type, exc_value, traceback):
79-
await self._sql_db.__aexit__(exc_type, exc_value, traceback)
83+
pass
8084

8185
async def create_index_template(self) -> None:
8286
async with self._sql_db.engine.begin() as conn:
8387
await conn.run_sync(self._sql_db.metadata.create_all)
8488

8589
async def upsert(self, doc_id, document) -> None:
86-
async with self:
90+
async with self._sql_db:
8791
values = {}
8892
for key, value in document.items():
8993
if key in self.fields:
@@ -106,7 +110,7 @@ async def search(
106110
per_page: int = 100,
107111
page: int | None = None,
108112
) -> tuple[int, list[dict[Any, Any]]]:
109-
async with self:
113+
async with self._sql_db:
110114
# Apply selection
111115
if parameters:
112116
columns = []
@@ -150,7 +154,8 @@ async def search(
150154
return results
151155

152156
async def ping(self):
153-
return await self._sql_db.ping()
157+
async with self._sql_db:
158+
return await self._sql_db.ping()
154159

155160

156161
def fake_available_osdb_implementations(name, *, real_available_implementations):

diracx-testing/src/diracx/testing/utils.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ def configure(self, enabled_dependencies):
252252
assert (
253253
self.app.dependency_overrides == {} and self.app.lifetime_functions == []
254254
), "configure cannot be nested"
255+
255256
for k, v in self.all_dependency_overrides.items():
256257

257258
class_name = k.__self__.__name__
@@ -284,17 +285,26 @@ async def create_db_schemas(self):
284285
import sqlalchemy
285286
from sqlalchemy.util.concurrency import greenlet_spawn
286287

288+
from diracx.db.os.utils import BaseOSDB
287289
from diracx.db.sql.utils import BaseSQLDB
290+
from diracx.testing.mock_osdb import MockOSDBMixin
288291

289292
for k, v in self.app.dependency_overrides.items():
290-
# Ignore dependency overrides which aren't BaseSQLDB.transaction
291-
if (
292-
isinstance(v, UnavailableDependency)
293-
or k.__func__ != BaseSQLDB.transaction.__func__
293+
# Ignore dependency overrides which aren't BaseSQLDB.transaction or BaseOSDB.session
294+
if isinstance(v, UnavailableDependency) or k.__func__ not in (
295+
BaseSQLDB.transaction.__func__,
296+
BaseOSDB.session.__func__,
294297
):
298+
295299
continue
300+
296301
# The first argument of the overridden BaseSQLDB.transaction is the DB object
297302
db = v.args[0]
303+
# We expect the OS DB to be mocked with sqlite, so use the
304+
# internal DB
305+
if isinstance(db, MockOSDBMixin):
306+
db = db._sql_db
307+
298308
assert isinstance(db, BaseSQLDB), (k, db)
299309

300310
# set PRAGMA foreign_keys=ON if sqlite

0 commit comments

Comments
 (0)