Skip to content

Commit 2e56d25

Browse files
authored
Introduce DataSource protocol for extensible polling architecture (#65)
* addressed issue #1 of week 3 * fixed lint errors * addressed ai reviews * fixed lint error
1 parent ee54b2b commit 2e56d25

7 files changed

Lines changed: 389 additions & 90 deletions

File tree

src/paperscout/__main__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
register_handlers,
2525
)
2626
from .shutdown import shutdown_services
27-
from .sources import ISOProber, WG21Index
27+
from .sources import ISOProber, OpenStdSource, WG21Index
2828
from .storage import ProbeState, UserWatchlist
2929

3030
log = logging.getLogger("paperscout")
@@ -222,6 +222,9 @@ async def _async_main() -> None:
222222
user_watchlist = UserWatchlist(pool)
223223
index = WG21Index(pool, cfg=settings)
224224
prober = ISOProber(index, state, user_watchlist)
225+
sources: list = [index, prober]
226+
if settings.enable_open_std:
227+
sources.append(OpenStdSource())
225228
app = create_app()
226229
mq = MessageQueue(app)
227230
mq.start()
@@ -252,10 +255,10 @@ def _pool_status(p) -> dict:
252255
return status
253256

254257
scheduler = Scheduler(
255-
index=index,
256-
prober=prober,
258+
sources=sources,
257259
user_watchlist=user_watchlist,
258260
state=state,
261+
cfg=settings,
259262
notify_callback=_on_poll_result,
260263
ops_alert_fn=_ops_alert,
261264
)

src/paperscout/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class Settings(BaseSettings):
3535
poll_overrun_cooldown_seconds: int = Field(default=300, ge=1) # 5 min
3636
enable_bulk_wg21: bool = True
3737
enable_iso_probe: bool = True
38+
enable_open_std: bool = False
3839

3940
# -- Paper prefixes / extensions (globals used for gap/unknown numbers) --
4041
probe_prefixes: list[str] = Field(default_factory=lambda: ["D", "P"])

src/paperscout/monitor.py

Lines changed: 108 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,20 @@
88
import logging
99
import threading
1010
import time
11-
from collections.abc import Callable, Mapping
11+
from collections.abc import Callable, Mapping, Sequence
1212
from dataclasses import dataclass
1313
from datetime import datetime, timezone
1414
from types import MappingProxyType
15-
from typing import Any
15+
from typing import Any, cast
1616

1717
import httpx
1818

1919
from .concurrency import run_blocking_io
2020
from .config import Settings, settings
2121
from .errors import ConfigurationError, FailureCategory
2222
from .models import CycleResult, CycleStatus, Paper, PerUserMatches, ProbeHit
23-
from .sources import ISOProber, WG21Index
23+
from .protocols import SOURCE_ISO_PROBE, SOURCE_OPEN_STD, SOURCE_WG21_INDEX, DataSource
24+
from .sources import ISOProber, OpenStdEntry, WG21Index
2425
from .storage import ProbeState, UserWatchlist
2526

2627
log = logging.getLogger(__name__)
@@ -37,7 +38,7 @@ class DiffResult:
3738
updated_papers: list[Paper]
3839

3940

40-
def diff_snapshots(
41+
def _diff_paper_maps(
4142
previous: dict[str, Paper],
4243
current: dict[str, Paper],
4344
) -> DiffResult:
@@ -67,6 +68,14 @@ def _paper_sort_key(p: Paper) -> tuple[str, str]:
6768
return DiffResult(new_papers=new_papers, updated_papers=updated_papers)
6869

6970

71+
def diff_snapshots(
72+
previous: dict[str, Paper],
73+
current: dict[str, Paper],
74+
) -> DiffResult:
75+
"""Compare two id→paper maps; detect additions and metadata changes."""
76+
return _diff_paper_maps(previous, current)
77+
78+
7079
# ── Poll Result ──────────────────────────────────────────────────────────────
7180

7281

@@ -150,22 +159,20 @@ class Scheduler:
150159

151160
def __init__(
152161
self,
153-
index: WG21Index,
154-
prober: ISOProber,
162+
sources: Sequence[DataSource],
155163
user_watchlist: UserWatchlist,
156164
state: ProbeState,
157165
cfg: Settings | None = None,
158166
notify_callback=None,
159167
ops_alert_fn: Callable[[str], None] | None = None,
160168
):
161-
self.index = index
162-
self.prober = prober
169+
self.sources = list(sources)
163170
self.user_watchlist = user_watchlist
164171
self.state = state
165172
self.cfg = cfg or settings
166173
self.notify_callback = notify_callback
167174
self.ops_alert_fn = ops_alert_fn
168-
self._previous_papers: dict[str, Paper] = {}
175+
self._snapshots: dict[str, Any] = {}
169176
self._seeded = False
170177
self._poll_count = 0
171178
self._last_successful_poll: float | None = None
@@ -176,6 +183,83 @@ def __init__(
176183
self._health_lock = threading.Lock()
177184
self._health_snapshot: SchedulerSnapshot | None = None
178185

186+
def _source_by_id(self, source_id: str) -> DataSource | None:
187+
for source in self.sources:
188+
if source.source_id == source_id:
189+
return source
190+
return None
191+
192+
def _wg21_index(self) -> WG21Index | None:
193+
source = self._source_by_id(SOURCE_WG21_INDEX)
194+
return cast(WG21Index, source) if source is not None else None
195+
196+
def _iso_prober(self) -> ISOProber | None:
197+
source = self._source_by_id(SOURCE_ISO_PROBE)
198+
return cast(ISOProber, source) if source is not None else None
199+
200+
def _source_enabled(self, source_id: str) -> bool:
201+
if source_id == SOURCE_WG21_INDEX:
202+
return self.cfg.enable_bulk_wg21
203+
if source_id == SOURCE_ISO_PROBE:
204+
return self.cfg.enable_iso_probe
205+
if source_id == SOURCE_OPEN_STD:
206+
return self.cfg.enable_open_std
207+
return True
208+
209+
def _log_index_diff(self, diff: DiffResult) -> None:
210+
for paper in diff.new_papers:
211+
log.info(
212+
"INDEX-NEW id=%-14s author=%-20s date=%s title=%r",
213+
paper.id,
214+
paper.author or "?",
215+
paper.date or "?",
216+
(paper.title or "")[:80],
217+
)
218+
for paper in diff.updated_papers:
219+
log.debug(
220+
"INDEX-UPD id=%-14s author=%-20s date=%s",
221+
paper.id,
222+
paper.author or "?",
223+
paper.date or "?",
224+
)
225+
226+
async def _poll_sources(self, *, baseline: bool = False) -> tuple[DiffResult, list[ProbeHit]]:
227+
diff = DiffResult(new_papers=[], updated_papers=[])
228+
probe_hits: list[ProbeHit] = []
229+
230+
for source in self.sources:
231+
if not self._source_enabled(source.source_id):
232+
continue
233+
234+
current = await source.fetch()
235+
if baseline:
236+
self._snapshots[source.source_id] = current
237+
if source.source_id == SOURCE_ISO_PROBE:
238+
cycle = cast(CycleResult, current)
239+
probe_hits = self._probe_hits_from_cycle(cycle)
240+
self._record_probe_cycle_completion()
241+
continue
242+
243+
previous = self._snapshots.get(source.source_id)
244+
result = source.diff(previous, current)
245+
self._snapshots[source.source_id] = current
246+
247+
if source.source_id == SOURCE_WG21_INDEX:
248+
diff = result
249+
papers = cast(dict[str, Paper], current)
250+
log.info("INDEX-LOAD papers=%d", len(papers))
251+
self._log_index_diff(diff)
252+
elif source.source_id == SOURCE_ISO_PROBE:
253+
cycle = cast(CycleResult, current)
254+
probe_hits = self._probe_hits_from_cycle(cycle)
255+
self._record_probe_cycle_completion()
256+
elif source.source_id == SOURCE_OPEN_STD:
257+
new_entries = cast(list[OpenStdEntry], result)
258+
if new_entries:
259+
log.info("OPEN-STD new=%d", len(new_entries))
260+
261+
return diff, probe_hits
262+
179263
def _probe_hits_from_cycle(self, cycle: CycleResult) -> list[ProbeHit]:
180264
"""Extract hits and record last cycle status for health / staleness."""
181265
self._last_cycle_status = cycle.status
@@ -193,7 +277,9 @@ def _probe_hits_from_cycle(self, cycle: CycleResult) -> list[ProbeHit]:
193277

194278
def _record_probe_cycle_completion(self) -> None:
195279
"""Update probe stats after any completed cycle (including FAILED)."""
196-
self._last_probe_stats = self.prober.snapshot_stats()
280+
prober = self._iso_prober()
281+
if prober is not None:
282+
self._last_probe_stats = prober.snapshot_stats()
197283

198284
def _mark_poll_successful_if_probe_ok(self) -> None:
199285
"""Advance staleness clock only when the last probe cycle did not fail."""
@@ -245,24 +331,25 @@ async def seed(self) -> SeedResult:
245331
t0 = time.monotonic()
246332
log.info("SEED-START seeding local database from all sources")
247333

248-
if self.cfg.enable_bulk_wg21:
249-
await self.index.refresh()
250-
log.info("SEED wg21.link loaded papers=%d", len(self.index.papers))
334+
diff, hits = await self._poll_sources(baseline=True)
335+
del diff
251336

252-
self._previous_papers = dict(self.index.papers)
253-
254-
hits: list[ProbeHit] = []
337+
wg21 = self._wg21_index()
338+
paper_count = (
339+
len(wg21.papers)
340+
if wg21 is not None
341+
else len(self._snapshots.get(SOURCE_WG21_INDEX, {}))
342+
)
343+
if self.cfg.enable_bulk_wg21 and wg21 is not None:
344+
log.info("SEED wg21.link loaded papers=%d", len(wg21.papers))
255345
if self.cfg.enable_iso_probe:
256-
cycle = await self.prober.run_cycle()
257-
hits = self._probe_hits_from_cycle(cycle)
258-
self._record_probe_cycle_completion()
259346
log.info("SEED isocpp.org probe existing=%d", len(hits))
260347

261348
self._seeded = True
262349
log.info(
263350
"SEED-DONE elapsed=%.1fs papers=%d discovered=%d had_prior_state=%s",
264351
time.monotonic() - t0,
265-
len(self._previous_papers),
352+
paper_count,
266353
len(self.state.get_all_discovered()),
267354
had_prior_state,
268355
)
@@ -330,37 +417,7 @@ async def poll_once(self) -> PollResult:
330417
self._publish_health_snapshot()
331418
return result
332419

333-
previous = dict(self._previous_papers)
334-
335-
if self.cfg.enable_bulk_wg21:
336-
await self.index.refresh()
337-
log.info("INDEX-LOAD papers=%d", len(self.index.papers))
338-
339-
diff = diff_snapshots(previous, self.index.papers)
340-
self._previous_papers = dict(self.index.papers)
341-
342-
for paper in diff.new_papers:
343-
log.info(
344-
"INDEX-NEW id=%-14s author=%-20s date=%s title=%r",
345-
paper.id,
346-
paper.author or "?",
347-
paper.date or "?",
348-
(paper.title or "")[:80],
349-
)
350-
for paper in diff.updated_papers:
351-
log.debug(
352-
"INDEX-UPD id=%-14s author=%-20s date=%s",
353-
paper.id,
354-
paper.author or "?",
355-
paper.date or "?",
356-
)
357-
358-
probe_hits: list[ProbeHit] = []
359-
if self.cfg.enable_iso_probe:
360-
cycle = await self.prober.run_cycle()
361-
probe_hits = self._probe_hits_from_cycle(cycle)
362-
self._record_probe_cycle_completion()
363-
420+
diff, probe_hits = await self._poll_sources()
364421
recent_hits = [h for h in probe_hits if h.is_recent]
365422
old_hits = [h for h in probe_hits if not h.is_recent]
366423

src/paperscout/protocols.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""Structural typing contracts for pluggable data sources.
2+
3+
Known ``source_id`` values:
4+
5+
- ``"wg21_index"`` — :class:`~paperscout.sources.WG21Index`
6+
- ``"iso_probe"`` — :class:`~paperscout.sources.ISOProber`
7+
- ``"open_std"`` — :class:`~paperscout.sources.OpenStdSource`
8+
"""
9+
10+
from __future__ import annotations
11+
12+
from typing import Any, Protocol, runtime_checkable
13+
14+
# Well-known source identifiers (stable across releases).
15+
SOURCE_WG21_INDEX = "wg21_index"
16+
SOURCE_ISO_PROBE = "iso_probe"
17+
SOURCE_OPEN_STD = "open_std"
18+
19+
20+
@runtime_checkable
21+
class DataSource(Protocol):
22+
"""Contract for fetch/parse/diff data sources polled by :class:`~paperscout.monitor.Scheduler`."""
23+
24+
@property
25+
def source_id(self) -> str: ...
26+
27+
async def fetch(self) -> Any:
28+
"""Fetch the latest snapshot from this source."""
29+
...
30+
31+
def diff(self, previous: Any, current: Any) -> Any:
32+
"""Compare *previous* and *current* snapshots; return source-specific diff."""
33+
...

0 commit comments

Comments
 (0)