-
Notifications
You must be signed in to change notification settings - Fork 58
Expand file tree
/
Copy path_sync_server.py
More file actions
133 lines (113 loc) · 3.97 KB
/
_sync_server.py
File metadata and controls
133 lines (113 loc) · 3.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from __future__ import annotations
import asyncio
import atexit
import warnings
from functools import partial
from queue import Queue
from threading import Thread
from typing import TYPE_CHECKING, NamedTuple
from .kaleido import Kaleido
if TYPE_CHECKING:
from typing import Any
class Task(NamedTuple):
fn: str
args: Any
kwargs: Any
class _BadFunctionName(BaseException):
"""For use when programmed poorly."""
class GlobalKaleidoServer:
_instance = None
async def _server(self, *args, **kwargs):
async with Kaleido(*args, **kwargs) as k: # multiple processor? Enable GPU?
while True:
task = self._task_queue.get() # thread dies if main thread dies
if task is None:
self._task_queue.task_done()
return
if not hasattr(k, task.fn):
raise _BadFunctionName(f"Kaleido has no attribute {task.fn}")
try:
self._return_queue.put(
await getattr(k, task.fn)(*task.args, **task.kwargs),
)
except Exception as e: # noqa: BLE001
self._return_queue.put(e)
self._task_queue.task_done()
def __new__(cls):
# Create the singleton on first instantiation
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False # noqa: SLF001
return cls._instance
def is_running(self):
return self._initialized
def open(self, *args: Any, silence_warnings=False, **kwargs: Any) -> None:
"""Initialize the singleton with three values."""
if self.is_running():
if not silence_warnings:
warnings.warn(
"Server already open.",
RuntimeWarning,
stacklevel=2,
)
return
coroutine = self._server(*args, **kwargs)
self._thread: Thread = Thread(
target=asyncio.run,
args=(coroutine,),
daemon=True,
)
self._task_queue: Queue[Task | None] = Queue()
self._return_queue: Queue[Any] = Queue()
self._thread.start()
self._initialized = True
close = partial(self.close, silence_warnings=True)
atexit.register(close)
def close(self, *, silence_warnings=False):
"""Reset the singleton back to an uninitialized state."""
if not self.is_running():
if not silence_warnings:
warnings.warn(
"Server already closed.",
RuntimeWarning,
stacklevel=2,
)
return
self._task_queue.put(None)
self._thread.join()
del self._thread
del self._task_queue
del self._return_queue
self._initialized = False
def call_function(self, cmd: str, *args, **kwargs):
if not self.is_running():
raise RuntimeError("Can't call function on stopped server.")
if kwargs.pop("kopts", None):
warnings.warn(
"The kopts argument is ignored if using a server.",
UserWarning,
stacklevel=3,
)
self._task_queue.put(Task(cmd, args, kwargs))
self._task_queue.join()
res = self._return_queue.get()
if isinstance(res, BaseException):
raise res
else:
return res
def oneshot_async_run(func, args: tuple[Any, ...], kwargs: dict):
q: Queue[Any] = Queue(maxsize=1)
def run(func, q, *args, **kwargs):
# func is a closure
try:
q.put(asyncio.run(func(*args, **kwargs)))
except BaseException as e: # noqa: BLE001
q.put(e)
t = Thread(target=run, args=(func, q, *args), kwargs=kwargs)
t.start()
t.join()
res = q.get()
if isinstance(res, BaseException):
raise res
else:
return res