Skip to content

Commit 0ae6fa2

Browse files
committed
fix: re-add connection logic
1 parent da7eee8 commit 0ae6fa2

7 files changed

Lines changed: 308 additions & 102 deletions

File tree

examples/stt_deepgram_transcription/main.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,17 @@ async def on_audio(pcm: PcmData, user):
8585
@stt.on("transcript")
8686
async def on_transcript(text: str, user: any, metadata: dict):
8787
timestamp = time.strftime("%H:%M:%S")
88-
user_info = user if user else "unknown"
88+
user_info = user.name if user and hasattr(user, "name") else "unknown"
8989
print(f"[{timestamp}] {user_info}: {text}")
9090
if metadata.get("confidence"):
9191
print(f" └─ confidence: {metadata['confidence']:.2%}")
9292

9393
@stt.on("partial_transcript")
9494
async def on_partial_transcript(text: str, user: any, metadata: dict):
9595
if text.strip(): # Only show non-empty partial transcripts
96-
user_info = user if user else "unknown"
96+
user_info = (
97+
user.name if user and hasattr(user, "name") else "unknown"
98+
)
9799
print(
98100
f" {user_info} (partial): {text}", end="\r"
99101
) # Overwrite line

getstream/video/rtc/connection_manager.py

Lines changed: 135 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from twirp.context import Context
1010

1111
from getstream.utils import StreamAsyncIOEventEmitter
12-
from getstream.video.rtc.coordinator import StreamAPIWS
1312
from getstream.video.rtc.pb.stream.video.sfu.models import models_pb2
13+
from getstream.video.rtc.pb.stream.video.sfu.signal_rpc import signal_pb2
1414
from getstream.video.rtc.twirp_client_wrapper import SignalClient
1515

1616
from getstream.video.call import Call
@@ -39,7 +39,7 @@ async def _log_event(event_type: str, data: Any):
3939

4040
class ConnectionManager(StreamAsyncIOEventEmitter):
4141
"""Main connection manager facade for video streaming."""
42-
42+
4343
def __init__(
4444
self,
4545
call: Call,
@@ -59,19 +59,22 @@ def __init__(
5959
self.session_id: str = str(uuid.uuid4())
6060
self.join_response: Optional[JoinCallResponse] = None
6161
self.local_sfu: bool = False # Local SFU flag for development
62-
62+
6363
# Private attributes
6464
self._connection_state: ConnectionState = ConnectionState.IDLE
6565
self._stop_event: asyncio.Event = asyncio.Event()
6666
self._connection_options: ConnectionOptions = ConnectionOptions()
6767
self._ws_client = None
68-
68+
self._coordinator_ws_client = None
69+
6970
# Initialize private managers
7071
self._participants_state: ParticipantsState = ParticipantsState()
7172
self._recording_manager: RecordingManager = RecordingManager()
7273
self._network_monitor: NetworkMonitor = NetworkMonitor(self)
7374
self._reconnector: ReconnectionManager = ReconnectionManager(self)
74-
self._subscription_manager: SubscriptionManager = SubscriptionManager(self, subscription_config)
75+
self._subscription_manager: SubscriptionManager = SubscriptionManager(
76+
self, subscription_config
77+
)
7578
self._peer_manager: PeerConnectionManager = PeerConnectionManager(self)
7679

7780
self.recording_manager = self._recording_manager # type: ignore
@@ -93,11 +96,8 @@ def connection_state(self, state: ConnectionState):
9396
old_state = self._connection_state
9497
self._connection_state = state
9598
# Schedule the emit as a background task since property setters cannot be async
96-
self.emit('connection.state_changed', {
97-
'old': old_state,
98-
'new': state
99-
})
100-
99+
self.emit("connection.state_changed", {"old": old_state, "new": state})
100+
101101
async def _on_ice_trickle(self, event):
102102
"""Handle ICE trickle from SFU."""
103103
logger.debug(f"Received ICE trickle for peer type {event.peer_type}")
@@ -124,6 +124,75 @@ async def _on_ice_trickle(self, event):
124124
except Exception as e:
125125
logger.debug(f"Error handling ICE trickle: {e}")
126126

127+
async def _on_subscriber_offer(self, event):
128+
"""Handle subscriber offer from SFU."""
129+
logger.info(f"Received subscriber offer: ice_restart={event.ice_restart}")
130+
131+
try:
132+
# Ensure we have a subscriber peer connection
133+
if not self.subscriber_pc:
134+
await self._peer_manager.setup_subscriber()
135+
136+
# Parse SDP to extract track-to-stream mapping
137+
self._extract_track_stream_mapping(event.sdp)
138+
139+
# Handle ICE restart if needed
140+
if event.ice_restart:
141+
logger.info("Restarting ICE for subscriber")
142+
await self.subscriber_pc.restartIce()
143+
144+
# Set remote description with the SFU's offer
145+
remote_description = aiortc.RTCSessionDescription(
146+
type="offer", sdp=event.sdp
147+
)
148+
await self.subscriber_pc.setRemoteDescription(remote_description)
149+
150+
# Create and set local answer
151+
answer = await self.subscriber_pc.createAnswer()
152+
await self.subscriber_pc.setLocalDescription(answer)
153+
154+
# Send answer back to SFU
155+
response = await self.twirp_signaling_client.SendAnswer(
156+
ctx=self.twirp_context,
157+
request=signal_pb2.SendAnswerRequest(
158+
session_id=self.session_id,
159+
peer_type=models_pb2.PEER_TYPE_SUBSCRIBER,
160+
sdp=self.subscriber_pc.localDescription.sdp,
161+
),
162+
server_path_prefix="",
163+
)
164+
logger.info(f"Sent subscriber answer: {response}")
165+
166+
except Exception as e:
167+
logger.error(f"Error handling subscriber offer: {e}")
168+
raise
169+
170+
def _extract_track_stream_mapping(self, sdp: str):
171+
"""Extract track-to-stream mapping from SDP."""
172+
track_mapping = {}
173+
174+
# Parse SDP to find track-to-stream mapping
175+
# SDP format includes lines like:
176+
# a=msid:<stream_id> <track_id>
177+
# a=mid:<media_id>
178+
for line in sdp.split("\n"):
179+
line = line.strip()
180+
if line.startswith("a=msid:"):
181+
# Extract msid line: a=msid:<stream_id> <track_id>
182+
parts = line.split(" ")
183+
if len(parts) >= 3:
184+
stream_id = parts[1]
185+
track_id = parts[2]
186+
track_mapping[track_id] = stream_id
187+
logger.debug(f"Extracted track mapping: {track_id} -> {stream_id}")
188+
189+
# Set the mapping in participants state
190+
if track_mapping:
191+
logger.info(f"Setting track stream mapping: {track_mapping}")
192+
self.participants_state.set_track_stream_mapping(track_mapping)
193+
else:
194+
logger.warning("No track-to-stream mapping found in SDP")
195+
127196
async def _connect_internal(
128197
self,
129198
region: Optional[str] = None,
@@ -139,10 +208,10 @@ async def _connect_internal(
139208
ws_url: Optional WebSocket URL to connect to
140209
token: Optional authentication token
141210
session_id: Optional session ID
142-
211+
143212
Raises:
144213
SfuConnectionError: If connection fails
145-
"""
214+
"""
146215
self.connection_state = ConnectionState.JOINING
147216

148217
# Step 1: Determine region
@@ -158,16 +227,21 @@ async def _connect_internal(
158227
# Step 2: Join call via coordinator
159228
if not (ws_url or token):
160229
join_response = await join_call(
161-
self.call, self.user_id, location, self.create, self.local_sfu, **self.kwargs
230+
self.call,
231+
self.user_id,
232+
location,
233+
self.create,
234+
self.local_sfu,
235+
**self.kwargs,
162236
)
163237
ws_url = join_response.data.credentials.server.ws_endpoint
164238
token = join_response.data.credentials.token
165239
self.join_response = join_response
166240
logger.debug(f"coordinator join response: {join_response.data}")
167-
241+
168242
# Use provided session_id or current one
169243
current_session_id = session_id or self.session_id
170-
244+
171245
# Step 3: Connect to WebSocket
172246
try:
173247
self._ws_client, sfu_event = await connect_websocket(
@@ -180,6 +254,17 @@ async def _connect_internal(
180254
self._ws_client.on_wildcard("*", _log_event)
181255
self._ws_client.on_event("ice_trickle", self._on_ice_trickle)
182256

257+
# Connect track subscription events to subscription manager
258+
self._ws_client.on_event(
259+
"track_published", self._subscription_manager.handle_track_published
260+
)
261+
self._ws_client.on_event(
262+
"track_unpublished", self._subscription_manager.handle_track_unpublished
263+
)
264+
265+
# Connect subscriber offer event to handle SDP negotiation
266+
self._ws_client.on_event("subscriber_offer", self._on_subscriber_offer)
267+
183268
if hasattr(sfu_event, "join_response"):
184269
logger.debug(f"sfu join response: {sfu_event.join_response}")
185270
# Populate participants state with existing participants
@@ -193,7 +278,7 @@ async def _connect_internal(
193278
)
194279
else:
195280
logger.warning(f"No join response from WebSocket: {sfu_event}")
196-
281+
197282
logger.debug(f"WebSocket connected successfully to {ws_url}")
198283
except Exception as e:
199284
logger.error(f"Failed to connect WebSocket to {ws_url}: {e}")
@@ -204,15 +289,21 @@ async def _connect_internal(
204289
self.twirp_signaling_client = SignalClient(address=twirp_server_url)
205290
self.twirp_context = Context(headers={"authorization": token})
206291

207-
# Step 5: Create coordinator websocket
208-
user_token = self.call.client.stream.create_token(user_id=self.user_id)
209-
self._coordinator_ws_client = StreamAPIWS(
210-
api_key=self.call.client.stream.api_key,
211-
token=user_token,
212-
user_details={"id": self.user_id},
213-
)
214-
self._coordinator_ws_client.on_wildcard("*", _log_event)
215-
await self._coordinator_ws_client.connect()
292+
# Step 5: Create coordinator websocket (temporarily disabled to test)
293+
# user_token = self.call.client.stream.create_token(user_id=self.user_id)
294+
# self._coordinator_ws_client = StreamAPIWS(
295+
# api_key=self.call.client.stream.api_key,
296+
# token=user_token,
297+
# user_details={"id": self.user_id},
298+
# healthcheck_interval=15.0, # Send heartbeat every 15 seconds instead of 25
299+
# healthcheck_timeout=20.0, # Expect server messages within 20 seconds instead of 30
300+
# )
301+
# self._coordinator_ws_client.on_wildcard("*", _log_event)
302+
# await self._coordinator_ws_client.connect()
303+
self._coordinator_ws_client = None # Temporarily disable coordinator connection
304+
305+
# Step 6: Setup subscriber peer connection to receive incoming tracks
306+
await self._peer_manager.setup_subscriber()
216307

217308
# Mark as connected
218309
self.running = True
@@ -255,6 +346,9 @@ async def leave(self):
255346
if self._ws_client:
256347
await self._ws_client.close()
257348
self._ws_client = None
349+
if self._coordinator_ws_client:
350+
await self._coordinator_ws_client.disconnect()
351+
self._coordinator_ws_client = None
258352

259353
self.connection_state = ConnectionState.LEFT
260354

@@ -264,7 +358,7 @@ async def __aenter__(self):
264358
"""Async context manager entry."""
265359
# Register network event handlers
266360
self._network_monitor.register_event_handlers()
267-
361+
268362
# Connect with retry
269363
await self.connect()
270364

@@ -287,10 +381,14 @@ async def addTrack(self, track, track_info=None):
287381
else:
288382
await self.add_tracks(audio=track)
289383

290-
async def start_recording(self, recording_types, user_ids=None, output_dir="recordings"):
384+
async def start_recording(
385+
self, recording_types, user_ids=None, output_dir="recordings"
386+
):
291387
"""Start recording."""
292388
logger.info("Starting recording")
293-
await self._recording_manager.start_recording(recording_types, user_ids, output_dir)
389+
await self._recording_manager.start_recording(
390+
recording_types, user_ids, output_dir
391+
)
294392

295393
async def stop_recording(self, recording_types=None, user_ids=None):
296394
"""Stop recording."""
@@ -306,15 +404,19 @@ def get_recording_status(self) -> dict:
306404
"""Get current recording status."""
307405
return self._recording_manager.get_recording_status()
308406

309-
async def subscribe_to_track(self, track_id: str, config: Optional[SubscriptionConfig] = None):
407+
async def subscribe_to_track(
408+
self, track_id: str, config: Optional[SubscriptionConfig] = None
409+
):
310410
"""Subscribe to a specific track."""
311411
await self._subscription_manager.subscribe_to_track(track_id, config)
312412

313413
async def unsubscribe_from_track(self, track_id: str):
314414
"""Unsubscribe from a specific track."""
315415
await self._subscription_manager.unsubscribe_from_track(track_id)
316416

317-
async def update_track_subscription(self, track_id: str, config: SubscriptionConfig):
417+
async def update_track_subscription(
418+
self, track_id: str, config: SubscriptionConfig
419+
):
318420
"""Update subscription configuration for a track."""
319421
await self._subscription_manager.update_track_subscription(track_id, config)
320422

@@ -410,7 +512,9 @@ def subscriber_negotiation_lock(self):
410512
# Internal cleanup & restoration helpers referenced by other modules
411513
# ------------------------------------------------------------------
412514

413-
async def _cleanup_connections(self, ws_client=None, publisher_pc=None, subscriber_pc=None):
515+
async def _cleanup_connections(
516+
self, ws_client=None, publisher_pc=None, subscriber_pc=None
517+
):
414518
"""Close provided connections safely; used by ReconnectionManager."""
415519
try:
416520
# Close peer connections (async)
@@ -438,4 +542,3 @@ async def _restore_published_tracks(self):
438542
await self._peer_manager.restore_published_tracks()
439543
except Exception as e:
440544
logger.error("Failed to restore published tracks", exc_info=e)
441-

getstream/video/rtc/coordinator/ws.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,6 @@ def _build_auth_payload(self) -> dict:
9494
"""
9595
Build the authentication payload to send after connection.
9696
97-
For initial connections, includes user_details if provided.
98-
For reconnections, user_details are excluded.
99-
10097
Returns:
10198
Authentication payload as a dictionary
10299
"""
@@ -105,8 +102,8 @@ def _build_auth_payload(self) -> dict:
105102
"products": ["video"],
106103
}
107104

108-
# Only include user_details on initial connection, not on reconnects
109-
if self._initial_connection and self.user_details:
105+
# Include user_details if available (both for initial connection and reconnections)
106+
if self.user_details:
110107
payload["user_details"] = self.user_details
111108

112109
return payload

getstream/video/rtc/participants.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,21 @@ def __init__(self):
2323
# ------------------------------------------------------------------
2424

2525
def get_user_from_track_id(self, track_id: str) -> Optional[models_pb2.Participant]:
26+
# Track IDs have format: participant_id:track_type:...
27+
# We can extract the participant prefix directly from the track ID
28+
if ":" in track_id:
29+
# Extract the participant prefix from the track ID
30+
prefix = track_id.split(":")[0]
31+
user = self._participant_by_prefix.get(prefix)
32+
if user:
33+
return user
34+
35+
# Fallback to the old mapping approach if it exists
2636
stream_id = self._track_stream_mapping.get(track_id)
2737
if stream_id:
2838
prefix = stream_id.split(":")[0]
2939
return self._participant_by_prefix.get(prefix)
40+
3041
return None
3142

3243
def get_stream_id_from_track_id(self, track_id: str) -> Optional[str]:
@@ -57,4 +68,4 @@ async def _on_participant_joined(self, event: events_pb2.ParticipantJoined):
5768

5869
async def _on_participant_left(self, event: events_pb2.ParticipantLeft):
5970
self.remove_participant(event.participant)
60-
self.emit("participant_left", event.participant)
71+
self.emit("participant_left", event.participant)

0 commit comments

Comments
 (0)