Skip to content
This repository was archived by the owner on Jun 10, 2025. It is now read-only.

Commit 1d91138

Browse files
author
Mathieu Gascon-Lefebvre
committed
Batching: Keep ids and tags in BatchingTopic and BatchingInventory.
1 parent d3d0013 commit 1d91138

4 files changed

Lines changed: 71 additions & 2 deletions

File tree

src/saturn_engine/worker/inventories/batching.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,8 @@ async def iterate(self, after: Optional[str] = None) -> AsyncIterator[Item]:
4141
return
4242

4343
after = batch[-1].id
44-
yield Item(id=after, args={"batch": [item.args for item in batch]})
44+
yield Item(
45+
id=after,
46+
args={"batch": [item.args for item in batch]},
47+
tags={"batched_ids": ", ".join([item.id for item in batch])},
48+
)

src/saturn_engine/worker/topics/batching.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from typing import AsyncContextManager
22
from typing import AsyncIterator
3+
from typing import DefaultDict
34

45
import asyncio
56
import contextlib
67
import dataclasses
8+
from collections import defaultdict
79
from collections.abc import AsyncGenerator
810
from contextlib import asynccontextmanager
911
from datetime import datetime
@@ -109,6 +111,8 @@ async def message_context(
109111
) -> AsyncIterator[TopicMessage]:
110112
context = contextlib.AsyncExitStack()
111113
message_args: list[dict] = []
114+
ids: list[str] = []
115+
tag_lists: DefaultDict[str, list[str]] = defaultdict(list)
112116

113117
for message_context in batch:
114118
message: TopicMessage
@@ -118,5 +122,16 @@ async def message_context(
118122
message = message_context
119123
message_args.append(message.args)
120124

125+
ids.append(message.id)
126+
for tag, value in message.tags.items():
127+
tag_lists[tag].append(value)
128+
129+
tags = {"batched_ids": ", ".join(ids)} | {
130+
tag: ", ".join(values) for tag, values in tag_lists.items()
131+
}
132+
121133
async with context:
122-
yield TopicMessage(args={"batch": message_args})
134+
yield TopicMessage(
135+
args={"batch": message_args},
136+
tags=tags,
137+
)

tests/worker/inventories/test_batching_inventory.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,21 @@ async def test_batching_inventory() -> None:
3232
},
3333
),
3434
]
35+
assert [i.tags for i in items] == [
36+
{"batched_ids": "0, 1, 2"},
37+
{"batched_ids": "3, 4, 5"},
38+
{"batched_ids": "6, 7, 8"},
39+
{"batched_ids": "9"},
40+
]
3541

3642
items = await alib.list(inventory.iterate(after="4"))
3743

3844
assert [(i.id, i.args) for i in items] == [
3945
("7", {"batch": [{"a": "5"}, {"a": "6"}, {"a": "7"}]}),
4046
("9", {"batch": [{"a": "8"}, {"a": "9"}]}),
4147
]
48+
49+
assert [i.tags for i in items] == [
50+
{"batched_ids": "5, 6, 7"},
51+
{"batched_ids": "8, 9"},
52+
]

tests/worker/topics/test_batching_topic.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,42 @@ async def test_batching_topic_context_manager(
149149

150150
assert batch_number == 2
151151
await topic.close()
152+
153+
154+
@pytest.mark.asyncio
155+
async def test_batching_topic_tags() -> None:
156+
BATCH_SIZE = 5
157+
158+
topic = BatchingTopic(
159+
options=BatchingTopic.Options(
160+
topic=TopicItem(
161+
name="static-topic-with-tags",
162+
type="StaticTopic",
163+
options={
164+
"messages": [
165+
{"id": "1", "args": {}, "tags": {"hello": "1"}},
166+
{"id": "2", "args": {}, "tags": {"hello": "2", "hi": "a"}},
167+
{"id": "3", "args": {}, "tags": {"hello": "3"}},
168+
{"id": "4", "args": {}, "tags": {"hello": "4"}},
169+
{"id": "5", "args": {}, "tags": {"hello": "5", "hi": "b"}},
170+
],
171+
},
172+
),
173+
batch_size=BATCH_SIZE,
174+
),
175+
services=ServicesNamespace(strict=False),
176+
)
177+
178+
async with alib.scoped_iter(topic.run()) as scoped_topic_iter:
179+
context = await scoped_topic_iter.__anext__()
180+
assert isinstance(context, AsyncContextManager)
181+
async with context as message:
182+
...
183+
184+
await topic.close()
185+
186+
assert message.tags == {
187+
"batched_ids": "1, 2, 3, 4, 5",
188+
"hello": "1, 2, 3, 4, 5",
189+
"hi": "a, b",
190+
}

0 commit comments

Comments
 (0)