Skip to content

Commit c0ccfc5

Browse files
committed
Add new sync server to separate file.
1 parent fff302e commit c0ccfc5

2 files changed

Lines changed: 156 additions & 30 deletions

File tree

src/py/kaleido/__init__.py

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,30 @@
66

77
from __future__ import annotations
88

9-
import asyncio
10-
import queue
11-
from threading import Thread
12-
from typing import TYPE_CHECKING, NamedTuple
13-
149
from choreographer.cli import get_chrome, get_chrome_sync
1510

11+
from . import _sync_server
1612
from ._page_generator import PageGenerator
1713
from .kaleido import Kaleido
1814

19-
if TYPE_CHECKING:
20-
from typing import Any
15+
_global_server = _sync_server.GlobalKaleidoServer()
16+
17+
18+
def start_kaleido_sync_server(*args, **kwargs):
19+
"""
20+
Start a kaleido server which will process all sync generation requests.
21+
22+
Only one server can be started at a time.
23+
24+
This wrapper function takes the exact same arguments as kaleido.Kaleido().
25+
"""
26+
_global_server.open(*args, **kwargs)
27+
28+
29+
def stop_kaleido_sync_server():
30+
"""Stop the kaleido server. It can be restarted."""
31+
_global_server.close()
32+
2133

2234
__all__ = [
2335
"Kaleido",
@@ -26,6 +38,8 @@
2638
"calc_fig_sync",
2739
"get_chrome",
2840
"get_chrome_sync",
41+
"start_kaleido_sync_server",
42+
"stop_kaleido_sync_server",
2943
"write_fig",
3044
"write_fig_from_object",
3145
"write_fig_from_object_sync",
@@ -126,36 +140,25 @@ async def write_fig_from_object(
126140
)
127141

128142

129-
def _async_thread_run(func, args: tuple[Any, ...], kwargs: dict):
130-
q: queue.Queue[Any] = queue.Queue(maxsize=1)
131-
132-
def run(func, q, *args, **kwargs):
133-
# func is a closure
134-
try:
135-
q.put(asyncio.run(func(*args, **kwargs)))
136-
except BaseException as e: # noqa: BLE001
137-
q.put(e)
138-
139-
t = Thread(target=run, args=(func, q, *args), kwargs=kwargs)
140-
t.start()
141-
t.join()
142-
res = q.get()
143-
if isinstance(res, BaseException):
144-
raise res
145-
else:
146-
return res
147-
148-
149143
def calc_fig_sync(*args, **kwargs):
150144
"""Call `calc_fig` but blocking."""
151-
return _async_thread_run(calc_fig, args=args, kwargs=kwargs)
145+
if _global_server.is_running():
146+
return _global_server.call_function("calc_fig", *args, **kwargs)
147+
else:
148+
return _sync_server.oneshot_async_run(calc_fig, args=args, kwargs=kwargs)
152149

153150

154151
def write_fig_sync(*args, **kwargs):
155152
"""Call `write_fig` but blocking."""
156-
_async_thread_run(write_fig, args=args, kwargs=kwargs)
153+
if _global_server.is_running():
154+
_global_server.call_function("write_fig", *args, **kwargs)
155+
else:
156+
_sync_server.oneshot_async_run(write_fig, args=args, kwargs=kwargs)
157157

158158

159159
def write_fig_from_object_sync(*args, **kwargs):
160160
"""Call `write_fig_from_object` but blocking."""
161-
_async_thread_run(write_fig_from_object, args=args, kwargs=kwargs)
161+
if _global_server.is_running():
162+
_global_server.call_function("write_fig_from_object", *args, **kwargs)
163+
else:
164+
_sync_server.oneshot_async_run(write_fig_from_object, args=args, kwargs=kwargs)

src/py/kaleido/_sync_server.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import warnings
5+
from queue import Queue
6+
from threading import Thread
7+
from typing import TYPE_CHECKING, NamedTuple
8+
9+
from .kaleido import Kaleido
10+
11+
if TYPE_CHECKING:
12+
from typing import Any
13+
14+
15+
class Task(NamedTuple):
16+
fn: str
17+
args: Any
18+
kwargs: Any
19+
20+
21+
class _BadFunctionName(BaseException):
22+
"""For use when programmed poorly."""
23+
24+
25+
class GlobalKaleidoServer:
26+
_instance = None
27+
28+
async def _server(self, *args, **kwargs):
29+
async with Kaleido(*args, **kwargs) as k: # multiple processor? Enable GPU?
30+
while True:
31+
task = self._task_queue.get() # thread dies if main thread dies
32+
if task is None:
33+
self._task_queue.task_done()
34+
return
35+
if not hasattr(k, task.fn):
36+
raise _BadFunctionName(f"Kaleido has no attribute {task.fn}")
37+
try:
38+
self._return_queue.put(
39+
await getattr(k, task.fn)(*task.args, **task.kwargs),
40+
)
41+
except Exception as e: # noqa: BLE001
42+
self._return_queue.put(e)
43+
44+
self._task_queue.task_done()
45+
46+
def __new__(cls):
47+
# Create the singleton on first instantiation
48+
if cls._instance is None:
49+
cls._instance = super().__new__(cls)
50+
cls._instance._initialized = False # noqa: SLF001
51+
return cls._instance
52+
53+
def is_running(self):
54+
return self._initialized
55+
56+
def open(self, *args, **kwargs):
57+
"""Initialize the singleton with three values."""
58+
if self.is_running():
59+
warnings.warn(
60+
"Server already open.",
61+
RuntimeWarning,
62+
stacklevel=2,
63+
)
64+
return
65+
coroutine = self._server(*args, **kwargs)
66+
self._thread: Thread = Thread(target=asyncio.run, args=(coroutine,))
67+
self._task_queue: Queue[Task | None] = Queue()
68+
self._return_queue: Queue[Any] = Queue()
69+
self._thread.start()
70+
self._initialized = True
71+
72+
def close(self):
73+
"""Reset the singleton back to an uninitialized state."""
74+
if not self.is_running():
75+
warnings.warn(
76+
"Server already closed.",
77+
RuntimeWarning,
78+
stacklevel=2,
79+
)
80+
return
81+
self._task_queue.put(None)
82+
self._thread.join()
83+
del self._thread
84+
del self._task_queue
85+
del self._return_queue
86+
self._initialized = False
87+
88+
def call_function(self, cmd: str, *args, **kwargs):
89+
if not self._is_running():
90+
raise RuntimeError("Can't call function on stopped server.")
91+
if kwargs.pop("kopts"):
92+
warnings.warn(
93+
"The kopts argument is ignored if using a server.",
94+
UserWarning,
95+
stacklevel=3,
96+
)
97+
self._task_queue.put(Task(cmd, args, kwargs))
98+
self._task_queue.join()
99+
res = self._return_queue.get()
100+
if isinstance(res, BaseException):
101+
raise res
102+
else:
103+
return res
104+
105+
106+
def oneshot_async_run(func, args: tuple[Any, ...], kwargs: dict):
107+
q: Queue[Any] = Queue(maxsize=1)
108+
109+
def run(func, q, *args, **kwargs):
110+
# func is a closure
111+
try:
112+
q.put(asyncio.run(func(*args, **kwargs)))
113+
except BaseException as e: # noqa: BLE001
114+
q.put(e)
115+
116+
t = Thread(target=run, args=(func, q, *args), kwargs=kwargs)
117+
t.start()
118+
t.join()
119+
res = q.get()
120+
if isinstance(res, BaseException):
121+
raise res
122+
else:
123+
return res

0 commit comments

Comments
 (0)