Skip to content

Commit 46e55e9

Browse files
committed
Fix typing in kaleido.
1 parent b48eed4 commit 46e55e9

1 file changed

Lines changed: 63 additions & 17 deletions

File tree

src/py/kaleido/kaleido.py

Lines changed: 63 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import asyncio
66
from collections.abc import AsyncIterable, Iterable
77
from pathlib import Path
8-
from typing import TYPE_CHECKING, TypedDict, cast
8+
from typing import TYPE_CHECKING, TypedDict, cast, overload
99

1010
import choreographer as choreo
1111
import logistro
@@ -24,7 +24,16 @@
2424

2525
if TYPE_CHECKING:
2626
from types import TracebackType
27-
from typing import Any, List, Tuple, TypeVar, Union, ValuesView
27+
from typing import (
28+
Any,
29+
AsyncGenerator,
30+
List,
31+
Literal,
32+
Tuple,
33+
TypeVar,
34+
Union,
35+
ValuesView,
36+
)
2837

2938
from typing_extensions import NotRequired, Required, TypeAlias
3039

@@ -313,15 +322,52 @@ async def _render_task(
313322
await self._return_kaleido_tab(tab)
314323

315324
### API ###
325+
@overload
326+
async def write_fig_from_object(
327+
self,
328+
generator: AnyIterable[FigureDict],
329+
*,
330+
cancel_on_error: bool = False,
331+
_write: Literal[False],
332+
) -> bytes: ...
333+
334+
@overload
335+
async def write_fig_from_object(
336+
self,
337+
generator: AnyIterable[FigureDict],
338+
*,
339+
cancel_on_error: Literal[True],
340+
_write: Literal[True] = True,
341+
) -> None: ...
342+
343+
@overload
344+
async def write_fig_from_object(
345+
self,
346+
generator: AnyIterable[FigureDict],
347+
*,
348+
cancel_on_error: Literal[False] = False,
349+
_write: Literal[True] = True,
350+
) -> tuple[Exception]: ...
351+
352+
@overload
353+
async def write_fig_from_object(
354+
self,
355+
generator: AnyIterable[FigureDict],
356+
*,
357+
cancel_on_error: bool,
358+
_write: Literal[True] = True,
359+
) -> tuple[Exception] | None: ...
316360

317361
async def write_fig_from_object(
318362
self,
319363
generator: AnyIterable[FigureDict], # TODO: must take a FigureDict alone
320364
*,
321365
cancel_on_error=False,
322366
_write: bool = True, # backwards compatibility!
323-
) -> tuple[bytes | None | Exception]:
367+
) -> None | bytes | tuple[Exception]:
324368
"""Temp."""
369+
if not _write:
370+
cancel_on_error = True
325371
if main_task := asyncio.current_task():
326372
self._main_render_coroutines.add(main_task)
327373
tasks: set[asyncio.Task] = set()
@@ -349,15 +395,20 @@ async def write_fig_from_object(
349395
)
350396
tasks.add(t)
351397

352-
return await asyncio.gather(*tasks, return_exceptions=not cancel_on_error)
398+
res = await asyncio.gather(*tasks, return_exceptions=not cancel_on_error)
399+
if not _write:
400+
return cast("bytes", res[0])
401+
elif cancel_on_error:
402+
return None
403+
else:
404+
return cast("tuple[Exception]", tuple(r for r in res if r))
353405

354406
finally:
355407
for task in tasks:
356408
if not task.done():
357409
task.cancel()
358410
if main_task:
359411
self._main_render_coroutines.remove(main_task)
360-
# return errors?
361412

362413
async def write_fig(
363414
self,
@@ -366,16 +417,16 @@ async def write_fig(
366417
opts: None | _fig_tools.LayoutOpts = None,
367418
*,
368419
topojson: str | None = None,
369-
cancel_on_error=False,
370-
) -> tuple[None | Exception]: # TODO this should be filtered
420+
cancel_on_error: bool = False,
421+
) -> tuple[Exception] | None:
371422
"""Temp."""
372423
if _fig_tools.is_figurish(fig) or not isinstance(
373424
fig,
374425
(Iterable, AsyncIterable),
375426
):
376427
fig = [fig]
377428

378-
async def _temp_generator():
429+
async def _temp_generator() -> AsyncGenerator[FigureDict, None]:
379430
async for f in _utils.ensure_async_iter(fig):
380431
yield {
381432
"fig": f,
@@ -384,13 +435,11 @@ async def _temp_generator():
384435
"topojson": topojson,
385436
}
386437

387-
res = await self.write_fig_from_object(
388-
generator=_temp_generator(),
438+
generator = cast("AsyncIterable[FigureDict]", _temp_generator())
439+
return await self.write_fig_from_object(
440+
generator=generator,
389441
cancel_on_error=cancel_on_error,
390442
)
391-
return cast("tuple[Exception]", tuple(r for r in res if r is not None))
392-
# we're using cast, but @overload would be better
393-
# because with _write = True, return is a tuple[Exception | None]
394443

395444
async def calc_fig(
396445
self,
@@ -408,11 +457,8 @@ async def _temp_generator():
408457
"topojson": topojson,
409458
}
410459

411-
res = await self.write_fig_from_object(
460+
return await self.write_fig_from_object(
412461
generator=_temp_generator(),
413462
cancel_on_error=True,
414463
_write=False,
415464
)
416-
return cast("bytes", res[0])
417-
# Complex type mechanics. Exceptions will raise. None not possible.
418-
# Bytes only option

0 commit comments

Comments
 (0)