Skip to content

Commit 2ed3c4b

Browse files
committed
Simplify source and route shutdown ownership
1 parent 2116de5 commit 2ed3c4b

8 files changed

Lines changed: 105 additions & 91 deletions

File tree

core_engine/src/outputs/asr_sink.rs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,10 @@ impl InputState {
214214
let input_channels = MASTER_FORMAT.channels.max(1) as usize;
215215
let ready_input_samples = self.drained_master.len() / input_channels * input_channels;
216216
if ready_input_samples > 0 {
217-
self.converter
218-
.convert(&self.drained_master[..ready_input_samples], &mut self.converted_output)?;
217+
self.converter.convert(
218+
&self.drained_master[..ready_input_samples],
219+
&mut self.converted_output,
220+
)?;
219221
if !self.converted_output.is_empty() {
220222
self.pending_output
221223
.extend_from_slice(&self.converted_output);
@@ -452,6 +454,14 @@ impl AsrSink {
452454
Err(_) => Err(AsrSinkError::ThreadPanic),
453455
}
454456
}
457+
458+
fn abort_on_drop(&mut self) {
459+
self.stop.store(true, Ordering::Relaxed);
460+
let Some(handle) = self.handle.take() else {
461+
return;
462+
};
463+
let _ = handle.join();
464+
}
455465
}
456466

457467
fn duration_to_u32_us(duration: Duration) -> u32 {
@@ -460,7 +470,7 @@ fn duration_to_u32_us(duration: Duration) -> u32 {
460470

461471
impl Drop for AsrSink {
462472
fn drop(&mut self) {
463-
let _ = self.stop();
473+
self.abort_on_drop();
464474
}
465475
}
466476

core_engine/src/outputs/wav_file.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,7 @@ impl WavFileOutput {
146146
}
147147
}
148148

149-
let ready_samples = input_buffers
150-
.iter()
151-
.map(VecDeque::len)
152-
.min()
153-
.unwrap_or(0)
149+
let ready_samples = input_buffers.iter().map(VecDeque::len).min().unwrap_or(0)
154150
/ frame_channels
155151
* frame_channels;
156152

@@ -312,6 +308,14 @@ impl WavFileOutput {
312308
Err(_) => Err(WavOutputError::ThreadPanic),
313309
}
314310
}
311+
312+
fn abort_on_drop(&mut self) {
313+
self.stop.store(true, Ordering::Relaxed);
314+
let Some(handle) = self.handle.take() else {
315+
return;
316+
};
317+
let _ = handle.join();
318+
}
315319
}
316320

317321
fn duration_to_u32_us(duration: Duration) -> u32 {
@@ -320,7 +324,7 @@ fn duration_to_u32_us(duration: Duration) -> u32 {
320324

321325
impl Drop for WavFileOutput {
322326
fn drop(&mut self) {
323-
let _ = self.stop();
327+
self.abort_on_drop();
324328
}
325329
}
326330

macloop/__init__.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,6 @@ def __init__(self) -> None:
283283
self._streams: dict[str, StreamHandle] = {}
284284
self._routes: dict[str, RouteHandle] = {}
285285
self._processors: dict[str, ProcessorHandle] = {}
286-
self._claimed_routes: set[str] = set()
287286
self._sink_refs: list[weakref.ReferenceType[_Closable]] = []
288287
self._closed = False
289288

@@ -384,7 +383,6 @@ def close(self) -> None:
384383
sink_err = exc
385384

386385
self._sink_refs.clear()
387-
self._claimed_routes.clear()
388386

389387
if backend_err is not None:
390388
raise backend_err
@@ -430,17 +428,12 @@ def _ensure_routes_available(self, routes: Sequence[RouteHandle]) -> None:
430428
if not routes:
431429
raise ValueError("routes must not be empty")
432430

431+
route_ids = []
433432
for route in routes:
434433
self._ensure_route_handle(route)
435-
if route.id in self._claimed_routes:
436-
raise ValueError(f"route '{route.id}' is already in use")
434+
route_ids.append(route.id)
437435

438-
def _claim_routes(self, route_ids: Sequence[str]) -> None:
439-
self._claimed_routes.update(route_ids)
440-
441-
def _release_routes(self, route_ids: Sequence[str]) -> None:
442-
for route_id in route_ids:
443-
self._claimed_routes.discard(route_id)
436+
self._backend.assert_routes_available(route_ids)
444437

445438
def _register_sink(self, sink: _Closable) -> None:
446439
self._sink_refs.append(weakref.ref(sink))
@@ -472,7 +465,6 @@ def __init__(
472465

473466
self.id = sink_id
474467
self._engine_ref = weakref.ref(engine)
475-
self._route_ids = tuple(route_ids)
476468
self._queue = out_queue
477469
self._queue_maxsize = max_queue_size
478470
self._async_queue: Optional[asyncio.Queue[object]] = None
@@ -496,7 +488,6 @@ def _callback(route_id: str, frames: int, samples: AudioSamples) -> None:
496488

497489
self._backend = backend
498490

499-
engine._claim_routes(route_ids)
500491
engine._register_sink(self)
501492

502493
def chunks(self) -> Iterator[AudioChunk]:
@@ -531,8 +522,6 @@ def close(self) -> None:
531522
err = exc
532523
finally:
533524
self._closed = True
534-
if engine is not None:
535-
engine._release_routes(self._route_ids)
536525
_drop_oldest_put(self._queue, _STOP)
537526
async_loop = self._async_loop
538527
async_queue = self._async_queue
@@ -641,10 +630,8 @@ def __init__(
641630
self.id = sink_id
642631
self._backend = backend
643632
self._engine_ref = weakref.ref(engine)
644-
self._route_ids = tuple(route_ids)
645633
self._closed = False
646634

647-
engine._claim_routes(route_ids)
648635
engine._register_sink(self)
649636

650637
def close(self) -> None:
@@ -663,8 +650,6 @@ def close(self) -> None:
663650
err = exc
664651
finally:
665652
self._closed = True
666-
if engine is not None:
667-
engine._release_routes(self._route_ids)
668653

669654
if err is not None:
670655
raise err

python_ffi/src/lib.rs

Lines changed: 45 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1336,6 +1336,11 @@ impl PyAudioEngineBackend {
13361336
Ok(out)
13371337
}
13381338

1339+
fn assert_routes_available(&self, route_ids: Vec<String>) -> PyResult<()> {
1340+
self.ensure_open()?;
1341+
self.ensure_route_consumers_available(&route_ids)
1342+
}
1343+
13391344
fn close(&mut self, py: Python<'_>) -> PyResult<()> {
13401345
// `closed` means no new public operations are allowed. If native cleanup timed out earlier,
13411346
// repeated close attempts are still allowed so drop/explicit close can keep retrying the
@@ -1469,6 +1474,44 @@ impl PyAudioEngineBackend {
14691474
Ok(())
14701475
}
14711476

1477+
fn start_stream_states_for_routes(
1478+
&mut self,
1479+
py: Python<'_>,
1480+
route_ids: &[String],
1481+
) -> PyResult<()> {
1482+
let stream_ids = self.stream_ids_for_routes(route_ids)?;
1483+
let stream_states = self.take_stream_states(&stream_ids)?;
1484+
let started_states = match py.detach(move || {
1485+
start_stream_states_with_timeout(stream_states, NATIVE_SOURCE_START_TIMEOUT)
1486+
}) {
1487+
Ok(states) => states,
1488+
Err(StartStreamsError {
1489+
message,
1490+
states,
1491+
poison,
1492+
cleanup,
1493+
}) => {
1494+
self.restore_stream_states(states);
1495+
self.store_pending_cleanups(cleanup);
1496+
self.poisoned = poison;
1497+
return Err(PyRuntimeError::new_err(message));
1498+
}
1499+
};
1500+
self.restore_stream_states(started_states);
1501+
Ok(())
1502+
}
1503+
1504+
fn prepare_route_consumers_for_sink(
1505+
&mut self,
1506+
py: Python<'_>,
1507+
route_ids: &[String],
1508+
) -> PyResult<Vec<(String, RouteConsumer)>> {
1509+
self.ensure_open()?;
1510+
self.ensure_route_consumers_available(route_ids)?;
1511+
self.start_stream_states_for_routes(py, route_ids)?;
1512+
self.take_route_consumers(route_ids)
1513+
}
1514+
14721515
fn stream_ids_for_routes(&self, route_ids: &[String]) -> PyResult<Vec<String>> {
14731516
let mut stream_ids = Vec::<String>::new();
14741517

@@ -1554,9 +1597,6 @@ impl PyAudioEngineBackend {
15541597
chunk_frames: usize,
15551598
callback: Py<PyAny>,
15561599
) -> PyResult<PyAsrSinkBackend> {
1557-
self.ensure_open()?;
1558-
self.ensure_route_consumers_available(&route_ids)?;
1559-
15601600
let format = StreamFormat::with_sample_format(
15611601
sample_rate,
15621602
channels,
@@ -1568,27 +1608,7 @@ impl PyAudioEngineBackend {
15681608
})
15691609
.map_err(|err| PyValueError::new_err(err.to_string()))?;
15701610

1571-
let stream_ids = self.stream_ids_for_routes(&route_ids)?;
1572-
let stream_states = self.take_stream_states(&stream_ids)?;
1573-
let started_states = match py.detach(move || {
1574-
start_stream_states_with_timeout(stream_states, NATIVE_SOURCE_START_TIMEOUT)
1575-
}) {
1576-
Ok(states) => states,
1577-
Err(StartStreamsError {
1578-
message,
1579-
states,
1580-
poison,
1581-
cleanup,
1582-
}) => {
1583-
self.restore_stream_states(states);
1584-
self.store_pending_cleanups(cleanup);
1585-
self.poisoned = poison;
1586-
return Err(PyRuntimeError::new_err(message));
1587-
}
1588-
};
1589-
self.restore_stream_states(started_states);
1590-
1591-
let route_consumers = self.take_route_consumers(&route_ids)?;
1611+
let route_consumers = self.prepare_route_consumers_for_sink(py, &route_ids)?;
15921612
let detached_result: DetachedAsrStartResult = py.detach(move || {
15931613
let inputs = route_consumers
15941614
.into_iter()
@@ -1640,32 +1660,9 @@ impl PyAudioEngineBackend {
16401660
fd: i32,
16411661
mix_gain: f32,
16421662
) -> PyResult<PyWavSinkBackend> {
1643-
self.ensure_open()?;
1644-
self.ensure_route_consumers_available(&route_ids)?;
1645-
16461663
let file = duplicate_file_descriptor(fd)
16471664
.map_err(|e| PyOSError::new_err(format!("failed to duplicate file descriptor: {e}")))?;
1648-
let stream_ids = self.stream_ids_for_routes(&route_ids)?;
1649-
let stream_states = self.take_stream_states(&stream_ids)?;
1650-
let started_states = match py.detach(move || {
1651-
start_stream_states_with_timeout(stream_states, NATIVE_SOURCE_START_TIMEOUT)
1652-
}) {
1653-
Ok(states) => states,
1654-
Err(StartStreamsError {
1655-
message,
1656-
states,
1657-
poison,
1658-
cleanup,
1659-
}) => {
1660-
self.restore_stream_states(states);
1661-
self.store_pending_cleanups(cleanup);
1662-
self.poisoned = poison;
1663-
return Err(PyRuntimeError::new_err(message));
1664-
}
1665-
};
1666-
self.restore_stream_states(started_states);
1667-
1668-
let route_consumers = self.take_route_consumers(&route_ids)?;
1665+
let route_consumers = self.prepare_route_consumers_for_sink(py, &route_ids)?;
16691666
let route_ids_for_sink = route_ids.clone();
16701667
let master_format = self.controller.master_format();
16711668
let detached_result: DetachedWavStartResult = py.detach(move || {

tests/conftest.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,23 +116,26 @@ def __init__(self, write_calls, samples_written, frames_written, write, finalize
116116

117117

118118
class _FakeAsrSinkBackend:
119-
def __init__(self) -> None:
119+
def __init__(self, route_ids: tuple[str, ...]) -> None:
120120
self.closed = False
121121
self._stats = {}
122+
self._route_ids = route_ids
122123

123124
def stats(self):
124125
return self._stats
125126

126127
def close(self) -> None:
127128
self.closed = True
128129

129-
def close_for_engine(self, _engine) -> None:
130+
def close_for_engine(self, engine) -> None:
130131
self.close()
132+
engine.available_routes.update(self._route_ids)
131133

132134

133135
class _FakeWavSinkBackend:
134-
def __init__(self) -> None:
136+
def __init__(self, route_ids: tuple[str, ...]) -> None:
135137
self.closed = False
138+
self._route_ids = route_ids
136139
write_latency = _FakeLatencyStats(
137140
last_us=18,
138141
max_us=30,
@@ -169,14 +172,16 @@ def stats(self):
169172
def close(self) -> None:
170173
self.closed = True
171174

172-
def close_for_engine(self, _engine) -> None:
175+
def close_for_engine(self, engine) -> None:
173176
self.close()
177+
engine.available_routes.update(self._route_ids)
174178

175179

176180
class _FakeAudioEngineBackend:
177181
def __init__(self) -> None:
178182
self.calls: list[tuple[object, ...]] = []
179183
self.closed = False
184+
self.available_routes: set[str] = set()
180185
pipeline_latency = _FakeLatencyStats(
181186
last_us=10,
182187
max_us=20,
@@ -216,10 +221,16 @@ def add_processor(self, stream_id, processor_id, processor_kind, config) -> None
216221

217222
def route(self, route_id, stream_id) -> None:
218223
self.calls.append(("route", route_id, stream_id))
224+
self.available_routes.add(route_id)
219225

220226
def get_stats(self):
221227
return self.stats
222228

229+
def assert_routes_available(self, route_ids) -> None:
230+
for route_id in route_ids:
231+
if route_id not in self.available_routes:
232+
raise ValueError(f"route '{route_id}' is not available")
233+
223234
def close(self) -> None:
224235
self.closed = True
225236

@@ -267,7 +278,9 @@ def _fake_create_asr_sink(
267278
p95_us=64,
268279
p99_us=64,
269280
)
270-
backend = _FakeAsrSinkBackend()
281+
for route_id in route_ids:
282+
engine.available_routes.discard(route_id)
283+
backend = _FakeAsrSinkBackend(tuple(route_ids))
271284
backend._stats = {
272285
route_ids[0]: _FakeAsrInputStats(
273286
chunks_emitted=2,
@@ -284,7 +297,9 @@ def _fake_create_asr_sink(
284297

285298
def _fake_create_wav_sink(engine, sink_id, route_ids, fd, mix_gain):
286299
engine.calls.append(("create_wav_sink", sink_id, tuple(route_ids), fd, mix_gain))
287-
return _FakeWavSinkBackend()
300+
for route_id in route_ids:
301+
engine.available_routes.discard(route_id)
302+
return _FakeWavSinkBackend(tuple(route_ids))
288303

289304

290305
@pytest.fixture

0 commit comments

Comments
 (0)