|
1 | 1 | import asyncio |
| 2 | +import hashlib |
| 3 | +import json |
2 | 4 | import sys |
3 | 5 | import typing as T |
4 | 6 | from collections import deque |
@@ -49,6 +51,8 @@ class _StreamState: |
49 | 51 | task_failures: list[str] = field(default_factory=list) |
50 | 52 | seen_message_ids: set[str] = field(default_factory=set) |
51 | 53 | seen_message_order: deque[str] = field(default_factory=deque) |
| 54 | + # Fallback tracking for backends that omit message ids in values events. |
| 55 | + no_id_message_fingerprints: dict[int, str] = field(default_factory=dict) |
52 | 56 | baseline_initialized: bool = False |
53 | 57 | run_values_messages: list[dict[str, T.Any]] = field(default_factory=list) |
54 | 58 | timed_out: bool = False |
@@ -267,16 +271,38 @@ def _extract_new_messages_from_values( |
267 | 271 | state: _StreamState, |
268 | 272 | ) -> list[dict[str, T.Any]]: |
269 | 273 | new_messages: list[dict[str, T.Any]] = [] |
270 | | - for msg in values_messages: |
| 274 | + no_id_indexes_seen: set[int] = set() |
| 275 | + for idx, msg in enumerate(values_messages): |
271 | 276 | if not isinstance(msg, dict): |
272 | 277 | continue |
273 | 278 | msg_id = get_message_id(msg) |
274 | | - if not msg_id or msg_id in state.seen_message_ids: |
| 279 | + if msg_id: |
| 280 | + if msg_id in state.seen_message_ids: |
| 281 | + continue |
| 282 | + self._remember_seen_message_id(state, msg_id) |
| 283 | + new_messages.append(msg) |
275 | 284 | continue |
276 | | - self._remember_seen_message_id(state, msg_id) |
| 285 | + |
| 286 | + no_id_indexes_seen.add(idx) |
| 287 | + msg_fingerprint = self._fingerprint_message(msg) |
| 288 | + if state.no_id_message_fingerprints.get(idx) == msg_fingerprint: |
| 289 | + continue |
| 290 | + state.no_id_message_fingerprints[idx] = msg_fingerprint |
277 | 291 | new_messages.append(msg) |
| 292 | + |
| 293 | + # Keep no-id index state aligned with latest values payload shape. |
| 294 | + for idx in list(state.no_id_message_fingerprints.keys()): |
| 295 | + if idx not in no_id_indexes_seen: |
| 296 | + state.no_id_message_fingerprints.pop(idx, None) |
278 | 297 | return new_messages |
279 | 298 |
|
| 299 | + def _fingerprint_message(self, message: dict[str, T.Any]) -> str: |
| 300 | + try: |
| 301 | + raw = json.dumps(message, sort_keys=True, ensure_ascii=False, default=str) |
| 302 | + except (TypeError, ValueError): |
| 303 | + raw = repr(message) |
| 304 | + return hashlib.sha1(raw.encode("utf-8", errors="ignore")).hexdigest() |
| 305 | + |
280 | 306 | def _remember_seen_message_id(self, state: _StreamState, msg_id: str) -> None: |
281 | 307 | if not msg_id or msg_id in state.seen_message_ids: |
282 | 308 | return |
@@ -422,9 +448,14 @@ def _handle_values_event( |
422 | 448 |
|
423 | 449 | if not state.baseline_initialized: |
424 | 450 | state.baseline_initialized = True |
425 | | - for msg in values_messages: |
| 451 | + for idx, msg in enumerate(values_messages): |
| 452 | + if not isinstance(msg, dict): |
| 453 | + continue |
426 | 454 | msg_id = get_message_id(msg) |
427 | | - self._remember_seen_message_id(state, msg_id) |
| 455 | + if msg_id: |
| 456 | + self._remember_seen_message_id(state, msg_id) |
| 457 | + continue |
| 458 | + state.no_id_message_fingerprints[idx] = self._fingerprint_message(msg) |
428 | 459 | return responses |
429 | 460 |
|
430 | 461 | new_messages = self._extract_new_messages_from_values( |
|
0 commit comments