diff --git a/newsfragments/378.internal.rst b/newsfragments/378.internal.rst new file mode 100644 index 000000000..a9b592dec --- /dev/null +++ b/newsfragments/378.internal.rst @@ -0,0 +1 @@ +Add shared pubsub test fixtures (``GossipSubHarness``, ``gossipsub_nodes``, ``connected_gossipsub_nodes``, ``subscribed_mesh``) and reusable polling helpers (``wait_for``, ``wait_for_convergence``) to support the pubsub test suite refactor. diff --git a/tests/core/pubsub/conftest.py b/tests/core/pubsub/conftest.py new file mode 100644 index 000000000..1c889b349 --- /dev/null +++ b/tests/core/pubsub/conftest.py @@ -0,0 +1,110 @@ +"""Shared fixtures and helpers for pubsub tests.""" + +from __future__ import annotations + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +import dataclasses +from typing import Any + +import pytest +import trio + +from libp2p.abc import IHost +from libp2p.pubsub.gossipsub import GossipSub +from libp2p.pubsub.pubsub import Pubsub +from tests.utils.factories import PubsubFactory +from tests.utils.pubsub.utils import dense_connect + + +@dataclasses.dataclass(frozen=True, slots=True) +class GossipSubHarness: + """Typed wrapper around a batch of GossipSub-backed pubsub instances.""" + + pubsubs: tuple[Pubsub, ...] + + @property + def hosts(self) -> tuple[IHost, ...]: + return tuple(ps.host for ps in self.pubsubs) + + @property + def routers(self) -> tuple[GossipSub, ...]: + result: list[GossipSub] = [] + for ps in self.pubsubs: + r = ps.router + assert isinstance(r, GossipSub), f"Expected GossipSub, got {type(r)}" + result.append(r) + return tuple(result) + + def __len__(self) -> int: + return len(self.pubsubs) + + +@asynccontextmanager +async def gossipsub_nodes(n: int, **kwargs: Any) -> AsyncIterator[GossipSubHarness]: + """ + Create *n* GossipSub-backed pubsub nodes wrapped in a harness. + + Usage:: + + async with gossipsub_nodes(3, heartbeat_interval=0.5) as h: + h.pubsubs # tuple[Pubsub, ...] + h.hosts # tuple[IHost, ...] + h.routers # tuple[GossipSub, ...] + """ + async with PubsubFactory.create_batch_with_gossipsub(n, **kwargs) as pubsubs: + yield GossipSubHarness(pubsubs=pubsubs) + + +@asynccontextmanager +async def connected_gossipsub_nodes( + n: int, *, strict: bool = False, **kwargs: Any +) -> AsyncIterator[GossipSubHarness]: + """ + Create *n* GossipSub nodes with dense connectivity. + + By default this waits only until each node has observed one expected + neighbour (fast path). Pass ``strict=True`` to wait until every node + has observed every other expected peer — useful for topology-sensitive + tests that assert exact peer counts or full fanout behaviour. + """ + peer_wait_timeout = kwargs.pop("peer_wait_timeout", 5.0) + async with gossipsub_nodes(n, **kwargs) as harness: + await dense_connect(harness.hosts) + if n > 1: + with trio.fail_after(peer_wait_timeout): + if strict: + for index, pubsub in enumerate(harness.pubsubs): + for other_index, other_host in enumerate(harness.hosts): + if other_index == index: + continue + await pubsub.wait_for_peer(other_host.get_id()) + else: + for index, pubsub in enumerate(harness.pubsubs): + target_host = harness.hosts[(index + 1) % n] + await pubsub.wait_for_peer(target_host.get_id()) + yield harness + + +@asynccontextmanager +async def subscribed_mesh( + topic: str, n: int, *, settle_time: float = 1.0, **kwargs: Any +) -> AsyncIterator[GossipSubHarness]: + """ + Create *n* connected GossipSub nodes all subscribed to *topic*. + + Waits *settle_time* seconds for mesh formation before yielding. + """ + async with connected_gossipsub_nodes(n, **kwargs) as harness: + for ps in harness.pubsubs: + await ps.subscribe(topic) + # TODO(#378): replace fixed sleep with predicate-based mesh-ready polling + await trio.sleep(settle_time) + yield harness + + +@pytest.fixture +async def connected_gossipsub_pair() -> AsyncIterator[GossipSubHarness]: + """Fixture: two connected GossipSub nodes with default config.""" + async with connected_gossipsub_nodes(2) as harness: + yield harness diff --git a/tests/core/pubsub/test_dummyaccount_demo.py b/tests/core/pubsub/test_dummyaccount_demo.py index 0018ba80f..2136f0ab3 100644 --- a/tests/core/pubsub/test_dummyaccount_demo.py +++ b/tests/core/pubsub/test_dummyaccount_demo.py @@ -1,8 +1,3 @@ -from collections.abc import ( - Callable, -) -import logging - import pytest import trio @@ -12,69 +7,9 @@ from tests.utils.pubsub.dummy_account_node import ( DummyAccountNode, ) - -logger = logging.getLogger(__name__) - - -async def wait_for_convergence( - nodes: tuple[DummyAccountNode, ...], - check: Callable[[DummyAccountNode], bool], - timeout: float = 10.0, - poll_interval: float = 0.02, - log_success: bool = False, - raise_last_exception_on_timeout: bool = True, -) -> None: - """ - Wait until all nodes satisfy the check condition. - - Returns as soon as convergence is reached, otherwise raises TimeoutError. - Convergence already guarantees all nodes satisfy the check, so callers need - not run a second assertion pass after this returns. - """ - start_time = trio.current_time() - - last_exception: Exception | None = None - last_exception_node: int | None = None - - while True: - failed_indices: list[int] = [] - for i, node in enumerate(nodes): - try: - ok = check(node) - except Exception as exc: - ok = False - last_exception = exc - last_exception_node = i - if not ok: - failed_indices.append(i) - - if not failed_indices: - elapsed = trio.current_time() - start_time - if log_success: - logger.debug("Converged in %.3fs with %d nodes", elapsed, len(nodes)) - return - - elapsed = trio.current_time() - start_time - if elapsed > timeout: - if raise_last_exception_on_timeout and last_exception is not None: - # Preserve the underlying assertion/exception signal (and its message) - # instead of hiding it behind a generic timeout. - node_hint = ( - f" (node index {last_exception_node})" - if last_exception_node is not None - else "" - ) - raise AssertionError( - f"Convergence failed{node_hint}: {last_exception}" - ) from last_exception - - raise TimeoutError( - f"Convergence timeout after {elapsed:.2f}s. " - f"Failed nodes: {failed_indices}. " - f"(Hint: run with -s and pass log_success=True for timing logs)" - ) - - await trio.sleep(poll_interval) +from tests.utils.pubsub.wait import ( + wait_for_convergence, +) async def perform_test(num_nodes, adjacency_map, action_func, assertion_func): @@ -116,7 +51,6 @@ def _check_final(node: DummyAccountNode) -> bool: # Success, terminate pending tasks. -@pytest.mark.trio async def test_simple_two_nodes(): num_nodes = 2 adj_map = {0: [1]} @@ -130,7 +64,6 @@ def assertion_func(dummy_node): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.trio async def test_simple_three_nodes_line_topography(): num_nodes = 3 adj_map = {0: [1], 1: [2]} @@ -144,7 +77,6 @@ def assertion_func(dummy_node): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.trio async def test_simple_three_nodes_triangle_topography(): num_nodes = 3 adj_map = {0: [1, 2], 1: [2]} @@ -158,7 +90,6 @@ def assertion_func(dummy_node): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.trio async def test_simple_seven_nodes_tree_topography(): num_nodes = 7 adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]} @@ -172,7 +103,6 @@ def assertion_func(dummy_node): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.trio async def test_set_then_send_from_root_seven_nodes_tree_topography(): num_nodes = 7 adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]} @@ -197,7 +127,6 @@ def assertion_func(dummy_node): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.trio async def test_set_then_send_from_different_leafs_seven_nodes_tree_topography(): num_nodes = 7 adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]} @@ -216,7 +145,6 @@ def assertion_func(dummy_node): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.trio async def test_simple_five_nodes_ring_topography(): num_nodes = 5 adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]} @@ -230,7 +158,6 @@ def assertion_func(dummy_node): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.trio async def test_set_then_send_from_diff_nodes_five_nodes_ring_topography(): num_nodes = 5 adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]} @@ -252,7 +179,6 @@ def assertion_func(dummy_node): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.trio @pytest.mark.slow async def test_set_then_send_from_five_diff_nodes_five_nodes_ring_topography(): num_nodes = 5 diff --git a/tests/utils/pubsub/wait.py b/tests/utils/pubsub/wait.py new file mode 100644 index 000000000..028d141a7 --- /dev/null +++ b/tests/utils/pubsub/wait.py @@ -0,0 +1,116 @@ +"""Polling helpers for pubsub test synchronization.""" + +from __future__ import annotations + +from collections.abc import Callable +import inspect +import logging +from typing import TYPE_CHECKING + +import trio + +if TYPE_CHECKING: + from tests.utils.pubsub.dummy_account_node import DummyAccountNode + +logger = logging.getLogger(__name__) + + +async def wait_for( + predicate: Callable[[], object], + *, + timeout: float = 10.0, + poll_interval: float = 0.02, + fail_msg: str = "", +) -> None: + """ + Poll until *predicate()* returns a truthy value, or raise ``TimeoutError``. + + Supports sync predicates, async predicates, and callables that return + awaitables (e.g. ``lambda: some_async_fn()``). If the predicate raises + an exception it is treated as falsy; on timeout the last such exception + is chained to the ``TimeoutError``. + """ + start = trio.current_time() + last_exc: Exception | None = None + + while True: + try: + result = predicate() + if inspect.isawaitable(result): + result = await result + if result: + return + except Exception as exc: + last_exc = exc + + elapsed = trio.current_time() - start + if elapsed > timeout: + msg = fail_msg or f"wait_for timed out after {elapsed:.2f}s" + err = TimeoutError(msg) + if last_exc is not None: + raise err from last_exc + raise err + + await trio.sleep(poll_interval) + + +async def wait_for_convergence( + nodes: tuple[DummyAccountNode, ...], + check: Callable[[DummyAccountNode], bool], + timeout: float = 10.0, + poll_interval: float = 0.02, + log_success: bool = False, + raise_last_exception_on_timeout: bool = True, +) -> None: + """ + Wait until all *nodes* satisfy *check*. + + Returns as soon as convergence is reached, otherwise raises + ``TimeoutError`` (or ``AssertionError`` when + *raise_last_exception_on_timeout* is ``True`` and a node raised). + + Preserves the API of the original inline helper from + ``test_dummyaccount_demo.py``. + """ + start_time = trio.current_time() + + last_exception: Exception | None = None + last_exception_node: int | None = None + + while True: + failed_indices: list[int] = [] + for i, node in enumerate(nodes): + try: + ok = check(node) + except Exception as exc: + ok = False + last_exception = exc + last_exception_node = i + if not ok: + failed_indices.append(i) + + if not failed_indices: + elapsed = trio.current_time() - start_time + if log_success: + logger.debug("Converged in %.3fs with %d nodes", elapsed, len(nodes)) + return + + elapsed = trio.current_time() - start_time + if elapsed > timeout: + if raise_last_exception_on_timeout and last_exception is not None: + node_hint = ( + f" (node index {last_exception_node})" + if last_exception_node is not None + else "" + ) + raise AssertionError( + f"Convergence failed{node_hint}: {last_exception}" + ) from last_exception + + raise TimeoutError( + f"Convergence timeout after {elapsed:.2f}s. " + f"Failed nodes: {failed_indices}. " + f"(Hint: run with -s and pass log_success=True for timing logs)" + ) + + await trio.sleep(poll_interval)