-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconftest.py
More file actions
135 lines (108 loc) · 4.03 KB
/
conftest.py
File metadata and controls
135 lines (108 loc) · 4.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""Configuration for the test suite."""
from abc import ABC, abstractmethod
from asyncio.events import BaseDefaultEventLoopPolicy
import multiprocessing
import os
import typing as _t
from unittest.mock import patch
import pytest
import pytest_asyncio
import pytest_cases
import ray
from that_depends import ContextScopes, container_context
import uvloop
from plugboard.component import Component, IOController as IO
from plugboard.component.io_controller import IOStreamClosedError
from plugboard.connector import ZMQConnector
from plugboard.utils.di import DI
from plugboard.utils.settings import Settings
@pytest.fixture(scope="session")
def event_loop_policy() -> BaseDefaultEventLoopPolicy:
"""Set uvloop as the event loop policy for the test session."""
return uvloop.EventLoopPolicy()
@pytest.fixture(scope="session", autouse=True)
def mp_set_start_method() -> None:
"""Set the start method for multiprocessing to 'spawn'."""
try:
multiprocessing.set_start_method("spawn", force=True)
except RuntimeError:
# Start method can only be set once per process
pass
@pytest.fixture(scope="session")
def ray_ctx() -> _t.Iterator[None]:
"""Initialises and shuts down Ray."""
ray.init(num_cpus=4, num_gpus=0, include_dashboard=True)
yield
ray.shutdown()
@pytest.fixture(scope="function")
def job_id_ctx() -> _t.Iterator[str]:
"""Enters the container context with the job_id."""
with container_context(DI, global_context={"job_id": None}, scope=ContextScopes.APP):
job_id = DI.job_id.resolve_sync()
yield job_id
@pytest_asyncio.fixture(scope="function", autouse=True)
async def DI_teardown() -> _t.AsyncGenerator[None, None]:
"""Cleans up any resources created in DI container after each test."""
try:
yield
finally:
await DI.tear_down()
@pytest_cases.fixture
@pytest_cases.parametrize(zmq_pubsub_proxy=[False, True])
def zmq_connector_cls(zmq_pubsub_proxy: bool) -> _t.Iterator[_t.Type[ZMQConnector]]:
"""Returns the ZMQConnector class with the specified proxy setting.
Patches the env var `PLUGBOARD_FLAGS_ZMQ_PUBSUB_PROXY` to control the proxy setting.
"""
with patch.dict(
os.environ,
{"PLUGBOARD_FLAGS_ZMQ_PUBSUB_PROXY": str(zmq_pubsub_proxy)},
):
testing_settings = Settings()
DI.settings.override_sync(testing_settings)
yield ZMQConnector
DI.settings.reset_override_sync()
class ComponentTestHelper(Component, ABC):
"""`ComponentTestHelper` is a component class for testing purposes."""
io = IO(inputs=[], outputs=[])
exports = ["_is_initialised", "_is_finished", "_step_count"]
@property
def is_initialised(self) -> bool: # noqa: D102
return self._is_initialised
@property
def is_finished(self) -> bool: # noqa: D102
return self._is_finished
@property
def step_count(self) -> int: # noqa: D102
return self._step_count
def __init__(self, *args: _t.Any, max_steps: int = 0, **kwargs: _t.Any) -> None:
super().__init__(*args, **kwargs)
self._is_initialised = False
self._is_finished = False
self._step_count = 0
self._max_steps = max_steps
async def init(self) -> None: # noqa: D102
self._is_initialised = True
await super().init()
@abstractmethod
async def step(self) -> None: # noqa: D102
self._step_count += 1
async def run(self) -> None: # noqa: D102
while True:
try:
await self.step()
except IOStreamClosedError:
break
if self._max_steps > 0 and self._step_count >= self._max_steps:
break
self._is_finished = True
def dict(self) -> dict:
"""Returns the component state as a dictionary."""
data = super().dict()
data.update(
{
"is_initialised": self._is_initialised,
"is_finished": self._is_finished,
"step_count": self._step_count,
}
)
return data