Skip to content

Commit 57d0923

Browse files
authored
StreamEdge: move track resolution logic to TrackResolver class (#538)
1 parent 0c684e5 commit 57d0923

3 files changed

Lines changed: 591 additions & 180 deletions

File tree

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
import asyncio
2+
3+
import pytest
4+
from getstream.video.rtc.pb.stream.video.sfu.models.models_pb2 import (
5+
TrackType as StreamTrackType,
6+
)
7+
from vision_agents.plugins.getstream._track_resolver import TrackResolver
8+
9+
10+
@pytest.fixture
11+
def resolver():
12+
return TrackResolver(poll_interval=0.005)
13+
14+
15+
class TestTrackResolver:
16+
async def test_known_track_reuse(self, resolver):
17+
resolver.register(
18+
track_id="t1",
19+
user_id="u1",
20+
session_id="s1",
21+
webrtc_kind="audio",
22+
)
23+
first = await resolver.resolve(
24+
user_id="u1",
25+
session_id="s1",
26+
stream_track_type=StreamTrackType.TRACK_TYPE_AUDIO,
27+
)
28+
29+
resolver.unpublish(
30+
user_id="u1",
31+
session_id="s1",
32+
stream_track_type=StreamTrackType.TRACK_TYPE_AUDIO,
33+
)
34+
35+
second = await resolver.resolve(
36+
user_id="u1",
37+
session_id="s1",
38+
stream_track_type=StreamTrackType.TRACK_TYPE_AUDIO,
39+
)
40+
41+
assert first == "t1"
42+
assert second == "t1"
43+
44+
async def test_session_migration(self, resolver):
45+
resolver.register(
46+
track_id="t1",
47+
user_id="u1",
48+
session_id="s_old",
49+
webrtc_kind="audio",
50+
)
51+
await resolver.resolve(
52+
user_id="u1",
53+
session_id="s_old",
54+
stream_track_type=StreamTrackType.TRACK_TYPE_AUDIO,
55+
)
56+
57+
# New session arrives without a fresh register() — same WebRTC track is reused.
58+
migrated = await resolver.resolve(
59+
user_id="u1",
60+
session_id="s_new",
61+
stream_track_type=StreamTrackType.TRACK_TYPE_AUDIO,
62+
timeout=0.1,
63+
)
64+
assert migrated == "t1"
65+
66+
# The old session entry is gone; the new one owns the track now.
67+
old_unpublish = resolver.unpublish(
68+
user_id="u1",
69+
session_id="s_old",
70+
stream_track_type=StreamTrackType.TRACK_TYPE_AUDIO,
71+
)
72+
new_unpublish = resolver.unpublish(
73+
user_id="u1",
74+
session_id="s_new",
75+
stream_track_type=StreamTrackType.TRACK_TYPE_AUDIO,
76+
)
77+
assert old_unpublish is None
78+
assert new_unpublish == "t1"
79+
80+
async def test_pending_arrives_first(self, resolver):
81+
resolver.register(
82+
track_id="t1",
83+
user_id="u1",
84+
session_id="s1",
85+
webrtc_kind="video",
86+
)
87+
track_id = await resolver.resolve(
88+
user_id="u1",
89+
session_id="s1",
90+
stream_track_type=StreamTrackType.TRACK_TYPE_VIDEO,
91+
)
92+
assert track_id == "t1"
93+
94+
async def test_track_published_arrives_first(self, resolver):
95+
resolve_task = asyncio.create_task(
96+
resolver.resolve(
97+
user_id="u1",
98+
session_id="s1",
99+
stream_track_type=StreamTrackType.TRACK_TYPE_VIDEO,
100+
timeout=1.0,
101+
)
102+
)
103+
await asyncio.sleep(0.02)
104+
resolver.register(
105+
track_id="t1",
106+
user_id="u1",
107+
session_id="s1",
108+
webrtc_kind="video",
109+
)
110+
track_id = await resolve_task
111+
assert track_id == "t1"
112+
113+
async def test_anonymous_fallback_success(self, resolver):
114+
resolver.register(
115+
track_id="t_anon",
116+
user_id=None,
117+
session_id=None,
118+
webrtc_kind="video",
119+
)
120+
track_id = await resolver.resolve(
121+
user_id="u1",
122+
session_id="s1",
123+
stream_track_type=StreamTrackType.TRACK_TYPE_VIDEO,
124+
)
125+
assert track_id == "t_anon"
126+
127+
async def test_anonymous_fallback_ambiguous(self, resolver):
128+
resolver.register(
129+
track_id="t_anon_a",
130+
user_id=None,
131+
session_id=None,
132+
webrtc_kind="video",
133+
)
134+
resolver.register(
135+
track_id="t_anon_b",
136+
user_id=None,
137+
session_id=None,
138+
webrtc_kind="video",
139+
)
140+
with pytest.raises(TimeoutError):
141+
await resolver.resolve(
142+
user_id="u1",
143+
session_id="s1",
144+
stream_track_type=StreamTrackType.TRACK_TYPE_VIDEO,
145+
timeout=0.05,
146+
)
147+
148+
async def test_timeout_no_pending(self, resolver):
149+
with pytest.raises(TimeoutError):
150+
await resolver.resolve(
151+
user_id="u1",
152+
session_id="s1",
153+
stream_track_type=StreamTrackType.TRACK_TYPE_VIDEO,
154+
timeout=0.05,
155+
)
156+
157+
async def test_cancel_during_resolve(self, resolver):
158+
resolve_task = asyncio.create_task(
159+
resolver.resolve(
160+
user_id="u1",
161+
session_id="s1",
162+
stream_track_type=StreamTrackType.TRACK_TYPE_VIDEO,
163+
timeout=10.0,
164+
)
165+
)
166+
await asyncio.sleep(0.02)
167+
168+
resolver.cancel(user_id="u1", session_id="s1")
169+
170+
track_id = await asyncio.wait_for(resolve_task, timeout=0.5)
171+
assert track_id is None
172+
173+
async def test_stale_pending_is_evicted(self):
174+
# Short TTL so we can verify the eviction without long sleeps.
175+
resolver = TrackResolver(poll_interval=0.005, pending_ttl=0.05)
176+
177+
# Stale anonymous video registered first; would normally make the
178+
# fallback ambiguous when a second anonymous video arrives.
179+
resolver.register(
180+
track_id="t_stale",
181+
user_id=None,
182+
session_id=None,
183+
webrtc_kind="video",
184+
)
185+
await asyncio.sleep(0.08)
186+
187+
resolver.register(
188+
track_id="t_fresh",
189+
user_id=None,
190+
session_id=None,
191+
webrtc_kind="video",
192+
)
193+
track_id = await resolver.resolve(
194+
user_id="u1",
195+
session_id="s1",
196+
stream_track_type=StreamTrackType.TRACK_TYPE_VIDEO,
197+
timeout=0.5,
198+
)
199+
assert track_id == "t_fresh"
200+
201+
async def test_concurrent_resolves_serialized(self, resolver):
202+
# Duplicate TrackPublishedEvent (e.g. from republish_tracks) starts two
203+
# resolves for the same key. Both must succeed with the same track_id.
204+
task_a = asyncio.create_task(
205+
resolver.resolve(
206+
user_id="u1",
207+
session_id="s1",
208+
stream_track_type=StreamTrackType.TRACK_TYPE_AUDIO,
209+
timeout=0.5,
210+
)
211+
)
212+
task_b = asyncio.create_task(
213+
resolver.resolve(
214+
user_id="u1",
215+
session_id="s1",
216+
stream_track_type=StreamTrackType.TRACK_TYPE_AUDIO,
217+
timeout=0.5,
218+
)
219+
)
220+
await asyncio.sleep(0.02)
221+
resolver.register(
222+
track_id="t1",
223+
user_id="u1",
224+
session_id="s1",
225+
webrtc_kind="audio",
226+
)
227+
results = await asyncio.gather(task_a, task_b)
228+
assert results == ["t1", "t1"]
229+
230+
async def test_cancel_drops_named_pending(self, resolver):
231+
# Named pending was registered (track_added fired), participant leaves
232+
# before TrackPublishedEvent. cancel() should drop the orphan.
233+
resolver.register(
234+
track_id="t_orphan",
235+
user_id="u1",
236+
session_id="s1",
237+
webrtc_kind="video",
238+
)
239+
resolver.cancel(user_id="u1", session_id="s1")
240+
241+
# If the orphan were still around, the next anonymous video would be
242+
# ambiguous against an exact-tuple lookup attempt — but here we just
243+
# verify a fresh resolve for the same key times out (no stale match).
244+
with pytest.raises(TimeoutError):
245+
await resolver.resolve(
246+
user_id="u1",
247+
session_id="s1",
248+
stream_track_type=StreamTrackType.TRACK_TYPE_VIDEO,
249+
timeout=0.05,
250+
)
251+
252+
async def test_cancel_before_resolve_is_noop(self, resolver):
253+
# Cancel arrives first; nothing in flight, no-op.
254+
resolver.cancel(user_id="u1", session_id="s1")
255+
256+
# Subsequent resolve runs normally and finds the matching pending.
257+
resolver.register(
258+
track_id="t1",
259+
user_id="u1",
260+
session_id="s1",
261+
webrtc_kind="video",
262+
)
263+
track_id = await resolver.resolve(
264+
user_id="u1",
265+
session_id="s1",
266+
stream_track_type=StreamTrackType.TRACK_TYPE_VIDEO,
267+
timeout=0.5,
268+
)
269+
assert track_id == "t1"

0 commit comments

Comments
 (0)