22
33import asyncio
44from abc import ABC , abstractmethod
5- from collections .abc import AsyncIterator
5+ from collections .abc import AsyncIterator , Iterable
66from typing import Generic , TypeVar , Union
77
88from typing_extensions import override
99
1010import livekit .rtc as rtc
11+ from livekit .rtc ._proto .track_pb2 import AudioTrackFeature
1112
1213from ...log import logger
1314from ...utils import aio , log_exceptions
1415from ..io import AudioInput , VideoInput
16+ from ._pre_connect_audio import PreConnectAudioHandler
1517
1618T = TypeVar ("T" , bound = Union [rtc .AudioFrame , rtc .VideoFrame ])
1719
@@ -126,14 +128,15 @@ async def _forward_task(
126128 self ,
127129 old_task : asyncio .Task | None ,
128130 stream : rtc .VideoStream | rtc .AudioStream ,
129- track_source : rtc .TrackSource .ValueType ,
131+ publication : rtc .RemoteTrackPublication ,
132+ participant : rtc .RemoteParticipant ,
130133 ) -> None :
131134 if old_task :
132135 await aio .cancel_and_wait (old_task )
133136
134137 extra = {
135- "participant" : self . _participant_identity ,
136- "source" : rtc .TrackSource .Name (track_source ),
138+ "participant" : participant . identity ,
139+ "source" : rtc .TrackSource .Name (publication . source ),
137140 }
138141 logger .debug ("start reading stream" , extra = extra )
139142 async for event in stream :
@@ -172,7 +175,7 @@ def _on_track_available(
172175 self ._stream = self ._create_stream (track )
173176 self ._publication = publication
174177 self ._forward_atask = asyncio .create_task (
175- self ._forward_task (self ._forward_atask , self ._stream , publication . source )
178+ self ._forward_task (self ._forward_atask , self ._stream , publication , participant )
176179 )
177180 return True
178181
@@ -202,13 +205,15 @@ def __init__(
202205 sample_rate : int ,
203206 num_channels : int ,
204207 noise_cancellation : rtc .NoiseCancellationOptions | None ,
208+ pre_connect_audio_handler : PreConnectAudioHandler | None ,
205209 ) -> None :
206210 _ParticipantInputStream .__init__ (
207211 self , room = room , track_source = rtc .TrackSource .SOURCE_MICROPHONE
208212 )
209213 self ._sample_rate = sample_rate
210214 self ._num_channels = num_channels
211215 self ._noise_cancellation = noise_cancellation
216+ self ._pre_connect_audio_handler = pre_connect_audio_handler
212217
213218 @override
214219 def _create_stream (self , track : rtc .Track ) -> rtc .AudioStream :
@@ -219,6 +224,78 @@ def _create_stream(self, track: rtc.Track) -> rtc.AudioStream:
219224 noise_cancellation = self ._noise_cancellation ,
220225 )
221226
227+ @override
228+ async def _forward_task (
229+ self ,
230+ old_task : asyncio .Task | None ,
231+ stream : rtc .AudioStream ,
232+ publication : rtc .RemoteTrackPublication ,
233+ participant : rtc .RemoteParticipant ,
234+ ) -> None :
235+ if (
236+ self ._pre_connect_audio_handler
237+ and publication .track
238+ and AudioTrackFeature .TF_PRECONNECT_BUFFER in publication .audio_features
239+ ):
240+ try :
241+ duration = 0
242+ frames = await self ._pre_connect_audio_handler .wait_for_data (publication .track .sid )
243+ for frame in self ._resample_frames (frames ):
244+ if self ._attached :
245+ await self ._data_ch .send (frame )
246+ duration += frame .duration
247+ if frames :
248+ logger .debug (
249+ "pre-connect audio buffer pushed" ,
250+ extra = {
251+ "duration" : duration ,
252+ "track_id" : publication .track .sid ,
253+ "participant" : participant .identity ,
254+ },
255+ )
256+
257+ except asyncio .TimeoutError :
258+ logger .warning (
259+ "timeout waiting for pre-connect audio buffer" ,
260+ extra = {
261+ "duration" : duration ,
262+ "track_id" : publication .track .sid ,
263+ "participant" : participant .identity ,
264+ },
265+ )
266+
267+ except Exception as e :
268+ logger .error (
269+ "error reading pre-connect audio buffer" ,
270+ extra = {
271+ "error" : e ,
272+ "track_id" : publication .track .sid ,
273+ "participant" : participant .identity ,
274+ },
275+ )
276+
277+ await super ()._forward_task (old_task , stream , publication , participant )
278+
279+ def _resample_frames (self , frames : Iterable [rtc .AudioFrame ]) -> Iterable [rtc .AudioFrame ]:
280+ resampler : rtc .AudioResampler | None = None
281+ for frame in frames :
282+ if (
283+ not resampler
284+ and self ._sample_rate is not None
285+ and frame .sample_rate != self ._sample_rate
286+ ):
287+ resampler = rtc .AudioResampler (
288+ input_rate = frame .sample_rate , output_rate = self ._sample_rate
289+ )
290+
291+ if resampler :
292+ yield from resampler .push (frame )
293+ else :
294+ yield frame
295+
296+ if resampler :
297+ yield from resampler .flush ()
298+
222299
223300class _ParticipantVideoInputStream (_ParticipantInputStream [rtc .VideoFrame ], VideoInput ):
224301 def __init__ (self , room : rtc .Room ) -> None :
0 commit comments