Skip to content

Commit d69e4f0

Browse files
authored
fix: incorrect tracking of limits after client-side filtering (#34)
1 parent 52b1d37 commit d69e4f0

4 files changed

Lines changed: 24 additions & 24 deletions

File tree

src/s2_sdk/_mappers.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -340,21 +340,8 @@ def append_ack_from_proto(ack: pb.AppendAck) -> AppendAck:
340340
)
341341

342342

343-
def read_batch_from_proto(
344-
batch: pb.ReadBatch, ignore_command_records: bool = False
345-
) -> ReadBatch:
346-
records = []
347-
for sr in batch.records:
348-
if ignore_command_records and _is_command_record(sr):
349-
continue
350-
records.append(
351-
SequencedRecord(
352-
seq_num=sr.seq_num,
353-
body=sr.body,
354-
headers=[(h.name, h.value) for h in sr.headers],
355-
timestamp=sr.timestamp,
356-
)
357-
)
343+
def read_batch_from_proto(batch: pb.ReadBatch) -> ReadBatch:
344+
records = [sequenced_record_from_proto(sr) for sr in batch.records]
358345
tail = None
359346
if batch.HasField("tail"):
360347
tail = StreamPosition(
@@ -373,12 +360,6 @@ def sequenced_record_from_proto(sr: pb.SequencedRecord) -> SequencedRecord:
373360
)
374361

375362

376-
def _is_command_record(sr: pb.SequencedRecord) -> bool:
377-
if len(sr.headers) == 1 and sr.headers[0].name == b"":
378-
return True
379-
return False
380-
381-
382363
def read_start_params(start: _ReadStart) -> dict[str, Any]:
383364
if isinstance(start, SeqNum):
384365
return {"seq_num": start.value}

src/s2_sdk/_ops.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -976,7 +976,13 @@ async def read(
976976

977977
proto_batch = pb.ReadBatch()
978978
proto_batch.ParseFromString(response.content)
979-
return read_batch_from_proto(proto_batch, ignore_command_records)
979+
batch = read_batch_from_proto(proto_batch)
980+
if ignore_command_records:
981+
batch = types.ReadBatch(
982+
records=[r for r in batch.records if not r.is_command_record()],
983+
tail=batch.tail,
984+
)
985+
return batch
980986

981987
@fallible
982988
async def read_session(

src/s2_sdk/_s2s/_read_session.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ async def run_read_session(
8181

8282
proto_batch = pb.ReadBatch()
8383
proto_batch.ParseFromString(message_body)
84-
batch = read_batch_from_proto(proto_batch, ignore_command_records)
84+
batch = read_batch_from_proto(proto_batch)
8585

8686
if batch.tail is not None:
8787
last_tail_at = time.monotonic()
@@ -103,7 +103,16 @@ async def run_read_session(
103103
)
104104
params["bytes"] = remaining_bytes
105105

106-
yield batch
106+
if ignore_command_records:
107+
batch = ReadBatch(
108+
records=[
109+
r for r in batch.records if not r.is_command_record()
110+
],
111+
tail=batch.tail,
112+
)
113+
114+
if batch.records:
115+
yield batch
107116

108117
return
109118
except Exception as e:

src/s2_sdk/_types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,10 @@ class SequencedRecord:
242242
timestamp: int
243243
"""Timestamp for this record."""
244244

245+
def is_command_record(self) -> bool:
246+
"""Check if this is a command record."""
247+
return len(self.headers) == 1 and self.headers[0][0] == b""
248+
245249

246250
@dataclass(slots=True)
247251
class ReadBatch:

0 commit comments

Comments
 (0)