|
10 | 10 | import warnings |
11 | 11 | from abc import abstractmethod |
12 | 12 | from collections import defaultdict |
13 | | -from collections.abc import Callable, Container, Iterator, Mapping |
| 13 | +from collections.abc import Callable, Container, Coroutine, Iterator, Mapping |
14 | 14 | from contextlib import contextmanager, suppress |
15 | 15 | from pathlib import Path |
16 | 16 | from typing import ( |
17 | 17 | TYPE_CHECKING, |
| 18 | + Any, |
18 | 19 | ClassVar, |
19 | 20 | Generic, |
20 | 21 | Literal, |
|
24 | 25 | ) |
25 | 26 |
|
26 | 27 | import filetype |
| 28 | +import httpc |
27 | 29 | import httpx |
28 | 30 | import pyfilename as pf |
29 | 31 | from filetype.types import IMAGE |
30 | 32 | from rich import progress |
31 | 33 | from yarl import URL |
32 | 34 |
|
33 | | -import httpc |
34 | | - |
35 | 35 | from ..base import console, logger, platforms |
36 | 36 | from ..directory_state import ( |
37 | 37 | DirectoryState, |
@@ -157,7 +157,6 @@ class Scraper(Generic[WebtoonId]): # MARK: SCRAPER |
157 | 157 | PLATFORM: ClassVar[str] |
158 | 158 | DOWNLOAD_INTERVAL: int | float = 0 |
159 | 159 | EXTRA_INFO_SCRAPER_FACTORY: type[ExtraInfoScraper] = ExtraInfoScraper |
160 | | - TASK_QUEUE_FACTORY: Callable = asyncio.Queue |
161 | 160 | information_vars: dict[str, None | str | Path | Callable] = dict( |
162 | 161 | title=None, |
163 | 162 | platform="PLATFORM", |
@@ -207,8 +206,10 @@ def __init__(self, webtoon_id: WebtoonId) -> None: |
207 | 206 | self.skip_download: list[int] = [] |
208 | 207 | """0-based index를 사용해 다운로드를 생략할 웹툰을 결정합니다.""" |
209 | 208 | self._download_status: Literal["downloading", "nothing", "canceling"] = "nothing" |
210 | | - self._triggers: defaultdict[str, list[Callable]] = defaultdict(list) |
211 | | - self._tasks: asyncio.Queue[asyncio.Future] = self.TASK_QUEUE_FACTORY() |
| 209 | + # self._triggers: defaultdict[tuple[Literal["async", "async_task"], str], list[Callable[..., Coroutine]]] | defaultdict[tuple[Literal["sync"], str], list[Callable]] = defaultdict(list) |
| 210 | + # 적어도 pyright에서는 위의 type expr가 잘 작동하지 않음. 아래의 더 generic한 버전을 사용 |
| 211 | + self._triggers: defaultdict[tuple[Literal["sync", "async", "async_task"], str], list[Callable[..., Coroutine]] | list[Callable]] = defaultdict(list) |
| 212 | + self._tasks: asyncio.Queue[asyncio.Future] = asyncio.Queue() |
212 | 213 | """_tasks에 값을 등록해 두면 스크래퍼가 종료될 때 해당 task들을 완료하거나 취소합니다.""" |
213 | 214 |
|
214 | 215 | # initialize extra info scraper |
@@ -386,6 +387,49 @@ async def fetch_all(self, reload: bool = False) -> None: |
386 | 387 | await self.fetch_webtoon_information(reload=reload) |
387 | 388 | await self.fetch_episode_information(reload=reload) |
388 | 389 |
|
| 390 | + @overload |
| 391 | + def register_async_callback(self, trigger: str, func: CallableT, *, blocking: bool = True) -> CallableT: ... |
| 392 | + |
| 393 | + @overload |
| 394 | + def register_async_callback(self, trigger: str, *, blocking: bool = True) -> Callable[[CallableT], CallableT]: ... |
| 395 | + |
| 396 | + def register_async_callback(self, trigger: str, func: Callable[..., Coroutine] | None = None, *, blocking: bool = True) -> Any: |
| 397 | + """특정 callback 트리거가 발생했을 때 실행할 비동기 콜백을 등록합니다.""" |
| 398 | + if func is None: |
| 399 | + return lambda func: self.register_async_callback(trigger, func, blocking=blocking) |
| 400 | + |
| 401 | + # blocking으로 할지 말지를 callback을 등록할 때 해야 할까, 아님 부를 때 결정해야 할까? |
| 402 | + # 실례를 한번 봐야 할 것 같은데 아직은 잘 모르겠다. |
| 403 | + # 일단 지금은 callback을 등록할 때 결정하는 것으로 한다. |
| 404 | + self._triggers[("async" if blocking else "async_task", trigger)].append(func) |
| 405 | + return func |
| 406 | + |
| 407 | + async def async_callback(self, situation: str, **context) -> list[asyncio.Task] | None: |
| 408 | + if callbacks := self._triggers.get(("async", situation)): |
| 409 | + for callback in callbacks: |
| 410 | + await callback(scraper=self, **context) |
| 411 | + |
| 412 | + if callbacks := self._triggers.get(("async_task", situation)): |
| 413 | + tasks = [] |
| 414 | + print("starting task!") |
| 415 | + for callback in callbacks: |
| 416 | + task = asyncio.create_task(callback(scraper=self, **context)) |
| 417 | + await self._tasks.put(task) |
| 418 | + print("I got task!") |
| 419 | + tasks.append(task) |
| 420 | + return tasks or None |
| 421 | + |
| 422 | + self.callback(situation, **context) |
| 423 | + |
| 424 | + def unregister_callback(self, trigger: str, func: Callable, type: Literal["sync", "async", "async_task"] | None = None) -> None: |
| 425 | + # 굳이 빈 key를 만들 필욘 없으니 get을 사용. 그냥 [] 사용해도 솔직히 상관없음. |
| 426 | + if type == "sync" or type is None: |
| 427 | + self._triggers.get(("sync", trigger), []).remove(func) |
| 428 | + if type == "async" or type is None: |
| 429 | + self._triggers.get(("async", trigger), []).remove(func) |
| 430 | + if type == "async_task" or type is None: |
| 431 | + self._triggers.get(("async_task", trigger), []).remove(func) |
| 432 | + |
389 | 433 | @overload |
390 | 434 | def register_callback(self, trigger: str, func: CallableT) -> CallableT: ... |
391 | 435 |
|
@@ -424,7 +468,7 @@ def startup_message(scraper: Scraper, finishing: bool, **context): |
424 | 468 | if func is None: |
425 | 469 | return lambda func: self.register_callback(trigger, func) |
426 | 470 |
|
427 | | - self._triggers[trigger].append(func) |
| 471 | + self._triggers[("sync", trigger)].append(func) |
428 | 472 | return func |
429 | 473 |
|
430 | 474 | def callback(self, situation: str, **context) -> None: |
@@ -496,7 +540,7 @@ def callback(self, situation: str, **context) -> None: |
496 | 540 | case the_others, _: |
497 | 541 | logger.debug(f"WebtoonScraper status: {the_others}") |
498 | 542 |
|
499 | | - if callbacks := self._triggers.get(situation): |
| 543 | + if callbacks := self._triggers.get(("sync", situation)): |
500 | 544 | for callback in callbacks: |
501 | 545 | callback(scraper=self, **context) |
502 | 546 |
|
|
0 commit comments