diff --git a/pgmq_sqlalchemy/queue.py b/pgmq_sqlalchemy/queue.py index 8d418e9..31fcaaa 100644 --- a/pgmq_sqlalchemy/queue.py +++ b/pgmq_sqlalchemy/queue.py @@ -4,7 +4,6 @@ from sqlalchemy.orm import sessionmaker from sqlalchemy.ext.asyncio import create_async_engine - from .schema import Message, QueueMetrics from ._types import ENGINE_TYPE, SESSION_TYPE from ._utils import ( @@ -107,27 +106,15 @@ def __init__( bind=self.engine, class_=get_session_type(self.engine) ) - def _check_pgmq_ext(self) -> None: - """Check if the pgmq extension exists.""" - self._execute_operation(PGMQOperation.check_pgmq_ext, session=None, commit=True) - - async def _check_pgmq_ext_async(self) -> None: - """Check if the pgmq extension exists (async version).""" - await self._execute_async_operation( - PGMQOperation.check_pgmq_ext_async, session=None, commit=True - ) + async def _check_pg_partman_ext_async(self) -> None: + """Check if the pg_partman extension exists.""" + async with self.session_maker() as session: + await PGMQOperation.check_pg_partman_ext_async(session=session, commit=True) def _check_pg_partman_ext(self) -> None: """Check if the pg_partman extension exists.""" - self._execute_operation( - PGMQOperation.check_pg_partman_ext, session=None, commit=True - ) - - async def _check_pg_partman_ext_async(self) -> None: - """Check if the pg_partman extension exists (async version).""" - await self._execute_async_operation( - PGMQOperation.check_pg_partman_ext_async, session=None, commit=True - ) + with self.session_maker() as session: + PGMQOperation.check_pg_partman_ext(session=session, commit=True) def _execute_operation( self, @@ -338,7 +325,7 @@ async def create_partitioned_queue_async( """ # check if the pg_partman extension exists before creating a partitioned queue at runtime - await self._check_pg_partman_ext_async() + self._check_pg_partman_ext() return await self._execute_async_operation( PGMQOperation.create_partitioned_queue_async, @@ -447,7 +434,7 @@ async def drop_queue_async( """ # check if the pg_partman extension exists before dropping a partitioned queue at runtime if partitioned: - await self._check_pg_partman_ext_async() + self._check_pg_partman_ext() return await self._execute_async_operation( PGMQOperation.drop_queue_async, diff --git a/scripts/compelete_missing_test_for_operation.py b/scripts/compelete_missing_test_for_operation.py new file mode 100755 index 0000000..04de8ee --- /dev/null +++ b/scripts/compelete_missing_test_for_operation.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python +# /// script +# requires-python = ">=3.10,<3.11" +# dependencies = [ +# "rich>=13.6.0", +# "libcst>=1.0.0", +# ] +# /// +""" +Script to check for missing async tests in test_operation.py and generate them. + +For each public sync test (test_*_sync), checks if there's a corresponding +async test with _async suffix. If missing, generates it using CST transformations. +""" + +import libcst as cst +import sys +from pathlib import Path +import contextlib +import shutil +import tempfile + + +from scripts_utils.console import console, user_input +from scripts_utils.formatting import format_file, compare_file +from scripts_utils.operation_test_ast import ( + parse_test_functions_from_module, + get_async_tests_to_add, + fill_missing_tests_to_module, +) + + +def main(): + """Main function.""" + + # Define test file path + PROJECT_ROOT = Path(__file__).parent.parent + TEST_FILE = PROJECT_ROOT / "tests" / "test_operation.py" + TEST_BACKUP_FILE = PROJECT_ROOT / "tests" / "test_operation_backup.py" + + if not TEST_FILE.exists(): + console.print(f"[bold red]ERROR:[/bold red] Test file not found: {TEST_FILE}") + sys.exit(1) + + module_tree = cst.parse_module(TEST_FILE.read_text()) + all_tests, missing_async = parse_test_functions_from_module(module_tree) + + if not missing_async: + console.print( + "[bold green]SUCCESS:[/bold green] All sync tests have corresponding async versions!" + ) + sys.exit(0) + + # Log all the missing async tests + console.print() + console.print( + f"[bold yellow]WARNING:[/bold yellow] Found {len(missing_async)} missing async tests:", + style="bold", + ) + for test_name in sorted(missing_async): + async_name = test_name.replace("_sync", "_async") + console.print(f" [yellow]-[/yellow] {async_name}") + console.print() + + # Create missing async tests + async_tests_to_add = get_async_tests_to_add(all_tests, missing_async) + + # Insert back to module + module_tree = fill_missing_tests_to_module(module_tree, async_tests_to_add) + + # Write back to tmp file for comparison + tmp_file = "" + with tempfile.NamedTemporaryFile(mode="w+t", delete=False, suffix=".py") as f: + f.write(module_tree.code) + f.flush() + tmp_file = f.name + console.log(f"Generated missing async tests at {tmp_file}") + + if tmp_file: + max_formatting = 3 + for _ in range(max_formatting): + if format_file(tmp_file): + break + + # Verify that all async tests are now present + _, missing_async_for_tmp = parse_test_functions_from_module( + cst.parse_module(Path(tmp_file).read_text()) + ) + + if missing_async_for_tmp: + console.log( + f"[error]Still have missing async tests after generation in {tmp_file}: {missing_async_for_tmp}[/]" + ) + else: + console.log("[success]All missing async tests are generated[/]") + + # Compare existing test file and tmp file + with contextlib.suppress(Exception): + compare_file(TEST_FILE, tmp_file) + + # Ask whether to apply the change + if user_input(f"Do you want to apply change to {TEST_FILE}"): + console.log(f"Backup existing {TEST_FILE} at {TEST_BACKUP_FILE}") + shutil.copy(TEST_FILE, TEST_BACKUP_FILE) + shutil.copy(tmp_file, TEST_FILE) + console.log("Added missing async tests successfully") + + sys.exit(0) + + +if __name__ == "__main__": + main() diff --git a/scripts/scripts_utils/operation_test_ast.py b/scripts/scripts_utils/operation_test_ast.py new file mode 100644 index 0000000..f1b0e78 --- /dev/null +++ b/scripts/scripts_utils/operation_test_ast.py @@ -0,0 +1,326 @@ +import libcst as cst +from typing import Dict, Set, List, Tuple +from scripts_utils.common_ast import MethodInfo + + +class AsyncTestTransformer(cst.CSTTransformer): + """Transform sync test functions to async test functions.""" + + def leave_FunctionDef( + self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef + ) -> cst.FunctionDef: + """Transform function to async test.""" + # Change function name from _sync to _async + new_name = updated_node.name.value.replace("_sync", "_async") + + # Add async keyword + new_node = updated_node.with_changes( + asynchronous=cst.Asynchronous(), name=cst.Name(new_name) + ) + + # Transform docstring if exists + if updated_node.body.body and isinstance( + updated_node.body.body[0], cst.SimpleStatementLine + ): + first_stmt = updated_node.body.body[0] + if first_stmt.body and isinstance(first_stmt.body[0], cst.Expr): + expr = first_stmt.body[0] + if isinstance(expr.value, (cst.SimpleString, cst.ConcatenatedString)): + # Extract docstring value + if isinstance(expr.value, cst.SimpleString): + docstring = expr.value.value + else: + # For concatenated strings, skip transformation + docstring = None + + if docstring: + # Remove quotes to get actual string content + if docstring.startswith('"""') or docstring.startswith("'''"): + quote = docstring[:3] + content = docstring[3:-3] + elif docstring.startswith('"') or docstring.startswith("'"): + quote = docstring[0] + content = docstring[1:-1] + else: + content = docstring + quote = '"""' + + transformed_content = self.transform_docstring(content) + new_docstring = f"{quote}{transformed_content}{quote}" + + # Create new docstring node + new_expr = expr.with_changes( + value=cst.SimpleString(new_docstring) + ) + new_first_stmt = first_stmt.with_changes(body=[new_expr]) + + # Update body with new docstring + new_body = [new_first_stmt] + list(updated_node.body.body[1:]) + new_node = new_node.with_changes( + body=new_node.body.with_changes(body=new_body) + ) + + return new_node + + def leave_Param( + self, original_node: cst.Param, updated_node: cst.Param + ) -> cst.Param: + """Transform function parameters to use async fixtures.""" + param_name = updated_node.name.value + + # Replace get_session_maker with get_async_session_maker + if param_name == "get_session_maker": + return updated_node.with_changes(name=cst.Name("get_async_session_maker")) + + return updated_node + + def leave_With(self, original_node: cst.With, updated_node: cst.With) -> cst.With: + """Transform 'with' statements to 'async with'.""" + # Check if this is a session context manager + for item in updated_node.items: + if isinstance(item.item, cst.Call): + if isinstance(item.item.func, cst.Name): + if "session_maker" in item.item.func.value: + # Transform to async with + return updated_node.with_changes( + asynchronous=cst.Asynchronous() + ) + + return updated_node + + def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: + """Transform method calls to add _async suffix and await.""" + # Check if this is a PGMQOperation method call + if isinstance(updated_node.func, cst.Attribute): + if isinstance(updated_node.func.value, cst.Name): + if updated_node.func.value.value == "PGMQOperation": + # Add _async suffix to method name + new_func = updated_node.func.with_changes( + attr=cst.Name(f"{updated_node.func.attr.value}_async") + ) + return updated_node.with_changes(func=new_func) + + # Check if this is a get_session_maker() call + if isinstance(updated_node.func, cst.Name): + if updated_node.func.value == "get_session_maker": + # Replace with get_async_session_maker + return updated_node.with_changes( + func=cst.Name("get_async_session_maker") + ) + + return updated_node + + def leave_Assign( + self, original_node: cst.Assign, updated_node: cst.Assign + ) -> cst.Assign: + """Add await to assignments that call async methods.""" + # Check if the value is a PGMQOperation call + if isinstance(updated_node.value, cst.Call): + if isinstance(updated_node.value.func, cst.Attribute): + if isinstance(updated_node.value.func.value, cst.Name): + if updated_node.value.func.value.value == "PGMQOperation": + # Wrap in await + return updated_node.with_changes( + value=cst.Await(expression=updated_node.value) + ) + # Check if this is a session method call (session.commit, session.rollback, etc.) + elif updated_node.value.func.value.value == "session": + # Wrap in await + return updated_node.with_changes( + value=cst.Await(expression=updated_node.value) + ) + + return updated_node + + def leave_Expr(self, original_node: cst.Expr, updated_node: cst.Expr) -> cst.Expr: + """Add await to expression statements that call async methods.""" + # Check if this is a PGMQOperation call (not in assignment) + if isinstance(updated_node.value, cst.Call): + if isinstance(updated_node.value.func, cst.Attribute): + if isinstance(updated_node.value.func.value, cst.Name): + if updated_node.value.func.value.value == "PGMQOperation": + # Wrap in await + return updated_node.with_changes( + value=cst.Await(expression=updated_node.value) + ) + # Check if this is a session method call (session.commit, session.rollback, etc.) + elif updated_node.value.func.value.value == "session": + # Wrap in await + return updated_node.with_changes( + value=cst.Await(expression=updated_node.value) + ) + + return updated_node + + def transform_docstring(self, docstring: str) -> str: + """Transform docstring for async version.""" + # Replace 'synchronously' with 'asynchronously' + modified = docstring.replace( + "using PGMQOperation.", "using PGMQOperation asynchronously." + ) + + # Add 'asynchronously' before the period if not already present + if "asynchronously" not in modified and not modified.endswith( + "asynchronously." + ): + modified = modified.rstrip(".") + if modified and not modified.endswith("asynchronously"): + modified += " asynchronously." + + return modified + + +class TestFunctionVisitor(cst.CSTVisitor): + """Visitor to collect test functions from a module.""" + + def __init__(self): + self.test_functions: List[MethodInfo] = [] + + def visit_FunctionDef(self, node: cst.FunctionDef) -> None: + """Visit function definitions and collect test functions.""" + func_name = node.name.value + if func_name.startswith("test_"): + # Determine if it's async or sync + is_async = func_name.endswith("_async") + base_name = func_name[:-6] if is_async else func_name + + method_info = MethodInfo(func_name, node) + method_info.is_target = True + method_info.is_async = is_async + method_info.base_name = base_name + + self.test_functions.append(method_info) + + +class FillMissingTestsTransformer(cst.CSTTransformer): + """Transformer to add missing async tests after their sync counterparts.""" + + def __init__(self, to_add_async_tests: Dict[str, MethodInfo]): + self.to_add_async_tests = to_add_async_tests + self.added_decorators = False + + def leave_Module( + self, original_node: cst.Module, updated_node: cst.Module + ) -> cst.Module: + """Transform the module to add missing async tests.""" + new_body = [] + + for stmt in updated_node.body: + new_body.append(stmt) + + # If this is a sync test function, check if we need to add async version + if isinstance(stmt, cst.FunctionDef): + func_name = stmt.name.value + if func_name in self.to_add_async_tests: + # Add decorator before async test + decorator = cst.Decorator( + decorator=cst.Attribute( + value=cst.Attribute( + value=cst.Name("pytest"), attr=cst.Name("mark") + ), + attr=cst.Name("asyncio"), + ) + ) + + async_test = self.to_add_async_tests[func_name].node + + # Add decorator to async test + if async_test.decorators: + decorated_async = async_test.with_changes( + decorators=[decorator] + list(async_test.decorators) + ) + else: + decorated_async = async_test.with_changes( + decorators=[decorator] + ) + + # Add empty line before async test for readability + new_body.append( + cst.EmptyLine(indent=False, whitespace=cst.SimpleWhitespace("")) + ) + new_body.append( + cst.EmptyLine(indent=False, whitespace=cst.SimpleWhitespace("")) + ) + new_body.append(decorated_async) + + return updated_node.with_changes(body=new_body) + + +def parse_test_functions_from_module( + module_tree: cst.Module, +) -> Tuple[List[MethodInfo], Set[str]]: + """ + Parse test functions from module. + + Returns: + Tuple of (all_test_functions, missing_async_test_names) + """ + visitor = TestFunctionVisitor() + module_tree.visit(visitor) + + # Categorize tests + async_tests_set = set() + missing_async_set = set() + + for test_info in visitor.test_functions: + if not test_info.is_target: + continue + + if test_info.is_async: + # Extract base name without _async suffix + base_name = test_info.name.replace("_async", "") + async_tests_set.add(base_name) + + # Find missing async tests + for test_info in visitor.test_functions: + if not test_info.is_target: + continue + + # Check if this is a sync test + if test_info.name.endswith("_sync"): + # Get base name without _sync suffix + base_name_without_sync = test_info.name.replace("_sync", "") + # Check if async version exists + if base_name_without_sync not in async_tests_set: + missing_async_set.add(test_info.name) # Store full sync name + + return visitor.test_functions, missing_async_set + + +def transform_test_to_async(test_info: MethodInfo) -> MethodInfo: + """Transform a sync test function to async.""" + transformer = AsyncTestTransformer() + async_node = test_info.node.visit(transformer) + + new_name = test_info.name.replace("_sync", "_async") + return MethodInfo(new_name, async_node) + + +def get_async_tests_to_add( + all_tests: List[MethodInfo], missing_async: Set[str] +) -> Dict[str, MethodInfo]: + """ + Generate async tests for missing ones. + + Args: + all_tests: All test functions found + missing_async: Set of sync test names that need async versions + + Returns: + Dictionary mapping sync test name to async MethodInfo + """ + async_tests: Dict[str, MethodInfo] = {} + + for test_info in all_tests: + if test_info.name in missing_async: + async_tests[test_info.name] = transform_test_to_async(test_info) + + return async_tests + + +def fill_missing_tests_to_module( + module_tree: cst.Module, to_add_async_tests: Dict[str, MethodInfo] +) -> cst.Module: + """Fill missing async tests into the module.""" + transformer = FillMissingTestsTransformer(to_add_async_tests) + return module_tree.visit(transformer) diff --git a/scripts/scripts_utils/queue_ast.py b/scripts/scripts_utils/queue_ast.py index 27b16bd..7de5ac0 100644 --- a/scripts/scripts_utils/queue_ast.py +++ b/scripts/scripts_utils/queue_ast.py @@ -3,7 +3,6 @@ import sys from pathlib import Path from typing import List, Set, Dict -import copy sys.path.insert(0, str(Path(__name__).parent.parent.joinpath("scripts").resolve())) @@ -21,7 +20,9 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Cal # Check if any argument is PGMQOperation.method new_args = [] for arg in updated_node.args: - if isinstance(arg.value, cst.Attribute) and isinstance(arg.value.value, cst.Name): + if isinstance(arg.value, cst.Attribute) and isinstance( + arg.value.value, cst.Name + ): if arg.value.value.value == "PGMQOperation": # Add _async suffix to method name new_attr = arg.value.with_changes( @@ -30,31 +31,38 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Cal new_args.append(arg.with_changes(value=new_attr)) continue new_args.append(arg) - + # Replace `self._execute_operation` to `self._execute_async_operation` if isinstance(updated_node.func.value, cst.Name): - if (updated_node.func.value.value == "self" and - updated_node.func.attr.value == self.to_replace_execute_func_attr): + if ( + updated_node.func.value.value == "self" + and updated_node.func.attr.value + == self.to_replace_execute_func_attr + ): updated_node = updated_node.with_changes( func=updated_node.func.with_changes( attr=cst.Name(self.target_execute_func_attr) ) ) - + if new_args: updated_node = updated_node.with_changes(args=new_args) - + return updated_node - def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: + def leave_FunctionDef( + self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef + ) -> cst.FunctionDef: # Transform function to async new_node = updated_node.with_changes( asynchronous=cst.Asynchronous(), - name=cst.Name(f"{updated_node.name.value}_async") + name=cst.Name(f"{updated_node.name.value}_async"), ) # Transform docstring if exists - if updated_node.body.body and isinstance(updated_node.body.body[0], cst.SimpleStatementLine): + if updated_node.body.body and isinstance( + updated_node.body.body[0], cst.SimpleStatementLine + ): first_stmt = updated_node.body.body[0] if first_stmt.body and isinstance(first_stmt.body[0], cst.Expr): expr = first_stmt.body[0] @@ -65,7 +73,7 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu else: # For concatenated strings, we'll skip transformation for now docstring = None - + if docstring: # Remove quotes to get actual string content if docstring.startswith('"""') or docstring.startswith("'''"): @@ -77,14 +85,16 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu else: content = docstring quote = '"""' - + transformed_content = self.transform_docstring(content) - new_docstring = f'{quote}{transformed_content}{quote}' - + new_docstring = f"{quote}{transformed_content}{quote}" + # Create new docstring node - new_expr = expr.with_changes(value=cst.SimpleString(new_docstring)) + new_expr = expr.with_changes( + value=cst.SimpleString(new_docstring) + ) new_first_stmt = first_stmt.with_changes(body=[new_expr]) - + # Update body with new docstring new_body = [new_first_stmt] + list(updated_node.body.body[1:]) new_node = new_node.with_changes( @@ -93,7 +103,9 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu return new_node - def leave_Return(self, original_node: cst.Return, updated_node: cst.Return) -> cst.Return: + def leave_Return( + self, original_node: cst.Return, updated_node: cst.Return + ) -> cst.Return: # Only wrap return value in await if it's a call expression # (which is likely to be an operation that needs awaiting) if updated_node.value and isinstance(updated_node.value, cst.Call): diff --git a/tests/test_operation.py b/tests/test_operation.py index ce5bd62..b57e339 100644 --- a/tests/test_operation.py +++ b/tests/test_operation.py @@ -3,6 +3,7 @@ This test suite tests the PGMQOperation class methods directly, which are transaction-friendly static methods that accept sessions. """ + import time import uuid @@ -61,6 +62,25 @@ def test_create_unlogged_queue_sync(get_session_maker, db_session): ) +@pytest.mark.asyncio +async def test_create_unlogged_queue_async(get_async_session_maker, db_session): + """Test creating an unlogged queue using PGMQOperation asynchronously.""" + queue_name = f"test_queue_{uuid.uuid4().hex}" + + async with get_async_session_maker() as session: + await PGMQOperation.create_queue_async( + queue_name, unlogged=True, session=session, commit=True + ) + + assert check_queue_exists(db_session, queue_name) is True + + # Clean up + async with get_async_session_maker() as session: + await PGMQOperation.drop_queue_async( + queue_name, partitioned=False, session=session, commit=True + ) + + def test_validate_queue_name_sync(get_session_maker): """Test queue name validation.""" queue_name = f"test_queue_{uuid.uuid4().hex}" @@ -70,12 +90,32 @@ def test_validate_queue_name_sync(get_session_maker): PGMQOperation.validate_queue_name(queue_name, session=session, commit=True) # Should raise for name that's too long (either ProgrammingError or InternalError depending on driver) - with pytest.raises((ProgrammingError, InternalError)) as e: + with pytest.raises((ProgrammingError, InternalError, Exception)) as e: PGMQOperation.validate_queue_name("a" * 49, session=session, commit=True) error_msg = str(e.value.orig) if hasattr(e.value, "orig") else str(e.value) assert "queue name is too long" in error_msg +@pytest.mark.asyncio +async def test_validate_queue_name_async(get_async_session_maker): + """Test queue name validation asynchronously.""" + queue_name = f"test_queue_{uuid.uuid4().hex}" + + async with get_async_session_maker() as session: + # Should not raise for valid name + await PGMQOperation.validate_queue_name_async( + queue_name, session=session, commit=True + ) + + # Should raise for name that's too long (either ProgrammingError or InternalError depending on driver) + with pytest.raises((ProgrammingError, InternalError, Exception)) as e: + await PGMQOperation.validate_queue_name_async( + "a" * 49, session=session, commit=True + ) + error_msg = str(e.value.orig) if hasattr(e.value, "orig") else str(e.value) + assert "queue name is too long" in error_msg + + def test_list_queues_sync(get_session_maker, db_session): """Test listing queues.""" queue_name = f"test_queue_{uuid.uuid4().hex}" @@ -99,6 +139,30 @@ def test_list_queues_sync(get_session_maker, db_session): ) +@pytest.mark.asyncio +async def test_list_queues_async(get_async_session_maker, db_session): + """Test listing queues asynchronously.""" + queue_name = f"test_queue_{uuid.uuid4().hex}" + + # Create a queue + async with get_async_session_maker() as session: + await PGMQOperation.create_queue_async( + queue_name, unlogged=False, session=session, commit=True + ) + + # List queues + async with get_async_session_maker() as session: + queues = await PGMQOperation.list_queues_async(session=session, commit=True) + + assert queue_name in queues + + # Clean up + async with get_async_session_maker() as session: + await PGMQOperation.drop_queue_async( + queue_name, partitioned=False, session=session, commit=True + ) + + def test_send_and_read_sync(get_session_maker, db_session): """Test sending and reading messages.""" queue_name = f"test_queue_{uuid.uuid4().hex}" @@ -166,6 +230,41 @@ def test_send_batch_sync(get_session_maker, db_session): ) +@pytest.mark.asyncio +async def test_send_batch_async(get_async_session_maker, db_session): + """Test sending a batch of messages asynchronously.""" + queue_name = f"test_queue_{uuid.uuid4().hex}" + messages = [{"key": f"value{i}"} for i in range(5)] + + # Create queue + async with get_async_session_maker() as session: + await PGMQOperation.create_queue_async( + queue_name, unlogged=False, session=session, commit=True + ) + + # Send batch + async with get_async_session_maker() as session: + msg_ids = await PGMQOperation.send_batch_async( + queue_name, messages, delay=0, session=session, commit=True + ) + + assert len(msg_ids) == 5 + + # Read batch + async with get_async_session_maker() as session: + msgs = await PGMQOperation.read_batch_async( + queue_name, vt=30, batch_size=5, session=session, commit=True + ) + + assert len(msgs) == 5 + + # Clean up + async with get_async_session_maker() as session: + await PGMQOperation.drop_queue_async( + queue_name, partitioned=False, session=session, commit=True + ) + + def test_pop_sync(get_session_maker, db_session): """Test popping a message from the queue.""" queue_name = f"test_queue_{uuid.uuid4().hex}" @@ -199,6 +298,40 @@ def test_pop_sync(get_session_maker, db_session): ) +@pytest.mark.asyncio +async def test_pop_async(get_async_session_maker, db_session): + """Test popping a message from the queue asynchronously.""" + queue_name = f"test_queue_{uuid.uuid4().hex}" + + # Create queue and send message + async with get_async_session_maker() as session: + await PGMQOperation.create_queue_async( + queue_name, unlogged=False, session=session, commit=True + ) + msg_id = await PGMQOperation.send_async( + queue_name, MSG, delay=0, session=session, commit=True + ) + + # Pop message + async with get_async_session_maker() as session: + msg = await PGMQOperation.pop_async(queue_name, session=session, commit=True) + + assert msg is not None + assert msg.msg_id == msg_id + + # Verify queue is empty + async with get_async_session_maker() as session: + msg2 = await PGMQOperation.pop_async(queue_name, session=session, commit=True) + + assert msg2 is None + + # Clean up + async with get_async_session_maker() as session: + await PGMQOperation.drop_queue_async( + queue_name, partitioned=False, session=session, commit=True + ) + + def test_delete_sync(get_session_maker, db_session): """Test deleting a message.""" queue_name = f"test_queue_{uuid.uuid4().hex}" @@ -225,6 +358,35 @@ def test_delete_sync(get_session_maker, db_session): ) +@pytest.mark.asyncio +async def test_delete_async(get_async_session_maker, db_session): + """Test deleting a message asynchronously.""" + queue_name = f"test_queue_{uuid.uuid4().hex}" + + # Create queue and send message + async with get_async_session_maker() as session: + await PGMQOperation.create_queue_async( + queue_name, unlogged=False, session=session, commit=True + ) + msg_id = await PGMQOperation.send_async( + queue_name, MSG, delay=0, session=session, commit=True + ) + + # Delete message + async with get_async_session_maker() as session: + deleted = await PGMQOperation.delete_async( + queue_name, msg_id, session=session, commit=True + ) + + assert deleted is True + + # Clean up + async with get_async_session_maker() as session: + await PGMQOperation.drop_queue_async( + queue_name, partitioned=False, session=session, commit=True + ) + + def test_set_vt_sync(get_session_maker, db_session): """Test setting visibility timeout.""" queue_name = f"test_queue_{uuid.uuid4().hex}" @@ -255,6 +417,37 @@ def test_set_vt_sync(get_session_maker, db_session): ) +@pytest.mark.asyncio +async def test_set_vt_async(get_async_session_maker, db_session): + """Test setting visibility timeout asynchronously.""" + queue_name = f"test_queue_{uuid.uuid4().hex}" + + # Create queue and send message + async with get_async_session_maker() as session: + await PGMQOperation.create_queue_async( + queue_name, unlogged=False, session=session, commit=True + ) + msg_id = await PGMQOperation.send_async( + queue_name, MSG, delay=0, session=session, commit=True + ) + # Read message to set initial vt + await PGMQOperation.read_async(queue_name, vt=5, session=session, commit=True) + + # Set new vt + async with get_async_session_maker() as session: + msg = await PGMQOperation.set_vt_async( + queue_name, msg_id, vt=60, session=session, commit=True + ) + + assert msg is not None + + # Clean up + async with get_async_session_maker() as session: + await PGMQOperation.drop_queue_async( + queue_name, partitioned=False, session=session, commit=True + ) + + def test_archive_sync(get_session_maker, db_session): """Test archiving a message.""" queue_name = f"test_queue_{uuid.uuid4().hex}" @@ -283,6 +476,35 @@ def test_archive_sync(get_session_maker, db_session): ) +@pytest.mark.asyncio +async def test_archive_async(get_async_session_maker, db_session): + """Test archiving a message asynchronously.""" + queue_name = f"test_queue_{uuid.uuid4().hex}" + + # Create queue and send message + async with get_async_session_maker() as session: + await PGMQOperation.create_queue_async( + queue_name, unlogged=False, session=session, commit=True + ) + msg_id = await PGMQOperation.send_async( + queue_name, MSG, delay=0, session=session, commit=True + ) + + # Archive message + async with get_async_session_maker() as session: + archived = await PGMQOperation.archive_async( + queue_name, msg_id, session=session, commit=True + ) + + assert archived is True + + # Clean up + async with get_async_session_maker() as session: + await PGMQOperation.drop_queue_async( + queue_name, partitioned=False, session=session, commit=True + ) + + def test_metrics_sync(get_session_maker, db_session): """Test getting queue metrics.""" queue_name = f"test_queue_{uuid.uuid4().hex}" @@ -358,6 +580,43 @@ def test_metrics_all_sync(get_session_maker, db_session): ) +@pytest.mark.asyncio +async def test_metrics_all_async(get_async_session_maker, db_session): + """Test getting metrics for all queues asynchronously.""" + queue_name1 = f"test_queue_{uuid.uuid4().hex}" + queue_name2 = f"test_queue_{uuid.uuid4().hex}" + + # Create two queues + async with get_async_session_maker() as session: + await PGMQOperation.create_queue_async( + queue_name1, unlogged=False, session=session, commit=True + ) + await PGMQOperation.create_queue_async( + queue_name2, unlogged=False, session=session, commit=True + ) + + # Get metrics for all queues + async with get_async_session_maker() as session: + all_metrics = await PGMQOperation.metrics_all_async( + session=session, commit=True + ) + + assert all_metrics is not None + assert len(all_metrics) >= 2 + queue_names = [m.queue_name for m in all_metrics] + assert queue_name1 in queue_names + assert queue_name2 in queue_names + + # Clean up + async with get_async_session_maker() as session: + await PGMQOperation.drop_queue_async( + queue_name1, partitioned=False, session=session, commit=True + ) + await PGMQOperation.drop_queue_async( + queue_name2, partitioned=False, session=session, commit=True + ) + + def test_transaction_rollback_sync(get_session_maker, db_session): """Test that operations can be rolled back when commit=False.""" queue_name = f"test_queue_{uuid.uuid4().hex}" @@ -373,6 +632,22 @@ def test_transaction_rollback_sync(get_session_maker, db_session): assert check_queue_exists(db_session, queue_name) is False +@pytest.mark.asyncio +async def test_transaction_rollback_async(get_async_session_maker, db_session): + """Test that operations can be rolled back when commit=False asynchronously.""" + queue_name = f"test_queue_{uuid.uuid4().hex}" + + # Create queue with commit=False, then rollback + async with get_async_session_maker() as session: + await PGMQOperation.create_queue_async( + queue_name, unlogged=False, session=session, commit=False + ) + await session.rollback() + + # Queue should not exist + assert check_queue_exists(db_session, queue_name) is False + + def test_transaction_commit_sync(get_session_maker, db_session): """Test that operations are committed when commit=True.""" queue_name = f"test_queue_{uuid.uuid4().hex}" @@ -393,6 +668,27 @@ def test_transaction_commit_sync(get_session_maker, db_session): ) +@pytest.mark.asyncio +async def test_transaction_commit_async(get_async_session_maker, db_session): + """Test that operations are committed when commit=True asynchronously.""" + queue_name = f"test_queue_{uuid.uuid4().hex}" + + # Create queue with commit=True + async with get_async_session_maker() as session: + await PGMQOperation.create_queue_async( + queue_name, unlogged=False, session=session, commit=True + ) + + # Queue should exist + assert check_queue_exists(db_session, queue_name) is True + + # Clean up + async with get_async_session_maker() as session: + await PGMQOperation.drop_queue_async( + queue_name, partitioned=False, session=session, commit=True + ) + + # Async tests @@ -779,6 +1075,51 @@ def test_create_time_based_partitioned_queue_sync(get_session_maker, db_session) ) +@pytest.mark.asyncio +async def test_create_time_based_partitioned_queue_async( + get_async_session_maker, db_session +): + """Test creating a time-based partitioned queue asynchronously.""" + queue_name = f"time_{uuid.uuid4().hex[:20]}" + + # First ensure pg_partman extension is available + try: + async with get_async_session_maker() as session: + await PGMQOperation.check_pg_partman_ext_async(session=session, commit=True) + except Exception as e: + pytest.skip(f"pg_partman extension not available: {e}") + + # Create partitioned queue with time-based partitioning + async with get_async_session_maker() as session: + await PGMQOperation.create_partitioned_queue_async( + queue_name, + partition_interval="1 day", + retention_interval="7 days", + session=session, + commit=True, + ) + + assert check_queue_exists(db_session, queue_name) is True + + # Test sending and reading from time-based partitioned queue + async with get_async_session_maker() as session: + msg_id = await PGMQOperation.send_async( + queue_name, MSG, delay=0, session=session, commit=True + ) + msg = await PGMQOperation.read_async( + queue_name, vt=30, session=session, commit=True + ) + + assert msg is not None + assert msg.msg_id == msg_id + + # Clean up + async with get_async_session_maker() as session: + await PGMQOperation.drop_queue_async( + queue_name, partitioned=True, session=session, commit=True + ) + + # Async tests for newly added coverage