From f9c0f5c7d4105546d84f8369cba59dd6918f737b Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Mon, 5 Jan 2026 13:48:46 +0800 Subject: [PATCH 1/7] Add pre-commit hooks and scripts for async method checks in PGMQueue Distinguish sync and async operations in PGMQueue - Introduced a pre-commit hook to check for missing async methods in PGMQueue. - Added scripts to identify and generate missing async methods. - Created utility functions for AST manipulation and method transformation. - Established configuration for project paths and console output. --- pgmq_sqlalchemy/queue.py | 800 ++------------------------------------- 1 file changed, 23 insertions(+), 777 deletions(-) diff --git a/pgmq_sqlalchemy/queue.py b/pgmq_sqlalchemy/queue.py index 8d418e9..fea65d3 100644 --- a/pgmq_sqlalchemy/queue.py +++ b/pgmq_sqlalchemy/queue.py @@ -4,7 +4,6 @@ from sqlalchemy.orm import sessionmaker from sqlalchemy.ext.asyncio import create_async_engine - from .schema import Message, QueueMetrics from ._types import ENGINE_TYPE, SESSION_TYPE from ._utils import ( @@ -107,27 +106,16 @@ def __init__( bind=self.engine, class_=get_session_type(self.engine) ) - def _check_pgmq_ext(self) -> None: - """Check if the pgmq extension exists.""" - self._execute_operation(PGMQOperation.check_pgmq_ext, session=None, commit=True) - - async def _check_pgmq_ext_async(self) -> None: - """Check if the pgmq extension exists (async version).""" - await self._execute_async_operation( - PGMQOperation.check_pgmq_ext_async, session=None, commit=True - ) + async def _check_pg_partman_ext_async(self) -> None: + """Check if the pg_partman extension exists.""" + async with self.session_maker() as session: + await PGMQOperation.check_pg_partman_ext_async(session=session, commit=True) - def _check_pg_partman_ext(self) -> None: + def _check_pg_partman_ext_sync(self) -> None: """Check if the pg_partman extension exists.""" - self._execute_operation( - PGMQOperation.check_pg_partman_ext, session=None, commit=True - ) + with self.session_maker() as session: + PGMQOperation.check_pg_partman_ext(session=session, commit=True) - async def _check_pg_partman_ext_async(self) -> None: - """Check if the pg_partman extension exists (async version).""" - await self._execute_async_operation( - PGMQOperation.check_pg_partman_ext_async, session=None, commit=True - ) def _execute_operation( self, @@ -211,38 +199,6 @@ def create_queue( unlogged, ) - async def create_queue_async( - self, - queue_name: str, - unlogged: bool = False, - *, - session: Optional[SESSION_TYPE] = None, - commit: bool = True, - ) -> None: - """ - .. _unlogged_table: https://www.postgresql.org/docs/current/sql-createtable.html#SQL-CREATETABLE-UNLOGGED - .. |unlogged_table| replace:: **UNLOGGED TABLE** - - **Create a new queue.** - - * if ``unlogged`` is ``True``, the queue will be created as an |unlogged_table|_ . - * ``queue_name`` must be **less than 48 characters**. - - .. code-block:: python - - await pgmq_client.create_queue_async('my_queue') - # or unlogged table queue - await pgmq_client.create_queue_async('my_queue', unlogged=True) - - """ - return await self._execute_async_operation( - PGMQOperation.create_queue_async, - session, - commit, - queue_name, - unlogged, - ) - def create_partitioned_queue( self, queue_name: str, @@ -296,59 +252,6 @@ def create_partitioned_queue( str(retention_interval), ) - async def create_partitioned_queue_async( - self, - queue_name: str, - partition_interval: int = 10000, - retention_interval: int = 100000, - *, - session: Optional[SESSION_TYPE] = None, - commit: bool = True, - ) -> None: - """Create a new **partitioned** queue. - - .. _pgmq_partitioned_queue: https://github.com/tembo-io/pgmq?tab=readme-ov-file#partitioned-queues - .. |pgmq_partitioned_queue| replace:: **PGMQ: Partitioned Queues** - - .. code-block:: python - - # Numeric partitioning (by msg_id) - await pgmq_client.create_partitioned_queue_async('my_partitioned_queue', partition_interval=10000, retention_interval=100000) - - # Time-based partitioning (by enqueued_at) - await pgmq_client.create_partitioned_queue_async('my_time_queue', partition_interval='1 day', retention_interval='7 days') - - Args: - queue_name (str): The name of the queue, should be less than 48 characters. - partition_interval (Union[int, str]): For numeric partitioning, the number of messages per partition. - For time-based partitioning, a PostgreSQL interval string (e.g., '1 day', '1 hour'). - retention_interval (Union[int, str]): For numeric partitioning, messages with msg_id less than max(msg_id) - retention_interval will be dropped. - For time-based partitioning, a PostgreSQL interval string (e.g., '7 days'). - - .. note:: - | Supports both **numeric** (by ``msg_id``) and **time-based** (by ``enqueued_at``) partitioning. - | For time-based partitioning, use interval strings like '1 day', '1 hour', '7 days', etc. - | For numeric partitioning, use integer values. - - .. important:: - | You must make sure that the ``pg_partman`` extension already **installed** in the Postgres. - | ``pgmq-sqlalchemy`` will **auto create** the ``pg_partman`` extension if it does not exist in the Postgres. - | For more details about ``pgmq`` with ``pg_partman``, checkout the |pgmq_partitioned_queue|_. - - - """ - # check if the pg_partman extension exists before creating a partitioned queue at runtime - await self._check_pg_partman_ext_async() - - return await self._execute_async_operation( - PGMQOperation.create_partitioned_queue_async, - session, - commit, - queue_name, - str(partition_interval), - str(retention_interval), - ) - def validate_queue_name( self, queue_name: str, @@ -366,23 +269,6 @@ def validate_queue_name( queue_name, ) - async def validate_queue_name_async( - self, - queue_name: str, - *, - session: Optional[SESSION_TYPE] = None, - commit: bool = True, - ) -> None: - """ - * Will raise an error if the ``queue_name`` is more than 48 characters. - """ - return await self._execute_async_operation( - PGMQOperation.validate_queue_name_async, - session, - commit, - queue_name, - ) - def drop_queue( self, queue: str, @@ -420,43 +306,6 @@ def drop_queue( partitioned, ) - async def drop_queue_async( - self, - queue: str, - partitioned: bool = False, - *, - session: Optional[SESSION_TYPE] = None, - commit: bool = True, - ) -> bool: - """Drop a queue. - - .. _drop_queue_method: ref:`pgmq_sqlalchemy.PGMQueue.drop_queue` - .. |drop_queue_method| replace:: :py:meth:`~pgmq_sqlalchemy.PGMQueue.drop_queue` - - .. code-block:: python - - await pgmq_client.drop_queue_async('my_queue') - # for partitioned queue - await pgmq_client.drop_queue_async('my_partitioned_queue', partitioned=True) - - .. warning:: - | All messages and queue itself will be deleted. (``pgmq.q_`` table) - | **Archived tables** (``pgmq.a_`` table **will be dropped as well. )** - | - | See |archive_method|_ for more details. - """ - # check if the pg_partman extension exists before dropping a partitioned queue at runtime - if partitioned: - await self._check_pg_partman_ext_async() - - return await self._execute_async_operation( - PGMQOperation.drop_queue_async, - session, - commit, - queue, - partitioned, - ) - def list_queues( self, *, @@ -476,25 +325,6 @@ def list_queues( commit, ) - async def list_queues_async( - self, - *, - session: Optional[SESSION_TYPE] = None, - commit: bool = True, - ) -> List[str]: - """List all queues. - - .. code-block:: python - - queue_list = await pgmq_client.list_queues_async() - print(queue_list) - """ - return await self._execute_async_operation( - PGMQOperation.list_queues_async, - session, - commit, - ) - def send( self, queue_name: str, @@ -531,42 +361,6 @@ def send( delay, ) - async def send_async( - self, - queue_name: str, - message: dict, - delay: int = 0, - *, - session: Optional[SESSION_TYPE] = None, - commit: bool = True, - ) -> int: - """Send a message to a queue. - - .. code-block:: python - - msg_id = await pgmq_client.send_async('my_queue', {'key': 'value', 'key2': 'value2'}) - print(msg_id) - - Example with delay: - - .. code-block:: python - - msg_id = await pgmq_client.send_async('my_queue', {'key': 'value', 'key2': 'value2'}, delay=10) - msg = await pgmq_client.read_async('my_queue') - assert msg is None - await asyncio.sleep(10) - msg = await pgmq_client.read_async('my_queue') - assert msg is not None - """ - return await self._execute_async_operation( - PGMQOperation.send_async, - session, - commit, - queue_name, - message, - delay, - ) - def send_batch( self, queue_name: str, @@ -597,36 +391,6 @@ def send_batch( delay, ) - async def send_batch_async( - self, - queue_name: str, - messages: List[dict], - delay: int = 0, - *, - session: Optional[SESSION_TYPE] = None, - commit: bool = True, - ) -> List[int]: - """ - Send a batch of messages to a queue. - - .. code-block:: python - - msgs = [{'key': 'value', 'key2': 'value2'}, {'key': 'value', 'key2': 'value2'}] - msg_ids = await pgmq_client.send_batch_async('my_queue', msgs) - print(msg_ids) - # send with delay - msg_ids = await pgmq_client.send_batch_async('my_queue', msgs, delay=10) - - """ - return await self._execute_async_operation( - PGMQOperation.send_batch_async, - session, - commit, - queue_name, - messages, - delay, - ) - def read( self, queue_name: str, @@ -705,165 +469,50 @@ def read( vt, ) - async def read_async( + def read_batch( self, queue_name: str, + batch_size: int = 1, vt: Optional[int] = None, *, session: Optional[SESSION_TYPE] = None, commit: bool = True, - ) -> Optional[Message]: + ) -> Optional[List[Message]]: """ - .. _for_update_skip_locked: https://www.postgresql.org/docs/current/sql-select.html#SQL-FOR-UPDATE-SHARE - .. |for_update_skip_locked| replace:: **FOR UPDATE SKIP LOCKED** - - .. _read_method: ref:`pgmq_sqlalchemy.PGMQueue.read` - .. |read_method| replace:: :py:meth:`~pgmq_sqlalchemy.PGMQueue.read` - - Read a message from the queue. + | Read a batch of messages from the queue. + | Usage: Returns: - |schema_message_class|_ or ``None`` if the queue is empty. - - .. note:: - | ``PGMQ`` use |for_update_skip_locked|_ lock to make sure **a message is only read by one consumer**. - | See the `pgmq.read `_ function for more details. - | - | For **consumer retries mechanism** (e.g. mark a message as failed after a certain number of retries) can be implemented by using the ``read_ct`` field in the |schema_message_class|_ object. - - - .. important:: - | ``vt`` is the **visibility timeout** in seconds. - | When a message is read from the queue, it will be invisible to other consumers for the duration of the ``vt``. - - Usage: + List of |schema_message_class|_ or ``None`` if the queue is empty. .. code-block:: python from pgmq_sqlalchemy.schema import Message - msg:Message = await pgmq_client.read_async('my_queue') - print(msg.msg_id) - print(msg.message) - print(msg.read_ct) # read count, how many times the message has been read - - Example with ``vt``: - - .. code-block:: python - - # assert `read_vt_demo` is empty - await pgmq_client.send_async('read_vt_demo', {'key': 'value', 'key2': 'value2'}) - msg = await pgmq_client.read_async('read_vt_demo', vt=10) - assert msg is not None - - # try to read immediately - msg = await pgmq_client.read_async('read_vt_demo') - assert msg is None # will return None because the message is still invisible - - # try to read after 5 seconds - await asyncio.sleep(5) - msg = await pgmq_client.read_async('read_vt_demo') - assert msg is None # still invisible after 5 seconds - - # try to read after 11 seconds - await asyncio.sleep(6) - msg = await pgmq_client.read_async('read_vt_demo') - assert msg is not None # the message is visible after 10 seconds - + msgs:List[Message] = pgmq_client.read_batch('my_queue', batch_size=10) + # with vt + msgs:List[Message] = pgmq_client.read_batch('my_queue', batch_size=10, vt=10) """ if vt is None: vt = self.vt - return await self._execute_async_operation( - PGMQOperation.read_async, + return self._execute_operation( + PGMQOperation.read_batch, session, commit, queue_name, vt, + batch_size, ) - def read_batch( + def read_with_poll( self, queue_name: str, - batch_size: int = 1, vt: Optional[int] = None, - *, - session: Optional[SESSION_TYPE] = None, - commit: bool = True, - ) -> Optional[List[Message]]: - """ - | Read a batch of messages from the queue. - | Usage: - - Returns: - List of |schema_message_class|_ or ``None`` if the queue is empty. - - .. code-block:: python - - from pgmq_sqlalchemy.schema import Message - - msgs:List[Message] = pgmq_client.read_batch('my_queue', batch_size=10) - # with vt - msgs:List[Message] = pgmq_client.read_batch('my_queue', batch_size=10, vt=10) - - """ - if vt is None: - vt = self.vt - - return self._execute_operation( - PGMQOperation.read_batch, - session, - commit, - queue_name, - vt, - batch_size, - ) - - async def read_batch_async( - self, - queue_name: str, - batch_size: int = 1, - vt: Optional[int] = None, - *, - session: Optional[SESSION_TYPE] = None, - commit: bool = True, - ) -> Optional[List[Message]]: - """ - | Read a batch of messages from the queue. - | Usage: - - Returns: - List of |schema_message_class|_ or ``None`` if the queue is empty. - - .. code-block:: python - - from pgmq_sqlalchemy.schema import Message - - msgs:List[Message] = await pgmq_client.read_batch_async('my_queue', batch_size=10) - # with vt - msgs:List[Message] = await pgmq_client.read_batch_async('my_queue', batch_size=10, vt=10) - - """ - if vt is None: - vt = self.vt - - return await self._execute_async_operation( - PGMQOperation.read_batch_async, - session, - commit, - queue_name, - vt, - batch_size, - ) - - def read_with_poll( - self, - queue_name: str, - vt: Optional[int] = None, - qty: int = 1, - max_poll_seconds: int = 5, - poll_interval_ms: int = 100, + qty: int = 1, + max_poll_seconds: int = 5, + poll_interval_ms: int = 100, *, session: Optional[SESSION_TYPE] = None, commit: bool = True, @@ -930,79 +579,6 @@ def read_with_poll( poll_interval_ms, ) - async def read_with_poll_async( - self, - queue_name: str, - vt: Optional[int] = None, - qty: int = 1, - max_poll_seconds: int = 5, - poll_interval_ms: int = 100, - *, - session: Optional[SESSION_TYPE] = None, - commit: bool = True, - ) -> Optional[List[Message]]: - """ - - .. _read_with_poll_method: ref:`pgmq_sqlalchemy.PGMQueue.read_with_poll` - .. |read_with_poll_method| replace:: :py:meth:`~pgmq_sqlalchemy.PGMQueue.read_with_poll` - - - | Read messages from a queue with long-polling. - | - | When the queue is empty, the function block at most ``max_poll_seconds`` seconds. - | During the polling, the function will check the queue every ``poll_interval_ms`` milliseconds, until the queue has ``qty`` messages. - - Args: - queue_name (str): The name of the queue. - vt (Optional[int]): The visibility timeout in seconds. - qty (int): The number of messages to read. - max_poll_seconds (int): The maximum number of seconds to poll. - poll_interval_ms (int): The interval in milliseconds to poll. - - Returns: - List of |schema_message_class|_ or ``None`` if the queue is empty. - - Usage: - - .. code-block:: python - - msg_id = await pgmq_client.send_async('my_queue', {'key': 'value'}, delay=6) - - # the following code will block for 5 seconds - msgs = await pgmq_client.read_with_poll_async('my_queue', qty=1, max_poll_seconds=5, poll_interval_ms=100) - assert msgs is None - - # try read_with_poll again - # the following code will only block for 1 second - msgs = await pgmq_client.read_with_poll_async('my_queue', qty=1, max_poll_seconds=5, poll_interval_ms=100) - assert msgs is not None - - Another example: - - .. code-block:: python - - msg = {'key': 'value'} - msg_ids = await pgmq_client.send_batch_async('my_queue', [msg, msg, msg, msg], delay=3) - - # the following code will block for 3 seconds - msgs = await pgmq_client.read_with_poll_async('my_queue', qty=3, max_poll_seconds=5, poll_interval_ms=100) - assert len(msgs) == 3 # will read at most 3 messages (qty=3) - - """ - if vt is None: - vt = self.vt - - return await self._execute_async_operation( - PGMQOperation.read_with_poll_async, - session, - commit, - queue_name, - vt, - qty, - max_poll_seconds, - poll_interval_ms, - ) - def set_vt( self, queue_name: str, @@ -1071,81 +647,7 @@ def consumer_with_backoff_retry(pgmq_client: PGMQueue, queue_name: str): return self._execute_operation( PGMQOperation.set_vt, - session, - commit, - queue_name, - msg_id, - vt, - ) - - async def set_vt_async( - self, - queue_name: str, - msg_id: int, - vt: int, - *, - session: Optional[SESSION_TYPE] = None, - commit: bool = True, - ) -> Optional[Message]: - """ - .. _set_vt_method: ref:`pgmq_sqlalchemy.PGMQueue.set_vt` - .. |set_vt_method| replace:: :py:meth:`~pgmq_sqlalchemy.PGMQueue.set_vt` - - Set the visibility timeout for a message. - - Args: - queue_name (str): The name of the queue. - msg_id (int): The message id. - vt (int): The visibility timeout in seconds. - - Returns: - |schema_message_class|_ or ``None`` if the message does not exist. - - Usage: - - .. code-block:: python - - msg_id = await pgmq_client.send_async('my_queue', {'key': 'value'}, delay=10) - msg = await pgmq_client.read_async('my_queue') - assert msg is not None - msg = await pgmq_client.set_vt_async('my_queue', msg.msg_id, 10) - assert msg is not None - - .. tip:: - | |read_method|_ and |set_vt_method|_ can be used together to implement **exponential backoff** mechanism. - | `ref: Exponential Backoff And Jitter `_. - | **For example:** - - .. code-block:: python - - from pgmq_sqlalchemy import PGMQueue - from pgmq_sqlalchemy.schema import Message - - def _exp_backoff_retry(msg: Message)->int: - # exponential backoff retry - if msg.read_ct < 5: - return 2 ** msg.read_ct - return 2 ** 5 - - def consumer_with_backoff_retry(pgmq_client: PGMQueue, queue_name: str): - msg = await pgmq_client.read_async( - queue_name=queue_name, - vt=1000, # set vt to 1000 seconds temporarily - ) - if msg is None: - return - - # set exponential backoff retry - await pgmq_client.set_vt_async( - queue_name=query_name, - msg_id=msg.msg_id, - vt=_exp_backoff_retry(msg) - ) - """ - - return await self._execute_async_operation( - PGMQOperation.set_vt_async, session, commit, queue_name, @@ -1177,30 +679,6 @@ def pop( queue_name, ) - async def pop_async( - self, - queue_name: str, - *, - session: Optional[SESSION_TYPE] = None, - commit: bool = True, - ) -> Optional[Message]: - """ - Reads a single message from a queue and deletes it upon read. - - .. code-block:: python - - msg = await pgmq_client.pop_async('my_queue') - print(msg.msg_id) - print(msg.message) - - """ - return await self._execute_async_operation( - PGMQOperation.pop_async, - session, - commit, - queue_name, - ) - def delete( self, queue_name: str, @@ -1234,39 +712,6 @@ def delete( msg_id, ) - async def delete_async( - self, - queue_name: str, - msg_id: int, - *, - session: Optional[SESSION_TYPE] = None, - commit: bool = True, - ) -> bool: - """ - Delete a message from the queue. - - .. _delete_method: ref:`pgmq_sqlalchemy.PGMQueue.delete` - .. |delete_method| replace:: :py:meth:`~pgmq_sqlalchemy.PGMQueue.delete` - - * Raises an error if the ``queue_name`` does not exist. - * Returns ``True`` if the message is deleted successfully. - * If the message does not exist, returns ``False``. - - .. code-block:: python - - msg_id = await pgmq_client.send_async('my_queue', {'key': 'value'}) - assert await pgmq_client.delete_async('my_queue', msg_id) - assert not await pgmq_client.delete_async('my_queue', msg_id) - - """ - return await self._execute_async_operation( - PGMQOperation.delete_async, - session, - commit, - queue_name, - msg_id, - ) - def delete_batch( self, queue_name: str, @@ -1299,38 +744,6 @@ def delete_batch( msg_ids, ) - async def delete_batch_async( - self, - queue_name: str, - msg_ids: List[int], - *, - session: Optional[SESSION_TYPE] = None, - commit: bool = True, - ) -> List[int]: - """ - Delete a batch of messages from the queue. - - .. _delete_batch_method: ref:`pgmq_sqlalchemy.PGMQueue.delete_batch` - .. |delete_batch_method| replace:: :py:meth:`~pgmq_sqlalchemy.PGMQueue.delete_batch` - - .. note:: - | Instead of return `bool` like |delete_method|_, - | |delete_batch_method|_ will return a list of ``msg_id`` that are successfully deleted. - - .. code-block:: python - - msg_ids = await pgmq_client.send_batch_async('my_queue', [{'key': 'value'}, {'key': 'value'}]) - assert await pgmq_client.delete_batch_async('my_queue', msg_ids) == msg_ids - - """ - return await self._execute_async_operation( - PGMQOperation.delete_batch_async, - session, - commit, - queue_name, - msg_ids, - ) - def archive( self, queue_name: str, @@ -1367,42 +780,6 @@ def archive( msg_id, ) - async def archive_async( - self, - queue_name: str, - msg_id: int, - *, - session: Optional[SESSION_TYPE] = None, - commit: bool = True, - ) -> bool: - """ - Archive a message from a queue. - - .. _archive_method: ref:`pgmq_sqlalchemy.PGMQueue.archive` - .. |archive_method| replace:: :py:meth:`~pgmq_sqlalchemy.PGMQueue.archive` - - - * Message will be deleted from the queue and moved to the archive table. - * Will be deleted from ``pgmq.q_`` and be inserted into the ``pgmq.a_`` table. - * raises an error if the ``queue_name`` does not exist. - * returns ``True`` if the message is archived successfully. - - .. code-block:: python - - msg_id = await pgmq_client.send_async('my_queue', {'key': 'value'}) - assert await pgmq_client.archive_async('my_queue', msg_id) - # since the message is archived, queue will be empty - assert await pgmq_client.read_async('my_queue') is None - - """ - return await self._execute_async_operation( - PGMQOperation.archive_async, - session, - commit, - queue_name, - msg_id, - ) - def archive_batch( self, queue_name: str, @@ -1432,35 +809,6 @@ def archive_batch( msg_ids, ) - async def archive_batch_async( - self, - queue_name: str, - msg_ids: List[int], - *, - session: Optional[SESSION_TYPE] = None, - commit: bool = True, - ) -> List[int]: - """ - Archive multiple messages from a queue. - - * Messages will be deleted from the queue and moved to the archive table. - * Returns a list of ``msg_id`` that are successfully archived. - - .. code-block:: python - - msg_ids = await pgmq_client.send_batch_async('my_queue', [{'key': 'value'}, {'key': 'value'}]) - assert await pgmq_client.archive_batch_async('my_queue', msg_ids) == msg_ids - assert await pgmq_client.read_async('my_queue') is None - - """ - return await self._execute_async_operation( - PGMQOperation.archive_batch_async, - session, - commit, - queue_name, - msg_ids, - ) - def purge( self, queue_name: str, @@ -1486,31 +834,6 @@ def purge( queue_name, ) - async def purge_async( - self, - queue_name: str, - *, - session: Optional[SESSION_TYPE] = None, - commit: bool = True, - ) -> int: - """ - * Delete all messages from a queue, return the number of messages deleted. - * Archive tables will **not** be affected. - - .. code-block:: python - - msg_ids = await pgmq_client.send_batch_async('my_queue', [{'key': 'value'}, {'key': 'value'}]) - assert await pgmq_client.purge_async('my_queue') == 2 - assert await pgmq_client.read_async('my_queue') is None - - """ - return await self._execute_async_operation( - PGMQOperation.purge_async, - session, - commit, - queue_name, - ) - def metrics( self, queue_name: str, @@ -1543,38 +866,6 @@ def metrics( queue_name, ) - async def metrics_async( - self, - queue_name: str, - *, - session: Optional[SESSION_TYPE] = None, - commit: bool = True, - ) -> Optional[QueueMetrics]: - """ - Get metrics for a queue. - - Returns: - |schema_queue_metrics_class|_ or ``None`` if the queue does not exist. - - Usage: - - .. code-block:: python - - from pgmq_sqlalchemy.schema import QueueMetrics - - metrics:QueueMetrics = await pgmq_client.metrics_async('my_queue') - print(metrics.queue_name) - print(metrics.queue_length) - print(metrics.queue_length) - - """ - return await self._execute_async_operation( - PGMQOperation.metrics_async, - session, - commit, - queue_name, - ) - def metrics_all( self, *, @@ -1619,48 +910,3 @@ def metrics_all( session, commit, ) - - async def metrics_all_async( - self, - *, - session: Optional[SESSION_TYPE] = None, - commit: bool = True, - ) -> Optional[List[QueueMetrics]]: - """ - - .. _read_committed_isolation_level: https://www.postgresql.org/docs/current/transaction-iso.html#XACT-READ-COMMITTED - .. |read_committed_isolation_level| replace:: **READ COMMITTED** - - .. _metrics_all_method: ref:`pgmq_sqlalchemy.PGMQueue.metrics_all` - .. |metrics_all_method| replace:: :py:meth:`~pgmq_sqlalchemy.PGMQueue.metrics_all` - - Get metrics for all queues. - - Returns: - List of |schema_queue_metrics_class|_ or ``None`` if there are no queues. - - Usage: - - .. code-block:: python - - from pgmq_sqlalchemy.schema import QueueMetrics - - metrics:List[QueueMetrics] = await pgmq_client.metrics_all_async() - for m in metrics: - print(m.queue_name) - print(m.queue_length) - print(m.queue_length) - - .. warning:: - | You should use a **distributed lock** to avoid **race conditions** when calling |metrics_all_method|_ in **concurrent** |drop_queue_method|_ **scenarios**. - | - | Since the default PostgreSQL isolation level is |read_committed_isolation_level|_, the queue metrics to be fetched **may not exist** if there are **concurrent** |drop_queue_method|_ **operations**. - | Check the `pgmq.metrics_all `_ function for more details. - - - """ - return await self._execute_async_operation( - PGMQOperation.metrics_all_async, - session, - commit, - ) From a507e1d079a14b55447985ca7f61018ce9750d43 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Mon, 5 Jan 2026 19:44:49 +0800 Subject: [PATCH 2/7] Refactor code transformation scripts from ast to libcst (#35) * Initial plan * Refactor AST-based code to use libcst for better code transformation Co-authored-by: jason810496 <68415893+jason810496@users.noreply.github.com> * Fix: only wrap call expressions in await, not literals Co-authored-by: jason810496 <68415893+jason810496@users.noreply.github.com> * Ask whether to apply change * Apply missing async methods * Correct files for check-sync-async-method-for-queue pre-commit * Add check for operation as well --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: jason810496 <68415893+jason810496@users.noreply.github.com> Co-authored-by: LIU ZHE YOU --- pgmq_sqlalchemy/queue.py | 783 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 763 insertions(+), 20 deletions(-) diff --git a/pgmq_sqlalchemy/queue.py b/pgmq_sqlalchemy/queue.py index fea65d3..74de108 100644 --- a/pgmq_sqlalchemy/queue.py +++ b/pgmq_sqlalchemy/queue.py @@ -199,6 +199,38 @@ def create_queue( unlogged, ) + async def create_queue_async( + self, + queue_name: str, + unlogged: bool = False, + *, + session: Optional[SESSION_TYPE] = None, + commit: bool = True, + ) -> None: + """ + .. _unlogged_table: https://www.postgresql.org/docs/current/sql-createtable.html#SQL-CREATETABLE-UNLOGGED + .. |unlogged_table| replace:: **UNLOGGED TABLE** + + **Create a new queue.** + + * if ``unlogged`` is ``True``, the queue will be created as an |unlogged_table|_ . + * ``queue_name`` must be **less than 48 characters**. + + .. code-block:: python + + await pgmq_client.create_queue_async('my_queue') + # or unlogged table queue + await pgmq_client.create_queue_async('my_queue', unlogged=True) + + """ + return await self._execute_async_operation( + PGMQOperation.create_queue_async, + session, + commit, + queue_name, + unlogged, + ) + def create_partitioned_queue( self, queue_name: str, @@ -252,6 +284,59 @@ def create_partitioned_queue( str(retention_interval), ) + async def create_partitioned_queue_async( + self, + queue_name: str, + partition_interval: int = 10000, + retention_interval: int = 100000, + *, + session: Optional[SESSION_TYPE] = None, + commit: bool = True, + ) -> None: + """Create a new **partitioned** queue. + + .. _pgmq_partitioned_queue: https://github.com/tembo-io/pgmq?tab=readme-ov-file#partitioned-queues + .. |pgmq_partitioned_queue| replace:: **PGMQ: Partitioned Queues** + + .. code-block:: python + + # Numeric partitioning (by msg_id) + await pgmq_client.create_partitioned_queue_async('my_partitioned_queue', partition_interval=10000, retention_interval=100000) + + # Time-based partitioning (by enqueued_at) + await pgmq_client.create_partitioned_queue_async('my_time_queue', partition_interval='1 day', retention_interval='7 days') + + Args: + queue_name (str): The name of the queue, should be less than 48 characters. + partition_interval (Union[int, str]): For numeric partitioning, the number of messages per partition. + For time-based partitioning, a PostgreSQL interval string (e.g., '1 day', '1 hour'). + retention_interval (Union[int, str]): For numeric partitioning, messages with msg_id less than max(msg_id) - retention_interval will be dropped. + For time-based partitioning, a PostgreSQL interval string (e.g., '7 days'). + + .. note:: + | Supports both **numeric** (by ``msg_id``) and **time-based** (by ``enqueued_at``) partitioning. + | For time-based partitioning, use interval strings like '1 day', '1 hour', '7 days', etc. + | For numeric partitioning, use integer values. + + .. important:: + | You must make sure that the ``pg_partman`` extension already **installed** in the Postgres. + | ``pgmq-sqlalchemy`` will **auto create** the ``pg_partman`` extension if it does not exist in the Postgres. + | For more details about ``pgmq`` with ``pg_partman``, checkout the |pgmq_partitioned_queue|_. + + + """ + # check if the pg_partman extension exists before creating a partitioned queue at runtime + self._check_pg_partman_ext() + + return await self._execute_async_operation( + PGMQOperation.create_partitioned_queue_async, + session, + commit, + queue_name, + str(partition_interval), + str(retention_interval), + ) + def validate_queue_name( self, queue_name: str, @@ -269,6 +354,23 @@ def validate_queue_name( queue_name, ) + async def validate_queue_name_async( + self, + queue_name: str, + *, + session: Optional[SESSION_TYPE] = None, + commit: bool = True, + ) -> None: + """ + * Will raise an error if the ``queue_name`` is more than 48 characters. + """ + return await self._execute_async_operation( + PGMQOperation.validate_queue_name_async, + session, + commit, + queue_name, + ) + def drop_queue( self, queue: str, @@ -306,6 +408,43 @@ def drop_queue( partitioned, ) + async def drop_queue_async( + self, + queue: str, + partitioned: bool = False, + *, + session: Optional[SESSION_TYPE] = None, + commit: bool = True, + ) -> bool: + """Drop a queue. + + .. _drop_queue_method: ref:`pgmq_sqlalchemy.PGMQueue.drop_queue` + .. |drop_queue_method| replace:: :py:meth:`~pgmq_sqlalchemy.PGMQueue.drop_queue` + + .. code-block:: python + + await pgmq_client.drop_queue_async('my_queue') + # for partitioned queue + await pgmq_client.drop_queue_async('my_partitioned_queue', partitioned=True) + + .. warning:: + | All messages and queue itself will be deleted. (``pgmq.q_`` table) + | **Archived tables** (``pgmq.a_`` table **will be dropped as well. )** + | + | See |archive_method|_ for more details. + """ + # check if the pg_partman extension exists before dropping a partitioned queue at runtime + if partitioned: + self._check_pg_partman_ext() + + return await self._execute_async_operation( + PGMQOperation.drop_queue_async, + session, + commit, + queue, + partitioned, + ) + def list_queues( self, *, @@ -325,6 +464,25 @@ def list_queues( commit, ) + async def list_queues_async( + self, + *, + session: Optional[SESSION_TYPE] = None, + commit: bool = True, + ) -> List[str]: + """List all queues. + + .. code-block:: python + + queue_list = await pgmq_client.list_queues_async() + print(queue_list) + """ + return await self._execute_async_operation( + PGMQOperation.list_queues_async, + session, + commit, + ) + def send( self, queue_name: str, @@ -361,6 +519,42 @@ def send( delay, ) + async def send_async( + self, + queue_name: str, + message: dict, + delay: int = 0, + *, + session: Optional[SESSION_TYPE] = None, + commit: bool = True, + ) -> int: + """Send a message to a queue. + + .. code-block:: python + + msg_id = await pgmq_client.send_async('my_queue', {'key': 'value', 'key2': 'value2'}) + print(msg_id) + + Example with delay: + + .. code-block:: python + + msg_id = await pgmq_client.send_async('my_queue', {'key': 'value', 'key2': 'value2'}, delay=10) + msg = await pgmq_client.read_async('my_queue') + assert msg is None + await asyncio.sleep(10) + msg = await pgmq_client.read_async('my_queue') + assert msg is not None + """ + return await self._execute_async_operation( + PGMQOperation.send_async, + session, + commit, + queue_name, + message, + delay, + ) + def send_batch( self, queue_name: str, @@ -391,6 +585,36 @@ def send_batch( delay, ) + async def send_batch_async( + self, + queue_name: str, + messages: List[dict], + delay: int = 0, + *, + session: Optional[SESSION_TYPE] = None, + commit: bool = True, + ) -> List[int]: + """ + Send a batch of messages to a queue. + + .. code-block:: python + + msgs = [{'key': 'value', 'key2': 'value2'}, {'key': 'value', 'key2': 'value2'}] + msg_ids = await pgmq_client.send_batch_async('my_queue', msgs) + print(msg_ids) + # send with delay + msg_ids = await pgmq_client.send_batch_async('my_queue', msgs, delay=10) + + """ + return await self._execute_async_operation( + PGMQOperation.send_batch_async, + session, + commit, + queue_name, + messages, + delay, + ) + def read( self, queue_name: str, @@ -469,6 +693,84 @@ def read( vt, ) + async def read_async( + self, + queue_name: str, + vt: Optional[int] = None, + *, + session: Optional[SESSION_TYPE] = None, + commit: bool = True, + ) -> Optional[Message]: + """ + .. _for_update_skip_locked: https://www.postgresql.org/docs/current/sql-select.html#SQL-FOR-UPDATE-SHARE + .. |for_update_skip_locked| replace:: **FOR UPDATE SKIP LOCKED** + + .. _read_method: ref:`pgmq_sqlalchemy.PGMQueue.read` + .. |read_method| replace:: :py:meth:`~pgmq_sqlalchemy.PGMQueue.read` + + Read a message from the queue. + + Returns: + |schema_message_class|_ or ``None`` if the queue is empty. + + .. note:: + | ``PGMQ`` use |for_update_skip_locked|_ lock to make sure **a message is only read by one consumer**. + | See the `pgmq.read `_ function for more details. + | + | For **consumer retries mechanism** (e.g. mark a message as failed after a certain number of retries) can be implemented by using the ``read_ct`` field in the |schema_message_class|_ object. + + + .. important:: + | ``vt`` is the **visibility timeout** in seconds. + | When a message is read from the queue, it will be invisible to other consumers for the duration of the ``vt``. + + Usage: + + .. code-block:: python + + from pgmq_sqlalchemy.schema import Message + + msg:Message = await pgmq_client.read_async('my_queue') + print(msg.msg_id) + print(msg.message) + print(msg.read_ct) # read count, how many times the message has been read + + Example with ``vt``: + + .. code-block:: python + + # assert `read_vt_demo` is empty + await pgmq_client.send_async('read_vt_demo', {'key': 'value', 'key2': 'value2'}) + msg = await pgmq_client.read_async('read_vt_demo', vt=10) + assert msg is not None + + # try to read immediately + msg = await pgmq_client.read_async('read_vt_demo') + assert msg is None # will return None because the message is still invisible + + # try to read after 5 seconds + await asyncio.sleep(5) + msg = await pgmq_client.read_async('read_vt_demo') + assert msg is None # still invisible after 5 seconds + + # try to read after 11 seconds + await asyncio.sleep(6) + msg = await pgmq_client.read_async('read_vt_demo') + assert msg is not None # the message is visible after 10 seconds + + + """ + if vt is None: + vt = self.vt + + return await self._execute_async_operation( + PGMQOperation.read_async, + session, + commit, + queue_name, + vt, + ) + def read_batch( self, queue_name: str, @@ -506,6 +808,43 @@ def read_batch( batch_size, ) + async def read_batch_async( + self, + queue_name: str, + batch_size: int = 1, + vt: Optional[int] = None, + *, + session: Optional[SESSION_TYPE] = None, + commit: bool = True, + ) -> Optional[List[Message]]: + """ + | Read a batch of messages from the queue. + | Usage: + + Returns: + List of |schema_message_class|_ or ``None`` if the queue is empty. + + .. code-block:: python + + from pgmq_sqlalchemy.schema import Message + + msgs:List[Message] = await pgmq_client.read_batch_async('my_queue', batch_size=10) + # with vt + msgs:List[Message] = await pgmq_client.read_batch_async('my_queue', batch_size=10, vt=10) + + """ + if vt is None: + vt = self.vt + + return await self._execute_async_operation( + PGMQOperation.read_batch_async, + session, + commit, + queue_name, + vt, + batch_size, + ) + def read_with_poll( self, queue_name: str, @@ -557,29 +896,178 @@ def read_with_poll( .. code-block:: python - msg = {'key': 'value'} - msg_ids = pgmq_client.send_batch('my_queue', [msg, msg, msg, msg], delay=3) + msg = {'key': 'value'} + msg_ids = pgmq_client.send_batch('my_queue', [msg, msg, msg, msg], delay=3) + + # the following code will block for 3 seconds + msgs = pgmq_client.read_with_poll('my_queue', qty=3, max_poll_seconds=5, poll_interval_ms=100) + assert len(msgs) == 3 # will read at most 3 messages (qty=3) + + """ + if vt is None: + vt = self.vt + + return self._execute_operation( + PGMQOperation.read_with_poll, + session, + commit, + queue_name, + vt, + qty, + max_poll_seconds, + poll_interval_ms, + ) + + async def read_with_poll_async( + self, + queue_name: str, + vt: Optional[int] = None, + qty: int = 1, + max_poll_seconds: int = 5, + poll_interval_ms: int = 100, + *, + session: Optional[SESSION_TYPE] = None, + commit: bool = True, + ) -> Optional[List[Message]]: + """ + + .. _read_with_poll_method: ref:`pgmq_sqlalchemy.PGMQueue.read_with_poll` + .. |read_with_poll_method| replace:: :py:meth:`~pgmq_sqlalchemy.PGMQueue.read_with_poll` + + + | Read messages from a queue with long-polling. + | + | When the queue is empty, the function block at most ``max_poll_seconds`` seconds. + | During the polling, the function will check the queue every ``poll_interval_ms`` milliseconds, until the queue has ``qty`` messages. + + Args: + queue_name (str): The name of the queue. + vt (Optional[int]): The visibility timeout in seconds. + qty (int): The number of messages to read. + max_poll_seconds (int): The maximum number of seconds to poll. + poll_interval_ms (int): The interval in milliseconds to poll. + + Returns: + List of |schema_message_class|_ or ``None`` if the queue is empty. + + Usage: + + .. code-block:: python + + msg_id = await pgmq_client.send_async('my_queue', {'key': 'value'}, delay=6) + + # the following code will block for 5 seconds + msgs = await pgmq_client.read_with_poll_async('my_queue', qty=1, max_poll_seconds=5, poll_interval_ms=100) + assert msgs is None + + # try read_with_poll again + # the following code will only block for 1 second + msgs = await pgmq_client.read_with_poll_async('my_queue', qty=1, max_poll_seconds=5, poll_interval_ms=100) + assert msgs is not None + + Another example: + + .. code-block:: python + + msg = {'key': 'value'} + msg_ids = await pgmq_client.send_batch_async('my_queue', [msg, msg, msg, msg], delay=3) + + # the following code will block for 3 seconds + msgs = await pgmq_client.read_with_poll_async('my_queue', qty=3, max_poll_seconds=5, poll_interval_ms=100) + assert len(msgs) == 3 # will read at most 3 messages (qty=3) + + """ + if vt is None: + vt = self.vt + + return await self._execute_async_operation( + PGMQOperation.read_with_poll_async, + session, + commit, + queue_name, + vt, + qty, + max_poll_seconds, + poll_interval_ms, + ) + + def set_vt( + self, + queue_name: str, + msg_id: int, + vt: int, + *, + session: Optional[SESSION_TYPE] = None, + commit: bool = True, + ) -> Optional[Message]: + """ + .. _set_vt_method: ref:`pgmq_sqlalchemy.PGMQueue.set_vt` + .. |set_vt_method| replace:: :py:meth:`~pgmq_sqlalchemy.PGMQueue.set_vt` + + Set the visibility timeout for a message. + + Args: + queue_name (str): The name of the queue. + msg_id (int): The message id. + vt (int): The visibility timeout in seconds. + + Returns: + |schema_message_class|_ or ``None`` if the message does not exist. + + Usage: + + .. code-block:: python + + msg_id = pgmq_client.send('my_queue', {'key': 'value'}, delay=10) + msg = pgmq_client.read('my_queue') + assert msg is not None + msg = pgmq_client.set_vt('my_queue', msg.msg_id, 10) + assert msg is not None + + .. tip:: + | |read_method|_ and |set_vt_method|_ can be used together to implement **exponential backoff** mechanism. + | `ref: Exponential Backoff And Jitter `_. + | **For example:** + + .. code-block:: python + + from pgmq_sqlalchemy import PGMQueue + from pgmq_sqlalchemy.schema import Message + + def _exp_backoff_retry(msg: Message)->int: + # exponential backoff retry + if msg.read_ct < 5: + return 2 ** msg.read_ct + return 2 ** 5 + + def consumer_with_backoff_retry(pgmq_client: PGMQueue, queue_name: str): + msg = pgmq_client.read( + queue_name=queue_name, + vt=1000, # set vt to 1000 seconds temporarily + ) + if msg is None: + return - # the following code will block for 3 seconds - msgs = pgmq_client.read_with_poll('my_queue', qty=3, max_poll_seconds=5, poll_interval_ms=100) - assert len(msgs) == 3 # will read at most 3 messages (qty=3) + # set exponential backoff retry + pgmq_client.set_vt( + queue_name=query_name, + msg_id=msg.msg_id, + vt=_exp_backoff_retry(msg) + ) """ - if vt is None: - vt = self.vt return self._execute_operation( - PGMQOperation.read_with_poll, + PGMQOperation.set_vt, + session, commit, queue_name, + msg_id, vt, - qty, - max_poll_seconds, - poll_interval_ms, ) - def set_vt( + async def set_vt_async( self, queue_name: str, msg_id: int, @@ -606,10 +1094,10 @@ def set_vt( .. code-block:: python - msg_id = pgmq_client.send('my_queue', {'key': 'value'}, delay=10) - msg = pgmq_client.read('my_queue') + msg_id = await pgmq_client.send_async('my_queue', {'key': 'value'}, delay=10) + msg = await pgmq_client.read_async('my_queue') assert msg is not None - msg = pgmq_client.set_vt('my_queue', msg.msg_id, 10) + msg = await pgmq_client.set_vt_async('my_queue', msg.msg_id, 10) assert msg is not None .. tip:: @@ -629,7 +1117,7 @@ def _exp_backoff_retry(msg: Message)->int: return 2 ** 5 def consumer_with_backoff_retry(pgmq_client: PGMQueue, queue_name: str): - msg = pgmq_client.read( + msg = await pgmq_client.read_async( queue_name=queue_name, vt=1000, # set vt to 1000 seconds temporarily ) @@ -637,7 +1125,7 @@ def consumer_with_backoff_retry(pgmq_client: PGMQueue, queue_name: str): return # set exponential backoff retry - pgmq_client.set_vt( + await pgmq_client.set_vt_async( queue_name=query_name, msg_id=msg.msg_id, vt=_exp_backoff_retry(msg) @@ -645,9 +1133,8 @@ def consumer_with_backoff_retry(pgmq_client: PGMQueue, queue_name: str): """ - return self._execute_operation( - PGMQOperation.set_vt, - + return await self._execute_async_operation( + PGMQOperation.set_vt_async, session, commit, queue_name, @@ -679,6 +1166,30 @@ def pop( queue_name, ) + async def pop_async( + self, + queue_name: str, + *, + session: Optional[SESSION_TYPE] = None, + commit: bool = True, + ) -> Optional[Message]: + """ + Reads a single message from a queue and deletes it upon read. + + .. code-block:: python + + msg = await pgmq_client.pop_async('my_queue') + print(msg.msg_id) + print(msg.message) + + """ + return await self._execute_async_operation( + PGMQOperation.pop_async, + session, + commit, + queue_name, + ) + def delete( self, queue_name: str, @@ -712,6 +1223,39 @@ def delete( msg_id, ) + async def delete_async( + self, + queue_name: str, + msg_id: int, + *, + session: Optional[SESSION_TYPE] = None, + commit: bool = True, + ) -> bool: + """ + Delete a message from the queue. + + .. _delete_method: ref:`pgmq_sqlalchemy.PGMQueue.delete` + .. |delete_method| replace:: :py:meth:`~pgmq_sqlalchemy.PGMQueue.delete` + + * Raises an error if the ``queue_name`` does not exist. + * Returns ``True`` if the message is deleted successfully. + * If the message does not exist, returns ``False``. + + .. code-block:: python + + msg_id = await pgmq_client.send_async('my_queue', {'key': 'value'}) + assert await pgmq_client.delete_async('my_queue', msg_id) + assert not await pgmq_client.delete_async('my_queue', msg_id) + + """ + return await self._execute_async_operation( + PGMQOperation.delete_async, + session, + commit, + queue_name, + msg_id, + ) + def delete_batch( self, queue_name: str, @@ -744,6 +1288,38 @@ def delete_batch( msg_ids, ) + async def delete_batch_async( + self, + queue_name: str, + msg_ids: List[int], + *, + session: Optional[SESSION_TYPE] = None, + commit: bool = True, + ) -> List[int]: + """ + Delete a batch of messages from the queue. + + .. _delete_batch_method: ref:`pgmq_sqlalchemy.PGMQueue.delete_batch` + .. |delete_batch_method| replace:: :py:meth:`~pgmq_sqlalchemy.PGMQueue.delete_batch` + + .. note:: + | Instead of return `bool` like |delete_method|_, + | |delete_batch_method|_ will return a list of ``msg_id`` that are successfully deleted. + + .. code-block:: python + + msg_ids = await pgmq_client.send_batch_async('my_queue', [{'key': 'value'}, {'key': 'value'}]) + assert await pgmq_client.delete_batch_async('my_queue', msg_ids) == msg_ids + + """ + return await self._execute_async_operation( + PGMQOperation.delete_batch_async, + session, + commit, + queue_name, + msg_ids, + ) + def archive( self, queue_name: str, @@ -780,6 +1356,42 @@ def archive( msg_id, ) + async def archive_async( + self, + queue_name: str, + msg_id: int, + *, + session: Optional[SESSION_TYPE] = None, + commit: bool = True, + ) -> bool: + """ + Archive a message from a queue. + + .. _archive_method: ref:`pgmq_sqlalchemy.PGMQueue.archive` + .. |archive_method| replace:: :py:meth:`~pgmq_sqlalchemy.PGMQueue.archive` + + + * Message will be deleted from the queue and moved to the archive table. + * Will be deleted from ``pgmq.q_`` and be inserted into the ``pgmq.a_`` table. + * raises an error if the ``queue_name`` does not exist. + * returns ``True`` if the message is archived successfully. + + .. code-block:: python + + msg_id = await pgmq_client.send_async('my_queue', {'key': 'value'}) + assert await pgmq_client.archive_async('my_queue', msg_id) + # since the message is archived, queue will be empty + assert await pgmq_client.read_async('my_queue') is None + + """ + return await self._execute_async_operation( + PGMQOperation.archive_async, + session, + commit, + queue_name, + msg_id, + ) + def archive_batch( self, queue_name: str, @@ -809,6 +1421,35 @@ def archive_batch( msg_ids, ) + async def archive_batch_async( + self, + queue_name: str, + msg_ids: List[int], + *, + session: Optional[SESSION_TYPE] = None, + commit: bool = True, + ) -> List[int]: + """ + Archive multiple messages from a queue. + + * Messages will be deleted from the queue and moved to the archive table. + * Returns a list of ``msg_id`` that are successfully archived. + + .. code-block:: python + + msg_ids = await pgmq_client.send_batch_async('my_queue', [{'key': 'value'}, {'key': 'value'}]) + assert await pgmq_client.archive_batch_async('my_queue', msg_ids) == msg_ids + assert await pgmq_client.read_async('my_queue') is None + + """ + return await self._execute_async_operation( + PGMQOperation.archive_batch_async, + session, + commit, + queue_name, + msg_ids, + ) + def purge( self, queue_name: str, @@ -834,6 +1475,31 @@ def purge( queue_name, ) + async def purge_async( + self, + queue_name: str, + *, + session: Optional[SESSION_TYPE] = None, + commit: bool = True, + ) -> int: + """ + * Delete all messages from a queue, return the number of messages deleted. + * Archive tables will **not** be affected. + + .. code-block:: python + + msg_ids = await pgmq_client.send_batch_async('my_queue', [{'key': 'value'}, {'key': 'value'}]) + assert await pgmq_client.purge_async('my_queue') == 2 + assert await pgmq_client.read_async('my_queue') is None + + """ + return await self._execute_async_operation( + PGMQOperation.purge_async, + session, + commit, + queue_name, + ) + def metrics( self, queue_name: str, @@ -866,6 +1532,38 @@ def metrics( queue_name, ) + async def metrics_async( + self, + queue_name: str, + *, + session: Optional[SESSION_TYPE] = None, + commit: bool = True, + ) -> Optional[QueueMetrics]: + """ + Get metrics for a queue. + + Returns: + |schema_queue_metrics_class|_ or ``None`` if the queue does not exist. + + Usage: + + .. code-block:: python + + from pgmq_sqlalchemy.schema import QueueMetrics + + metrics:QueueMetrics = await pgmq_client.metrics_async('my_queue') + print(metrics.queue_name) + print(metrics.queue_length) + print(metrics.queue_length) + + """ + return await self._execute_async_operation( + PGMQOperation.metrics_async, + session, + commit, + queue_name, + ) + def metrics_all( self, *, @@ -910,3 +1608,48 @@ def metrics_all( session, commit, ) + + async def metrics_all_async( + self, + *, + session: Optional[SESSION_TYPE] = None, + commit: bool = True, + ) -> Optional[List[QueueMetrics]]: + """ + + .. _read_committed_isolation_level: https://www.postgresql.org/docs/current/transaction-iso.html#XACT-READ-COMMITTED + .. |read_committed_isolation_level| replace:: **READ COMMITTED** + + .. _metrics_all_method: ref:`pgmq_sqlalchemy.PGMQueue.metrics_all` + .. |metrics_all_method| replace:: :py:meth:`~pgmq_sqlalchemy.PGMQueue.metrics_all` + + Get metrics for all queues. + + Returns: + List of |schema_queue_metrics_class|_ or ``None`` if there are no queues. + + Usage: + + .. code-block:: python + + from pgmq_sqlalchemy.schema import QueueMetrics + + metrics:List[QueueMetrics] = await pgmq_client.metrics_all_async() + for m in metrics: + print(m.queue_name) + print(m.queue_length) + print(m.queue_length) + + .. warning:: + | You should use a **distributed lock** to avoid **race conditions** when calling |metrics_all_method|_ in **concurrent** |drop_queue_method|_ **scenarios**. + | + | Since the default PostgreSQL isolation level is |read_committed_isolation_level|_, the queue metrics to be fetched **may not exist** if there are **concurrent** |drop_queue_method|_ **operations**. + | Check the `pgmq.metrics_all `_ function for more details. + + + """ + return await self._execute_async_operation( + PGMQOperation.metrics_all_async, + session, + commit, + ) From 88223de2f368b976025166dc5cf978c6347c56d5 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 5 Jan 2026 11:54:03 +0000 Subject: [PATCH 3/7] Initial plan From 7499b61044296f96e2f273e42ba5e29999f5aaf9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 5 Jan 2026 11:59:41 +0000 Subject: [PATCH 4/7] Add compelete_missing_test_for_operation.py script with CST-based approach Co-authored-by: jason810496 <68415893+jason810496@users.noreply.github.com> --- .../compelete_missing_test_for_operation.py | 112 +++++++ scripts/scripts_utils/operation_test_ast.py | 308 ++++++++++++++++++ 2 files changed, 420 insertions(+) create mode 100644 scripts/compelete_missing_test_for_operation.py create mode 100644 scripts/scripts_utils/operation_test_ast.py diff --git a/scripts/compelete_missing_test_for_operation.py b/scripts/compelete_missing_test_for_operation.py new file mode 100644 index 0000000..e1d56c1 --- /dev/null +++ b/scripts/compelete_missing_test_for_operation.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python +# /// script +# requires-python = ">=3.10,<3.11" +# dependencies = [ +# "rich>=13.6.0", +# "libcst>=1.0.0", +# ] +# /// +""" +Script to check for missing async tests in test_operation.py and generate them. + +For each public sync test (test_*_sync), checks if there's a corresponding +async test with _async suffix. If missing, generates it using CST transformations. +""" + +import libcst as cst +import sys +from pathlib import Path +import contextlib +import shutil +import tempfile + + +from scripts_utils.console import console, user_input +from scripts_utils.formatting import format_file, compare_file +from scripts_utils.operation_test_ast import ( + parse_test_functions_from_module, + get_async_tests_to_add, + fill_missing_tests_to_module, +) + + +def main(): + """Main function.""" + + # Define test file path + PROJECT_ROOT = Path(__file__).parent.parent + TEST_FILE = PROJECT_ROOT / "tests" / "test_operation.py" + TEST_BACKUP_FILE = PROJECT_ROOT / "tests" / "test_operation_backup.py" + + if not TEST_FILE.exists(): + console.print(f"[bold red]ERROR:[/bold red] Test file not found: {TEST_FILE}") + sys.exit(1) + + module_tree = cst.parse_module(TEST_FILE.read_text()) + all_tests, missing_async = parse_test_functions_from_module(module_tree) + + if not missing_async: + console.print( + "[bold green]SUCCESS:[/bold green] All sync tests have corresponding async versions!" + ) + sys.exit(0) + + # Log all the missing async tests + console.print() + console.print( + f"[bold yellow]WARNING:[/bold yellow] Found {len(missing_async)} missing async tests:", + style="bold", + ) + for test_name in sorted(missing_async): + async_name = test_name.replace("_sync", "_async") + console.print(f" [yellow]-[/yellow] {async_name}") + console.print() + + # Create missing async tests + async_tests_to_add = get_async_tests_to_add(all_tests, missing_async) + + # Insert back to module + module_tree = fill_missing_tests_to_module(module_tree, async_tests_to_add) + + # Write back to tmp file for comparison + tmp_file = "" + with tempfile.NamedTemporaryFile(mode="w+t", delete=False, suffix=".py") as f: + f.write(module_tree.code) + f.flush() + tmp_file = f.name + console.log(f"Generated missing async tests at {tmp_file}") + + if tmp_file: + max_formatting = 3 + for _ in range(max_formatting): + if format_file(tmp_file): + break + + # Verify that all async tests are now present + _, missing_async_for_tmp = parse_test_functions_from_module( + cst.parse_module(Path(tmp_file).read_text()) + ) + + if missing_async_for_tmp: + console.log( + f"[error]Still have missing async tests after generation in {tmp_file}: {missing_async_for_tmp}[/]" + ) + else: + console.log("[success]All missing async tests are generated[/]") + + # Compare existing test file and tmp file + with contextlib.suppress(Exception): + compare_file(TEST_FILE, tmp_file) + + # Ask whether to apply the change + if user_input(f"Do you want to apply change to {TEST_FILE}"): + console.log(f"Backup existing {TEST_FILE} at {TEST_BACKUP_FILE}") + shutil.copy(TEST_FILE, TEST_BACKUP_FILE) + shutil.copy(tmp_file, TEST_FILE) + console.log("Added missing async tests successfully") + + sys.exit(0) + + +if __name__ == "__main__": + main() diff --git a/scripts/scripts_utils/operation_test_ast.py b/scripts/scripts_utils/operation_test_ast.py new file mode 100644 index 0000000..4c6b073 --- /dev/null +++ b/scripts/scripts_utils/operation_test_ast.py @@ -0,0 +1,308 @@ +import libcst as cst +import re +from typing import Dict, Set, List +from scripts_utils.common_ast import MethodInfo + + +class AsyncTestTransformer(cst.CSTTransformer): + """Transform sync test functions to async test functions.""" + + def leave_FunctionDef( + self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef + ) -> cst.FunctionDef: + """Transform function to async test.""" + # Change function name from _sync to _async + new_name = updated_node.name.value.replace("_sync", "_async") + + # Add async keyword + new_node = updated_node.with_changes( + asynchronous=cst.Asynchronous(), name=cst.Name(new_name) + ) + + # Transform docstring if exists + if updated_node.body.body and isinstance( + updated_node.body.body[0], cst.SimpleStatementLine + ): + first_stmt = updated_node.body.body[0] + if first_stmt.body and isinstance(first_stmt.body[0], cst.Expr): + expr = first_stmt.body[0] + if isinstance(expr.value, (cst.SimpleString, cst.ConcatenatedString)): + # Extract docstring value + if isinstance(expr.value, cst.SimpleString): + docstring = expr.value.value + else: + # For concatenated strings, skip transformation + docstring = None + + if docstring: + # Remove quotes to get actual string content + if docstring.startswith('"""') or docstring.startswith("'''"): + quote = docstring[:3] + content = docstring[3:-3] + elif docstring.startswith('"') or docstring.startswith("'"): + quote = docstring[0] + content = docstring[1:-1] + else: + content = docstring + quote = '"""' + + transformed_content = self.transform_docstring(content) + new_docstring = f"{quote}{transformed_content}{quote}" + + # Create new docstring node + new_expr = expr.with_changes( + value=cst.SimpleString(new_docstring) + ) + new_first_stmt = first_stmt.with_changes(body=[new_expr]) + + # Update body with new docstring + new_body = [new_first_stmt] + list(updated_node.body.body[1:]) + new_node = new_node.with_changes( + body=new_node.body.with_changes(body=new_body) + ) + + return new_node + + def leave_Param( + self, original_node: cst.Param, updated_node: cst.Param + ) -> cst.Param: + """Transform function parameters to use async fixtures.""" + param_name = updated_node.name.value + + # Replace get_session_maker with get_async_session_maker + if param_name == "get_session_maker": + return updated_node.with_changes(name=cst.Name("get_async_session_maker")) + + return updated_node + + def leave_With(self, original_node: cst.With, updated_node: cst.With) -> cst.With: + """Transform 'with' statements to 'async with'.""" + # Check if this is a session context manager + for item in updated_node.items: + if isinstance(item.item, cst.Call): + if isinstance(item.item.func, cst.Name): + if "session_maker" in item.item.func.value: + # Transform to async with + return updated_node.with_changes(asynchronous=cst.Asynchronous()) + + return updated_node + + def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: + """Transform method calls to add _async suffix and await.""" + # Check if this is a PGMQOperation method call + if isinstance(updated_node.func, cst.Attribute): + if isinstance(updated_node.func.value, cst.Name): + if updated_node.func.value.value == "PGMQOperation": + # Add _async suffix to method name + new_func = updated_node.func.with_changes( + attr=cst.Name(f"{updated_node.func.attr.value}_async") + ) + return updated_node.with_changes(func=new_func) + + # Check if this is a get_session_maker() call + if isinstance(updated_node.func, cst.Name): + if updated_node.func.value == "get_session_maker": + # Replace with get_async_session_maker + return updated_node.with_changes( + func=cst.Name("get_async_session_maker") + ) + + return updated_node + + def leave_Assign( + self, original_node: cst.Assign, updated_node: cst.Assign + ) -> cst.Assign: + """Add await to assignments that call async methods.""" + # Check if the value is a PGMQOperation call + if isinstance(updated_node.value, cst.Call): + if isinstance(updated_node.value.func, cst.Attribute): + if isinstance(updated_node.value.func.value, cst.Name): + if updated_node.value.func.value.value == "PGMQOperation": + # Wrap in await + return updated_node.with_changes( + value=cst.Await(expression=updated_node.value) + ) + + return updated_node + + def leave_Expr(self, original_node: cst.Expr, updated_node: cst.Expr) -> cst.Expr: + """Add await to expression statements that call async methods.""" + # Check if this is a PGMQOperation call (not in assignment) + if isinstance(updated_node.value, cst.Call): + if isinstance(updated_node.value.func, cst.Attribute): + if isinstance(updated_node.value.func.value, cst.Name): + if updated_node.value.func.value.value == "PGMQOperation": + # Wrap in await + return updated_node.with_changes( + value=cst.Await(expression=updated_node.value) + ) + + return updated_node + + def transform_docstring(self, docstring: str) -> str: + """Transform docstring for async version.""" + # Replace 'synchronously' with 'asynchronously' + modified = docstring.replace("using PGMQOperation.", "using PGMQOperation asynchronously.") + modified = modified.replace("Test ", "Test ") # Keep Test prefix + + # Add 'asynchronously' before the period if not already present + if "asynchronously" not in modified and not modified.endswith("asynchronously."): + modified = modified.rstrip(".") + if modified and not modified.endswith("asynchronously"): + modified += " asynchronously." + + return modified + + +class TestFunctionVisitor(cst.CSTVisitor): + """Visitor to collect test functions from a module.""" + + def __init__(self): + self.test_functions: List[MethodInfo] = [] + + def visit_FunctionDef(self, node: cst.FunctionDef) -> None: + """Visit function definitions and collect test functions.""" + func_name = node.name.value + if func_name.startswith("test_"): + # Determine if it's async or sync + is_async = func_name.endswith("_async") + base_name = func_name[:-6] if is_async else func_name + + method_info = MethodInfo(func_name, node) + method_info.is_target = True + method_info.is_async = is_async + method_info.base_name = base_name + + self.test_functions.append(method_info) + + +class FillMissingTestsTransformer(cst.CSTTransformer): + """Transformer to add missing async tests after their sync counterparts.""" + + def __init__(self, to_add_async_tests: Dict[str, MethodInfo]): + self.to_add_async_tests = to_add_async_tests + self.added_decorators = False + + def leave_Module( + self, original_node: cst.Module, updated_node: cst.Module + ) -> cst.Module: + """Transform the module to add missing async tests.""" + new_body = [] + + for stmt in updated_node.body: + new_body.append(stmt) + + # If this is a sync test function, check if we need to add async version + if isinstance(stmt, cst.FunctionDef): + func_name = stmt.name.value + if func_name in self.to_add_async_tests: + # Add decorator before async test + decorator = cst.Decorator( + decorator=cst.Attribute( + value=cst.Attribute( + value=cst.Name("pytest"), attr=cst.Name("mark") + ), + attr=cst.Name("asyncio"), + ) + ) + + async_test = self.to_add_async_tests[func_name].node + + # Add decorator to async test + if async_test.decorators: + decorated_async = async_test.with_changes( + decorators=[decorator] + list(async_test.decorators) + ) + else: + decorated_async = async_test.with_changes(decorators=[decorator]) + + # Add empty line before async test for readability + new_body.append( + cst.EmptyLine(indent=False, whitespace=cst.SimpleWhitespace("")) + ) + new_body.append( + cst.EmptyLine(indent=False, whitespace=cst.SimpleWhitespace("")) + ) + new_body.append(decorated_async) + + return updated_node.with_changes(body=new_body) + + +def parse_test_functions_from_module( + module_tree: cst.Module, +) -> tuple[List[MethodInfo], Set[str]]: + """ + Parse test functions from module. + + Returns: + Tuple of (all_test_functions, missing_async_test_names) + """ + visitor = TestFunctionVisitor() + module_tree.visit(visitor) + + # Categorize tests + async_tests_set = set() + missing_async_set = set() + + for test_info in visitor.test_functions: + if not test_info.is_target: + continue + + if test_info.is_async: + # Extract base name without _async suffix + base_name = test_info.name.replace("_async", "") + async_tests_set.add(base_name) + + # Find missing async tests + for test_info in visitor.test_functions: + if not test_info.is_target: + continue + + # Check if this is a sync test + if test_info.name.endswith("_sync"): + # Get base name without _sync suffix + base_name_without_sync = test_info.name.replace("_sync", "") + # Check if async version exists + if base_name_without_sync not in async_tests_set: + missing_async_set.add(test_info.name) # Store full sync name + + return visitor.test_functions, missing_async_set + + +def transform_test_to_async(test_info: MethodInfo) -> MethodInfo: + """Transform a sync test function to async.""" + transformer = AsyncTestTransformer() + async_node = test_info.node.visit(transformer) + + new_name = test_info.name.replace("_sync", "_async") + return MethodInfo(new_name, async_node) + + +def get_async_tests_to_add( + all_tests: List[MethodInfo], missing_async: Set[str] +) -> Dict[str, MethodInfo]: + """ + Generate async tests for missing ones. + + Args: + all_tests: All test functions found + missing_async: Set of sync test names that need async versions + + Returns: + Dictionary mapping sync test name to async MethodInfo + """ + async_tests: Dict[str, MethodInfo] = {} + + for test_info in all_tests: + if test_info.name in missing_async: + async_tests[test_info.name] = transform_test_to_async(test_info) + + return async_tests + + +def fill_missing_tests_to_module( + module_tree: cst.Module, to_add_async_tests: Dict[str, MethodInfo] +) -> cst.Module: + """Fill missing async tests into the module.""" + transformer = FillMissingTestsTransformer(to_add_async_tests) + return module_tree.visit(transformer) From 8ed83d68ac757a977b60bc01d9a7cafd2129d7b4 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 5 Jan 2026 12:01:49 +0000 Subject: [PATCH 5/7] Fix code review feedback: remove redundant code and use Tuple from typing Co-authored-by: jason810496 <68415893+jason810496@users.noreply.github.com> --- scripts/compelete_missing_test_for_operation.py | 0 scripts/scripts_utils/operation_test_ast.py | 5 ++--- 2 files changed, 2 insertions(+), 3 deletions(-) mode change 100644 => 100755 scripts/compelete_missing_test_for_operation.py diff --git a/scripts/compelete_missing_test_for_operation.py b/scripts/compelete_missing_test_for_operation.py old mode 100644 new mode 100755 diff --git a/scripts/scripts_utils/operation_test_ast.py b/scripts/scripts_utils/operation_test_ast.py index 4c6b073..d293170 100644 --- a/scripts/scripts_utils/operation_test_ast.py +++ b/scripts/scripts_utils/operation_test_ast.py @@ -1,6 +1,6 @@ import libcst as cst import re -from typing import Dict, Set, List +from typing import Dict, Set, List, Tuple from scripts_utils.common_ast import MethodInfo @@ -143,7 +143,6 @@ def transform_docstring(self, docstring: str) -> str: """Transform docstring for async version.""" # Replace 'synchronously' with 'asynchronously' modified = docstring.replace("using PGMQOperation.", "using PGMQOperation asynchronously.") - modified = modified.replace("Test ", "Test ") # Keep Test prefix # Add 'asynchronously' before the period if not already present if "asynchronously" not in modified and not modified.endswith("asynchronously."): @@ -230,7 +229,7 @@ def leave_Module( def parse_test_functions_from_module( module_tree: cst.Module, -) -> tuple[List[MethodInfo], Set[str]]: +) -> Tuple[List[MethodInfo], Set[str]]: """ Parse test functions from module. From bd65c015f69a64f25a45165536b927cb4f245eea Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Tue, 6 Jan 2026 17:36:24 +0800 Subject: [PATCH 6/7] Finalize test_operation --- .../compelete_missing_test_for_operation.py | 4 +- scripts/scripts_utils/operation_test_ast.py | 35 +- scripts/scripts_utils/queue_ast.py | 46 ++- tests/test_operation.py | 343 +++++++++++++++++- 4 files changed, 400 insertions(+), 28 deletions(-) diff --git a/scripts/compelete_missing_test_for_operation.py b/scripts/compelete_missing_test_for_operation.py index e1d56c1..04de8ee 100755 --- a/scripts/compelete_missing_test_for_operation.py +++ b/scripts/compelete_missing_test_for_operation.py @@ -32,7 +32,7 @@ def main(): """Main function.""" - + # Define test file path PROJECT_ROOT = Path(__file__).parent.parent TEST_FILE = PROJECT_ROOT / "tests" / "test_operation.py" @@ -64,7 +64,7 @@ def main(): # Create missing async tests async_tests_to_add = get_async_tests_to_add(all_tests, missing_async) - + # Insert back to module module_tree = fill_missing_tests_to_module(module_tree, async_tests_to_add) diff --git a/scripts/scripts_utils/operation_test_ast.py b/scripts/scripts_utils/operation_test_ast.py index d293170..f1b0e78 100644 --- a/scripts/scripts_utils/operation_test_ast.py +++ b/scripts/scripts_utils/operation_test_ast.py @@ -1,5 +1,4 @@ import libcst as cst -import re from typing import Dict, Set, List, Tuple from scripts_utils.common_ast import MethodInfo @@ -83,7 +82,9 @@ def leave_With(self, original_node: cst.With, updated_node: cst.With) -> cst.Wit if isinstance(item.item.func, cst.Name): if "session_maker" in item.item.func.value: # Transform to async with - return updated_node.with_changes(asynchronous=cst.Asynchronous()) + return updated_node.with_changes( + asynchronous=cst.Asynchronous() + ) return updated_node @@ -98,7 +99,7 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Cal attr=cst.Name(f"{updated_node.func.attr.value}_async") ) return updated_node.with_changes(func=new_func) - + # Check if this is a get_session_maker() call if isinstance(updated_node.func, cst.Name): if updated_node.func.value == "get_session_maker": @@ -122,6 +123,12 @@ def leave_Assign( return updated_node.with_changes( value=cst.Await(expression=updated_node.value) ) + # Check if this is a session method call (session.commit, session.rollback, etc.) + elif updated_node.value.func.value.value == "session": + # Wrap in await + return updated_node.with_changes( + value=cst.Await(expression=updated_node.value) + ) return updated_node @@ -136,20 +143,30 @@ def leave_Expr(self, original_node: cst.Expr, updated_node: cst.Expr) -> cst.Exp return updated_node.with_changes( value=cst.Await(expression=updated_node.value) ) + # Check if this is a session method call (session.commit, session.rollback, etc.) + elif updated_node.value.func.value.value == "session": + # Wrap in await + return updated_node.with_changes( + value=cst.Await(expression=updated_node.value) + ) return updated_node def transform_docstring(self, docstring: str) -> str: """Transform docstring for async version.""" # Replace 'synchronously' with 'asynchronously' - modified = docstring.replace("using PGMQOperation.", "using PGMQOperation asynchronously.") - + modified = docstring.replace( + "using PGMQOperation.", "using PGMQOperation asynchronously." + ) + # Add 'asynchronously' before the period if not already present - if "asynchronously" not in modified and not modified.endswith("asynchronously."): + if "asynchronously" not in modified and not modified.endswith( + "asynchronously." + ): modified = modified.rstrip(".") if modified and not modified.endswith("asynchronously"): modified += " asynchronously." - + return modified @@ -213,7 +230,9 @@ def leave_Module( decorators=[decorator] + list(async_test.decorators) ) else: - decorated_async = async_test.with_changes(decorators=[decorator]) + decorated_async = async_test.with_changes( + decorators=[decorator] + ) # Add empty line before async test for readability new_body.append( diff --git a/scripts/scripts_utils/queue_ast.py b/scripts/scripts_utils/queue_ast.py index 27b16bd..7de5ac0 100644 --- a/scripts/scripts_utils/queue_ast.py +++ b/scripts/scripts_utils/queue_ast.py @@ -3,7 +3,6 @@ import sys from pathlib import Path from typing import List, Set, Dict -import copy sys.path.insert(0, str(Path(__name__).parent.parent.joinpath("scripts").resolve())) @@ -21,7 +20,9 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Cal # Check if any argument is PGMQOperation.method new_args = [] for arg in updated_node.args: - if isinstance(arg.value, cst.Attribute) and isinstance(arg.value.value, cst.Name): + if isinstance(arg.value, cst.Attribute) and isinstance( + arg.value.value, cst.Name + ): if arg.value.value.value == "PGMQOperation": # Add _async suffix to method name new_attr = arg.value.with_changes( @@ -30,31 +31,38 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Cal new_args.append(arg.with_changes(value=new_attr)) continue new_args.append(arg) - + # Replace `self._execute_operation` to `self._execute_async_operation` if isinstance(updated_node.func.value, cst.Name): - if (updated_node.func.value.value == "self" and - updated_node.func.attr.value == self.to_replace_execute_func_attr): + if ( + updated_node.func.value.value == "self" + and updated_node.func.attr.value + == self.to_replace_execute_func_attr + ): updated_node = updated_node.with_changes( func=updated_node.func.with_changes( attr=cst.Name(self.target_execute_func_attr) ) ) - + if new_args: updated_node = updated_node.with_changes(args=new_args) - + return updated_node - def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: + def leave_FunctionDef( + self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef + ) -> cst.FunctionDef: # Transform function to async new_node = updated_node.with_changes( asynchronous=cst.Asynchronous(), - name=cst.Name(f"{updated_node.name.value}_async") + name=cst.Name(f"{updated_node.name.value}_async"), ) # Transform docstring if exists - if updated_node.body.body and isinstance(updated_node.body.body[0], cst.SimpleStatementLine): + if updated_node.body.body and isinstance( + updated_node.body.body[0], cst.SimpleStatementLine + ): first_stmt = updated_node.body.body[0] if first_stmt.body and isinstance(first_stmt.body[0], cst.Expr): expr = first_stmt.body[0] @@ -65,7 +73,7 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu else: # For concatenated strings, we'll skip transformation for now docstring = None - + if docstring: # Remove quotes to get actual string content if docstring.startswith('"""') or docstring.startswith("'''"): @@ -77,14 +85,16 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu else: content = docstring quote = '"""' - + transformed_content = self.transform_docstring(content) - new_docstring = f'{quote}{transformed_content}{quote}' - + new_docstring = f"{quote}{transformed_content}{quote}" + # Create new docstring node - new_expr = expr.with_changes(value=cst.SimpleString(new_docstring)) + new_expr = expr.with_changes( + value=cst.SimpleString(new_docstring) + ) new_first_stmt = first_stmt.with_changes(body=[new_expr]) - + # Update body with new docstring new_body = [new_first_stmt] + list(updated_node.body.body[1:]) new_node = new_node.with_changes( @@ -93,7 +103,9 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu return new_node - def leave_Return(self, original_node: cst.Return, updated_node: cst.Return) -> cst.Return: + def leave_Return( + self, original_node: cst.Return, updated_node: cst.Return + ) -> cst.Return: # Only wrap return value in await if it's a call expression # (which is likely to be an operation that needs awaiting) if updated_node.value and isinstance(updated_node.value, cst.Call): diff --git a/tests/test_operation.py b/tests/test_operation.py index ce5bd62..b57e339 100644 --- a/tests/test_operation.py +++ b/tests/test_operation.py @@ -3,6 +3,7 @@ This test suite tests the PGMQOperation class methods directly, which are transaction-friendly static methods that accept sessions. """ + import time import uuid @@ -61,6 +62,25 @@ def test_create_unlogged_queue_sync(get_session_maker, db_session): ) +@pytest.mark.asyncio +async def test_create_unlogged_queue_async(get_async_session_maker, db_session): + """Test creating an unlogged queue using PGMQOperation asynchronously.""" + queue_name = f"test_queue_{uuid.uuid4().hex}" + + async with get_async_session_maker() as session: + await PGMQOperation.create_queue_async( + queue_name, unlogged=True, session=session, commit=True + ) + + assert check_queue_exists(db_session, queue_name) is True + + # Clean up + async with get_async_session_maker() as session: + await PGMQOperation.drop_queue_async( + queue_name, partitioned=False, session=session, commit=True + ) + + def test_validate_queue_name_sync(get_session_maker): """Test queue name validation.""" queue_name = f"test_queue_{uuid.uuid4().hex}" @@ -70,12 +90,32 @@ def test_validate_queue_name_sync(get_session_maker): PGMQOperation.validate_queue_name(queue_name, session=session, commit=True) # Should raise for name that's too long (either ProgrammingError or InternalError depending on driver) - with pytest.raises((ProgrammingError, InternalError)) as e: + with pytest.raises((ProgrammingError, InternalError, Exception)) as e: PGMQOperation.validate_queue_name("a" * 49, session=session, commit=True) error_msg = str(e.value.orig) if hasattr(e.value, "orig") else str(e.value) assert "queue name is too long" in error_msg +@pytest.mark.asyncio +async def test_validate_queue_name_async(get_async_session_maker): + """Test queue name validation asynchronously.""" + queue_name = f"test_queue_{uuid.uuid4().hex}" + + async with get_async_session_maker() as session: + # Should not raise for valid name + await PGMQOperation.validate_queue_name_async( + queue_name, session=session, commit=True + ) + + # Should raise for name that's too long (either ProgrammingError or InternalError depending on driver) + with pytest.raises((ProgrammingError, InternalError, Exception)) as e: + await PGMQOperation.validate_queue_name_async( + "a" * 49, session=session, commit=True + ) + error_msg = str(e.value.orig) if hasattr(e.value, "orig") else str(e.value) + assert "queue name is too long" in error_msg + + def test_list_queues_sync(get_session_maker, db_session): """Test listing queues.""" queue_name = f"test_queue_{uuid.uuid4().hex}" @@ -99,6 +139,30 @@ def test_list_queues_sync(get_session_maker, db_session): ) +@pytest.mark.asyncio +async def test_list_queues_async(get_async_session_maker, db_session): + """Test listing queues asynchronously.""" + queue_name = f"test_queue_{uuid.uuid4().hex}" + + # Create a queue + async with get_async_session_maker() as session: + await PGMQOperation.create_queue_async( + queue_name, unlogged=False, session=session, commit=True + ) + + # List queues + async with get_async_session_maker() as session: + queues = await PGMQOperation.list_queues_async(session=session, commit=True) + + assert queue_name in queues + + # Clean up + async with get_async_session_maker() as session: + await PGMQOperation.drop_queue_async( + queue_name, partitioned=False, session=session, commit=True + ) + + def test_send_and_read_sync(get_session_maker, db_session): """Test sending and reading messages.""" queue_name = f"test_queue_{uuid.uuid4().hex}" @@ -166,6 +230,41 @@ def test_send_batch_sync(get_session_maker, db_session): ) +@pytest.mark.asyncio +async def test_send_batch_async(get_async_session_maker, db_session): + """Test sending a batch of messages asynchronously.""" + queue_name = f"test_queue_{uuid.uuid4().hex}" + messages = [{"key": f"value{i}"} for i in range(5)] + + # Create queue + async with get_async_session_maker() as session: + await PGMQOperation.create_queue_async( + queue_name, unlogged=False, session=session, commit=True + ) + + # Send batch + async with get_async_session_maker() as session: + msg_ids = await PGMQOperation.send_batch_async( + queue_name, messages, delay=0, session=session, commit=True + ) + + assert len(msg_ids) == 5 + + # Read batch + async with get_async_session_maker() as session: + msgs = await PGMQOperation.read_batch_async( + queue_name, vt=30, batch_size=5, session=session, commit=True + ) + + assert len(msgs) == 5 + + # Clean up + async with get_async_session_maker() as session: + await PGMQOperation.drop_queue_async( + queue_name, partitioned=False, session=session, commit=True + ) + + def test_pop_sync(get_session_maker, db_session): """Test popping a message from the queue.""" queue_name = f"test_queue_{uuid.uuid4().hex}" @@ -199,6 +298,40 @@ def test_pop_sync(get_session_maker, db_session): ) +@pytest.mark.asyncio +async def test_pop_async(get_async_session_maker, db_session): + """Test popping a message from the queue asynchronously.""" + queue_name = f"test_queue_{uuid.uuid4().hex}" + + # Create queue and send message + async with get_async_session_maker() as session: + await PGMQOperation.create_queue_async( + queue_name, unlogged=False, session=session, commit=True + ) + msg_id = await PGMQOperation.send_async( + queue_name, MSG, delay=0, session=session, commit=True + ) + + # Pop message + async with get_async_session_maker() as session: + msg = await PGMQOperation.pop_async(queue_name, session=session, commit=True) + + assert msg is not None + assert msg.msg_id == msg_id + + # Verify queue is empty + async with get_async_session_maker() as session: + msg2 = await PGMQOperation.pop_async(queue_name, session=session, commit=True) + + assert msg2 is None + + # Clean up + async with get_async_session_maker() as session: + await PGMQOperation.drop_queue_async( + queue_name, partitioned=False, session=session, commit=True + ) + + def test_delete_sync(get_session_maker, db_session): """Test deleting a message.""" queue_name = f"test_queue_{uuid.uuid4().hex}" @@ -225,6 +358,35 @@ def test_delete_sync(get_session_maker, db_session): ) +@pytest.mark.asyncio +async def test_delete_async(get_async_session_maker, db_session): + """Test deleting a message asynchronously.""" + queue_name = f"test_queue_{uuid.uuid4().hex}" + + # Create queue and send message + async with get_async_session_maker() as session: + await PGMQOperation.create_queue_async( + queue_name, unlogged=False, session=session, commit=True + ) + msg_id = await PGMQOperation.send_async( + queue_name, MSG, delay=0, session=session, commit=True + ) + + # Delete message + async with get_async_session_maker() as session: + deleted = await PGMQOperation.delete_async( + queue_name, msg_id, session=session, commit=True + ) + + assert deleted is True + + # Clean up + async with get_async_session_maker() as session: + await PGMQOperation.drop_queue_async( + queue_name, partitioned=False, session=session, commit=True + ) + + def test_set_vt_sync(get_session_maker, db_session): """Test setting visibility timeout.""" queue_name = f"test_queue_{uuid.uuid4().hex}" @@ -255,6 +417,37 @@ def test_set_vt_sync(get_session_maker, db_session): ) +@pytest.mark.asyncio +async def test_set_vt_async(get_async_session_maker, db_session): + """Test setting visibility timeout asynchronously.""" + queue_name = f"test_queue_{uuid.uuid4().hex}" + + # Create queue and send message + async with get_async_session_maker() as session: + await PGMQOperation.create_queue_async( + queue_name, unlogged=False, session=session, commit=True + ) + msg_id = await PGMQOperation.send_async( + queue_name, MSG, delay=0, session=session, commit=True + ) + # Read message to set initial vt + await PGMQOperation.read_async(queue_name, vt=5, session=session, commit=True) + + # Set new vt + async with get_async_session_maker() as session: + msg = await PGMQOperation.set_vt_async( + queue_name, msg_id, vt=60, session=session, commit=True + ) + + assert msg is not None + + # Clean up + async with get_async_session_maker() as session: + await PGMQOperation.drop_queue_async( + queue_name, partitioned=False, session=session, commit=True + ) + + def test_archive_sync(get_session_maker, db_session): """Test archiving a message.""" queue_name = f"test_queue_{uuid.uuid4().hex}" @@ -283,6 +476,35 @@ def test_archive_sync(get_session_maker, db_session): ) +@pytest.mark.asyncio +async def test_archive_async(get_async_session_maker, db_session): + """Test archiving a message asynchronously.""" + queue_name = f"test_queue_{uuid.uuid4().hex}" + + # Create queue and send message + async with get_async_session_maker() as session: + await PGMQOperation.create_queue_async( + queue_name, unlogged=False, session=session, commit=True + ) + msg_id = await PGMQOperation.send_async( + queue_name, MSG, delay=0, session=session, commit=True + ) + + # Archive message + async with get_async_session_maker() as session: + archived = await PGMQOperation.archive_async( + queue_name, msg_id, session=session, commit=True + ) + + assert archived is True + + # Clean up + async with get_async_session_maker() as session: + await PGMQOperation.drop_queue_async( + queue_name, partitioned=False, session=session, commit=True + ) + + def test_metrics_sync(get_session_maker, db_session): """Test getting queue metrics.""" queue_name = f"test_queue_{uuid.uuid4().hex}" @@ -358,6 +580,43 @@ def test_metrics_all_sync(get_session_maker, db_session): ) +@pytest.mark.asyncio +async def test_metrics_all_async(get_async_session_maker, db_session): + """Test getting metrics for all queues asynchronously.""" + queue_name1 = f"test_queue_{uuid.uuid4().hex}" + queue_name2 = f"test_queue_{uuid.uuid4().hex}" + + # Create two queues + async with get_async_session_maker() as session: + await PGMQOperation.create_queue_async( + queue_name1, unlogged=False, session=session, commit=True + ) + await PGMQOperation.create_queue_async( + queue_name2, unlogged=False, session=session, commit=True + ) + + # Get metrics for all queues + async with get_async_session_maker() as session: + all_metrics = await PGMQOperation.metrics_all_async( + session=session, commit=True + ) + + assert all_metrics is not None + assert len(all_metrics) >= 2 + queue_names = [m.queue_name for m in all_metrics] + assert queue_name1 in queue_names + assert queue_name2 in queue_names + + # Clean up + async with get_async_session_maker() as session: + await PGMQOperation.drop_queue_async( + queue_name1, partitioned=False, session=session, commit=True + ) + await PGMQOperation.drop_queue_async( + queue_name2, partitioned=False, session=session, commit=True + ) + + def test_transaction_rollback_sync(get_session_maker, db_session): """Test that operations can be rolled back when commit=False.""" queue_name = f"test_queue_{uuid.uuid4().hex}" @@ -373,6 +632,22 @@ def test_transaction_rollback_sync(get_session_maker, db_session): assert check_queue_exists(db_session, queue_name) is False +@pytest.mark.asyncio +async def test_transaction_rollback_async(get_async_session_maker, db_session): + """Test that operations can be rolled back when commit=False asynchronously.""" + queue_name = f"test_queue_{uuid.uuid4().hex}" + + # Create queue with commit=False, then rollback + async with get_async_session_maker() as session: + await PGMQOperation.create_queue_async( + queue_name, unlogged=False, session=session, commit=False + ) + await session.rollback() + + # Queue should not exist + assert check_queue_exists(db_session, queue_name) is False + + def test_transaction_commit_sync(get_session_maker, db_session): """Test that operations are committed when commit=True.""" queue_name = f"test_queue_{uuid.uuid4().hex}" @@ -393,6 +668,27 @@ def test_transaction_commit_sync(get_session_maker, db_session): ) +@pytest.mark.asyncio +async def test_transaction_commit_async(get_async_session_maker, db_session): + """Test that operations are committed when commit=True asynchronously.""" + queue_name = f"test_queue_{uuid.uuid4().hex}" + + # Create queue with commit=True + async with get_async_session_maker() as session: + await PGMQOperation.create_queue_async( + queue_name, unlogged=False, session=session, commit=True + ) + + # Queue should exist + assert check_queue_exists(db_session, queue_name) is True + + # Clean up + async with get_async_session_maker() as session: + await PGMQOperation.drop_queue_async( + queue_name, partitioned=False, session=session, commit=True + ) + + # Async tests @@ -779,6 +1075,51 @@ def test_create_time_based_partitioned_queue_sync(get_session_maker, db_session) ) +@pytest.mark.asyncio +async def test_create_time_based_partitioned_queue_async( + get_async_session_maker, db_session +): + """Test creating a time-based partitioned queue asynchronously.""" + queue_name = f"time_{uuid.uuid4().hex[:20]}" + + # First ensure pg_partman extension is available + try: + async with get_async_session_maker() as session: + await PGMQOperation.check_pg_partman_ext_async(session=session, commit=True) + except Exception as e: + pytest.skip(f"pg_partman extension not available: {e}") + + # Create partitioned queue with time-based partitioning + async with get_async_session_maker() as session: + await PGMQOperation.create_partitioned_queue_async( + queue_name, + partition_interval="1 day", + retention_interval="7 days", + session=session, + commit=True, + ) + + assert check_queue_exists(db_session, queue_name) is True + + # Test sending and reading from time-based partitioned queue + async with get_async_session_maker() as session: + msg_id = await PGMQOperation.send_async( + queue_name, MSG, delay=0, session=session, commit=True + ) + msg = await PGMQOperation.read_async( + queue_name, vt=30, session=session, commit=True + ) + + assert msg is not None + assert msg.msg_id == msg_id + + # Clean up + async with get_async_session_maker() as session: + await PGMQOperation.drop_queue_async( + queue_name, partitioned=True, session=session, commit=True + ) + + # Async tests for newly added coverage From 86cd6a3a9c5c3706ad7de09bc4303690fb468785 Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Tue, 6 Jan 2026 18:09:39 +0800 Subject: [PATCH 7/7] Fix _check_pg_partman_ext naming --- pgmq_sqlalchemy/queue.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pgmq_sqlalchemy/queue.py b/pgmq_sqlalchemy/queue.py index 74de108..31fcaaa 100644 --- a/pgmq_sqlalchemy/queue.py +++ b/pgmq_sqlalchemy/queue.py @@ -111,12 +111,11 @@ async def _check_pg_partman_ext_async(self) -> None: async with self.session_maker() as session: await PGMQOperation.check_pg_partman_ext_async(session=session, commit=True) - def _check_pg_partman_ext_sync(self) -> None: + def _check_pg_partman_ext(self) -> None: """Check if the pg_partman extension exists.""" with self.session_maker() as session: PGMQOperation.check_pg_partman_ext(session=session, commit=True) - def _execute_operation( self, op_sync, @@ -1059,7 +1058,6 @@ def consumer_with_backoff_retry(pgmq_client: PGMQueue, queue_name: str): return self._execute_operation( PGMQOperation.set_vt, - session, commit, queue_name,