diff --git a/CHANGELOG.rst b/CHANGELOG.rst index be14099cd..36b51e657 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -25,6 +25,8 @@ Fixed - ``MigrationRecorder`` no longer emits tortoise's own ``pk`` field ``DeprecationWarning`` when applying migrations; it now builds its bookkeeping model with ``primary_key=True``. (#2203) - ``QuerySet.count()`` now matches the limited query result for the LIMIT/OFFSET edge cases: it returns ``0`` (instead of a negative number) when ``offset()`` exceeds the total row count, and ``0`` (instead of the total) for ``limit(0)``. (#2208) - Field declarations on models now resolve to their concrete type (e.g. ``CharField[str]``) in Pyright/Pylance instead of ``Field[Unknown]``; the ``Field.__new__`` type-check stub now returns ``Self``. (#2216) +- ``TransactionContext`` now returns a ``TransactionalDBClient`` instead of a raw database connection. This change gives the correct inferred type for the transaction context. (#2232) + 1.1.7 ----- diff --git a/tortoise/backends/base/client.py b/tortoise/backends/base/client.py index 5cc91ad9d..66ad46e69 100644 --- a/tortoise/backends/base/client.py +++ b/tortoise/backends/base/client.py @@ -306,14 +306,14 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self._lock.release() -class TransactionContext(Generic[T_conn]): +class TransactionContext: """A context manager interface for transactions. It is returned from in_transaction and _in_transaction.""" client: TransactionalDBClient @abc.abstractmethod - async def __aenter__(self) -> T_conn: ... + async def __aenter__(self) -> TransactionalDBClient: ... @abc.abstractmethod async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: ... diff --git a/tortoise/backends/sqlite/client.py b/tortoise/backends/sqlite/client.py index ca66c43d6..bc9f0e064 100644 --- a/tortoise/backends/sqlite/client.py +++ b/tortoise/backends/sqlite/client.py @@ -16,7 +16,6 @@ Capabilities, ConnectionWrapper, NestedTransactionContext, - T_conn, TransactionalDBClient, TransactionContext, ) @@ -191,7 +190,7 @@ async def ensure_connection(self) -> None: await self.connection._parent.create_connection(with_db=True) self.connection._connection = self.connection._parent._connection - async def __aenter__(self) -> T_conn: + async def __aenter__(self) -> TransactionalDBClient: await self._trxlock.acquire() await self.ensure_connection() self.token = get_connections().set(self.connection_name, self.connection)