55import asyncio
66from collections .abc import AsyncIterable , Iterable
77from pathlib import Path
8- from typing import TYPE_CHECKING , TypedDict , cast
8+ from typing import TYPE_CHECKING , TypedDict , cast , overload
99
1010import choreographer as choreo
1111import logistro
2424
2525if TYPE_CHECKING :
2626 from types import TracebackType
27- from typing import Any , List , Tuple , TypeVar , Union , ValuesView
27+ from typing import (
28+ Any ,
29+ AsyncGenerator ,
30+ List ,
31+ Literal ,
32+ Tuple ,
33+ TypeVar ,
34+ Union ,
35+ ValuesView ,
36+ )
2837
2938 from typing_extensions import NotRequired , Required , TypeAlias
3039
@@ -313,15 +322,52 @@ async def _render_task(
313322 await self ._return_kaleido_tab (tab )
314323
315324 ### API ###
325+ @overload
326+ async def write_fig_from_object (
327+ self ,
328+ generator : AnyIterable [FigureDict ],
329+ * ,
330+ cancel_on_error : bool = False ,
331+ _write : Literal [False ],
332+ ) -> bytes : ...
333+
334+ @overload
335+ async def write_fig_from_object (
336+ self ,
337+ generator : AnyIterable [FigureDict ],
338+ * ,
339+ cancel_on_error : Literal [True ],
340+ _write : Literal [True ] = True ,
341+ ) -> None : ...
342+
343+ @overload
344+ async def write_fig_from_object (
345+ self ,
346+ generator : AnyIterable [FigureDict ],
347+ * ,
348+ cancel_on_error : Literal [False ] = False ,
349+ _write : Literal [True ] = True ,
350+ ) -> tuple [Exception ]: ...
351+
352+ @overload
353+ async def write_fig_from_object (
354+ self ,
355+ generator : AnyIterable [FigureDict ],
356+ * ,
357+ cancel_on_error : bool ,
358+ _write : Literal [True ] = True ,
359+ ) -> tuple [Exception ] | None : ...
316360
317361 async def write_fig_from_object (
318362 self ,
319363 generator : AnyIterable [FigureDict ], # TODO: must take a FigureDict alone
320364 * ,
321365 cancel_on_error = False ,
322366 _write : bool = True , # backwards compatibility!
323- ) -> tuple [ bytes | None | Exception ]:
367+ ) -> None | bytes | tuple [ Exception ]:
324368 """Temp."""
369+ if not _write :
370+ cancel_on_error = True
325371 if main_task := asyncio .current_task ():
326372 self ._main_render_coroutines .add (main_task )
327373 tasks : set [asyncio .Task ] = set ()
@@ -349,15 +395,20 @@ async def write_fig_from_object(
349395 )
350396 tasks .add (t )
351397
352- return await asyncio .gather (* tasks , return_exceptions = not cancel_on_error )
398+ res = await asyncio .gather (* tasks , return_exceptions = not cancel_on_error )
399+ if not _write :
400+ return cast ("bytes" , res [0 ])
401+ elif cancel_on_error :
402+ return None
403+ else :
404+ return cast ("tuple[Exception]" , tuple (r for r in res if r ))
353405
354406 finally :
355407 for task in tasks :
356408 if not task .done ():
357409 task .cancel ()
358410 if main_task :
359411 self ._main_render_coroutines .remove (main_task )
360- # return errors?
361412
362413 async def write_fig (
363414 self ,
@@ -366,16 +417,16 @@ async def write_fig(
366417 opts : None | _fig_tools .LayoutOpts = None ,
367418 * ,
368419 topojson : str | None = None ,
369- cancel_on_error = False ,
370- ) -> tuple [None | Exception ]: # TODO this should be filtered
420+ cancel_on_error : bool = False ,
421+ ) -> tuple [Exception ] | None :
371422 """Temp."""
372423 if _fig_tools .is_figurish (fig ) or not isinstance (
373424 fig ,
374425 (Iterable , AsyncIterable ),
375426 ):
376427 fig = [fig ]
377428
378- async def _temp_generator ():
429+ async def _temp_generator () -> AsyncGenerator [ FigureDict , None ] :
379430 async for f in _utils .ensure_async_iter (fig ):
380431 yield {
381432 "fig" : f ,
@@ -384,13 +435,11 @@ async def _temp_generator():
384435 "topojson" : topojson ,
385436 }
386437
387- res = await self .write_fig_from_object (
388- generator = _temp_generator (),
438+ generator = cast ("AsyncIterable[FigureDict]" , _temp_generator ())
439+ return await self .write_fig_from_object (
440+ generator = generator ,
389441 cancel_on_error = cancel_on_error ,
390442 )
391- return cast ("tuple[Exception]" , tuple (r for r in res if r is not None ))
392- # we're using cast, but @overload would be better
393- # because with _write = True, return is a tuple[Exception | None]
394443
395444 async def calc_fig (
396445 self ,
@@ -408,11 +457,8 @@ async def _temp_generator():
408457 "topojson" : topojson ,
409458 }
410459
411- res = await self .write_fig_from_object (
460+ return await self .write_fig_from_object (
412461 generator = _temp_generator (),
413462 cancel_on_error = True ,
414463 _write = False ,
415464 )
416- return cast ("bytes" , res [0 ])
417- # Complex type mechanics. Exceptions will raise. None not possible.
418- # Bytes only option
0 commit comments