Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions newsfragments/378.internal.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add shared pubsub test fixtures (``GossipSubHarness``, ``gossipsub_nodes``, ``connected_gossipsub_nodes``, ``subscribed_mesh``) and reusable polling helpers (``wait_for``, ``wait_for_convergence``) to support the pubsub test suite refactor.
110 changes: 110 additions & 0 deletions tests/core/pubsub/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""Shared fixtures and helpers for pubsub tests."""

from __future__ import annotations

from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
import dataclasses
from typing import Any

import pytest
import trio

from libp2p.abc import IHost
from libp2p.pubsub.gossipsub import GossipSub
from libp2p.pubsub.pubsub import Pubsub
from tests.utils.factories import PubsubFactory
from tests.utils.pubsub.utils import dense_connect


@dataclasses.dataclass(frozen=True, slots=True)
class GossipSubHarness:
"""Typed wrapper around a batch of GossipSub-backed pubsub instances."""

pubsubs: tuple[Pubsub, ...]

@property
def hosts(self) -> tuple[IHost, ...]:
return tuple(ps.host for ps in self.pubsubs)

@property
def routers(self) -> tuple[GossipSub, ...]:
result: list[GossipSub] = []
for ps in self.pubsubs:
r = ps.router
assert isinstance(r, GossipSub), f"Expected GossipSub, got {type(r)}"
result.append(r)
return tuple(result)

def __len__(self) -> int:
return len(self.pubsubs)


@asynccontextmanager
async def gossipsub_nodes(n: int, **kwargs: Any) -> AsyncIterator[GossipSubHarness]:
"""
Create *n* GossipSub-backed pubsub nodes wrapped in a harness.

Usage::

async with gossipsub_nodes(3, heartbeat_interval=0.5) as h:
h.pubsubs # tuple[Pubsub, ...]
h.hosts # tuple[IHost, ...]
h.routers # tuple[GossipSub, ...]
"""
async with PubsubFactory.create_batch_with_gossipsub(n, **kwargs) as pubsubs:
yield GossipSubHarness(pubsubs=pubsubs)


@asynccontextmanager
async def connected_gossipsub_nodes(
n: int, *, strict: bool = False, **kwargs: Any
) -> AsyncIterator[GossipSubHarness]:
"""
Create *n* GossipSub nodes with dense connectivity.

By default this waits only until each node has observed one expected
neighbour (fast path). Pass ``strict=True`` to wait until every node
has observed every other expected peer — useful for topology-sensitive
tests that assert exact peer counts or full fanout behaviour.
"""
peer_wait_timeout = kwargs.pop("peer_wait_timeout", 5.0)
async with gossipsub_nodes(n, **kwargs) as harness:
await dense_connect(harness.hosts)
if n > 1:
with trio.fail_after(peer_wait_timeout):
if strict:
for index, pubsub in enumerate(harness.pubsubs):
for other_index, other_host in enumerate(harness.hosts):
if other_index == index:
continue
await pubsub.wait_for_peer(other_host.get_id())
else:
for index, pubsub in enumerate(harness.pubsubs):
target_host = harness.hosts[(index + 1) % n]
await pubsub.wait_for_peer(target_host.get_id())
yield harness


@asynccontextmanager
async def subscribed_mesh(
topic: str, n: int, *, settle_time: float = 1.0, **kwargs: Any
) -> AsyncIterator[GossipSubHarness]:
"""
Create *n* connected GossipSub nodes all subscribed to *topic*.

Waits *settle_time* seconds for mesh formation before yielding.
"""
async with connected_gossipsub_nodes(n, **kwargs) as harness:
for ps in harness.pubsubs:
await ps.subscribe(topic)
# TODO(#378): replace fixed sleep with predicate-based mesh-ready polling
await trio.sleep(settle_time)
yield harness


@pytest.fixture
async def connected_gossipsub_pair() -> AsyncIterator[GossipSubHarness]:
"""Fixture: two connected GossipSub nodes with default config."""
async with connected_gossipsub_nodes(2) as harness:
yield harness
80 changes: 3 additions & 77 deletions tests/core/pubsub/test_dummyaccount_demo.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
from collections.abc import (
Callable,
)
import logging

import pytest
import trio

Expand All @@ -12,69 +7,9 @@
from tests.utils.pubsub.dummy_account_node import (
DummyAccountNode,
)

logger = logging.getLogger(__name__)


async def wait_for_convergence(
nodes: tuple[DummyAccountNode, ...],
check: Callable[[DummyAccountNode], bool],
timeout: float = 10.0,
poll_interval: float = 0.02,
log_success: bool = False,
raise_last_exception_on_timeout: bool = True,
) -> None:
"""
Wait until all nodes satisfy the check condition.

Returns as soon as convergence is reached, otherwise raises TimeoutError.
Convergence already guarantees all nodes satisfy the check, so callers need
not run a second assertion pass after this returns.
"""
start_time = trio.current_time()

last_exception: Exception | None = None
last_exception_node: int | None = None

while True:
failed_indices: list[int] = []
for i, node in enumerate(nodes):
try:
ok = check(node)
except Exception as exc:
ok = False
last_exception = exc
last_exception_node = i
if not ok:
failed_indices.append(i)

if not failed_indices:
elapsed = trio.current_time() - start_time
if log_success:
logger.debug("Converged in %.3fs with %d nodes", elapsed, len(nodes))
return

elapsed = trio.current_time() - start_time
if elapsed > timeout:
if raise_last_exception_on_timeout and last_exception is not None:
# Preserve the underlying assertion/exception signal (and its message)
# instead of hiding it behind a generic timeout.
node_hint = (
f" (node index {last_exception_node})"
if last_exception_node is not None
else ""
)
raise AssertionError(
f"Convergence failed{node_hint}: {last_exception}"
) from last_exception

raise TimeoutError(
f"Convergence timeout after {elapsed:.2f}s. "
f"Failed nodes: {failed_indices}. "
f"(Hint: run with -s and pass log_success=True for timing logs)"
)

await trio.sleep(poll_interval)
from tests.utils.pubsub.wait import (
wait_for_convergence,
)


async def perform_test(num_nodes, adjacency_map, action_func, assertion_func):
Expand Down Expand Up @@ -116,7 +51,6 @@ def _check_final(node: DummyAccountNode) -> bool:
# Success, terminate pending tasks.


@pytest.mark.trio
async def test_simple_two_nodes():
num_nodes = 2
adj_map = {0: [1]}
Expand All @@ -130,7 +64,6 @@ def assertion_func(dummy_node):
await perform_test(num_nodes, adj_map, action_func, assertion_func)


@pytest.mark.trio
async def test_simple_three_nodes_line_topography():
num_nodes = 3
adj_map = {0: [1], 1: [2]}
Expand All @@ -144,7 +77,6 @@ def assertion_func(dummy_node):
await perform_test(num_nodes, adj_map, action_func, assertion_func)


@pytest.mark.trio
async def test_simple_three_nodes_triangle_topography():
num_nodes = 3
adj_map = {0: [1, 2], 1: [2]}
Expand All @@ -158,7 +90,6 @@ def assertion_func(dummy_node):
await perform_test(num_nodes, adj_map, action_func, assertion_func)


@pytest.mark.trio
async def test_simple_seven_nodes_tree_topography():
num_nodes = 7
adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
Expand All @@ -172,7 +103,6 @@ def assertion_func(dummy_node):
await perform_test(num_nodes, adj_map, action_func, assertion_func)


@pytest.mark.trio
async def test_set_then_send_from_root_seven_nodes_tree_topography():
num_nodes = 7
adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
Expand All @@ -197,7 +127,6 @@ def assertion_func(dummy_node):
await perform_test(num_nodes, adj_map, action_func, assertion_func)


@pytest.mark.trio
async def test_set_then_send_from_different_leafs_seven_nodes_tree_topography():
num_nodes = 7
adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
Expand All @@ -216,7 +145,6 @@ def assertion_func(dummy_node):
await perform_test(num_nodes, adj_map, action_func, assertion_func)


@pytest.mark.trio
async def test_simple_five_nodes_ring_topography():
num_nodes = 5
adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]}
Expand All @@ -230,7 +158,6 @@ def assertion_func(dummy_node):
await perform_test(num_nodes, adj_map, action_func, assertion_func)


@pytest.mark.trio
async def test_set_then_send_from_diff_nodes_five_nodes_ring_topography():
num_nodes = 5
adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]}
Expand All @@ -252,7 +179,6 @@ def assertion_func(dummy_node):
await perform_test(num_nodes, adj_map, action_func, assertion_func)


@pytest.mark.trio
@pytest.mark.slow
async def test_set_then_send_from_five_diff_nodes_five_nodes_ring_topography():
num_nodes = 5
Expand Down
116 changes: 116 additions & 0 deletions tests/utils/pubsub/wait.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""Polling helpers for pubsub test synchronization."""

from __future__ import annotations

from collections.abc import Callable
import inspect
import logging
from typing import TYPE_CHECKING

import trio

if TYPE_CHECKING:
from tests.utils.pubsub.dummy_account_node import DummyAccountNode

logger = logging.getLogger(__name__)


async def wait_for(
predicate: Callable[[], object],
*,
timeout: float = 10.0,
poll_interval: float = 0.02,
fail_msg: str = "",
) -> None:
"""
Poll until *predicate()* returns a truthy value, or raise ``TimeoutError``.

Supports sync predicates, async predicates, and callables that return
awaitables (e.g. ``lambda: some_async_fn()``). If the predicate raises
an exception it is treated as falsy; on timeout the last such exception
is chained to the ``TimeoutError``.
"""
start = trio.current_time()
last_exc: Exception | None = None

while True:
try:
result = predicate()
if inspect.isawaitable(result):
result = await result
if result:
return
except Exception as exc:
last_exc = exc

elapsed = trio.current_time() - start
if elapsed > timeout:
msg = fail_msg or f"wait_for timed out after {elapsed:.2f}s"
err = TimeoutError(msg)
if last_exc is not None:
raise err from last_exc
raise err

await trio.sleep(poll_interval)


async def wait_for_convergence(
nodes: tuple[DummyAccountNode, ...],
check: Callable[[DummyAccountNode], bool],
timeout: float = 10.0,
poll_interval: float = 0.02,
log_success: bool = False,
raise_last_exception_on_timeout: bool = True,
) -> None:
"""
Wait until all *nodes* satisfy *check*.

Returns as soon as convergence is reached, otherwise raises
``TimeoutError`` (or ``AssertionError`` when
*raise_last_exception_on_timeout* is ``True`` and a node raised).

Preserves the API of the original inline helper from
``test_dummyaccount_demo.py``.
"""
start_time = trio.current_time()

last_exception: Exception | None = None
last_exception_node: int | None = None

while True:
failed_indices: list[int] = []
for i, node in enumerate(nodes):
try:
ok = check(node)
except Exception as exc:
ok = False
last_exception = exc
last_exception_node = i
if not ok:
failed_indices.append(i)

if not failed_indices:
elapsed = trio.current_time() - start_time
if log_success:
logger.debug("Converged in %.3fs with %d nodes", elapsed, len(nodes))
return

elapsed = trio.current_time() - start_time
if elapsed > timeout:
if raise_last_exception_on_timeout and last_exception is not None:
node_hint = (
f" (node index {last_exception_node})"
if last_exception_node is not None
else ""
)
raise AssertionError(
f"Convergence failed{node_hint}: {last_exception}"
) from last_exception

raise TimeoutError(
f"Convergence timeout after {elapsed:.2f}s. "
f"Failed nodes: {failed_indices}. "
f"(Hint: run with -s and pass log_success=True for timing logs)"
)

await trio.sleep(poll_interval)
Loading