diff --git a/crates/aec/src/onnx/mod.rs b/crates/aec/src/onnx/mod.rs index 5c82223cf3..0ff87a381c 100644 --- a/crates/aec/src/onnx/mod.rs +++ b/crates/aec/src/onnx/mod.rs @@ -1,15 +1,99 @@ use realfft::{ComplexToReal, RealFftPlanner, RealToComplex, num_complex::Complex}; +use std::cell::Cell; use std::sync::Arc; -use hypr_onnx::{ - ndarray::{Array3, Array4}, - ort::{session::Session, value::TensorRef}, +use hypr_onnx::ort::{ + logging::LogLevel, + memory::Allocator, + session::{Session, builder::GraphOptimizationLevel}, + value::Tensor, }; use super::CircularBuffer; pub(crate) mod model; +// Holds only the 4 Tensor objects and their cached raw pointers. +// Stored separately from ProcessingContext and cached in a thread-local +// so ProcessingContext can remain stack-local in the hot loop. +struct TensorCache { + in_mag: Tensor, + in_mag_ptr: *mut f32, + in_mag_len: usize, + lpb_mag: Tensor, + lpb_mag_ptr: *mut f32, + lpb_mag_len: usize, + estimated_block: Tensor, + estimated_block_ptr: *mut f32, + estimated_block_len: usize, + in_lpb: Tensor, + in_lpb_ptr: *mut f32, + in_lpb_len: usize, +} + +// SAFETY: single-threaded use; raw pointers point into ORT-allocated stable memory. +unsafe impl Send for TensorCache {} + +impl TensorCache { + fn new(block_len: usize) -> Self { + let allocator = Allocator::default(); + let freq_len = block_len / 2 + 1; + + let mut in_mag = Tensor::::new(&allocator, [1usize, 1, freq_len]) + .expect("failed to allocate in_mag tensor"); + let (_, in_mag_slice) = in_mag + .try_extract_tensor_mut::() + .expect("failed to extract in_mag pointer"); + let in_mag_ptr = in_mag_slice.as_mut_ptr(); + let in_mag_len = in_mag_slice.len(); + + let mut lpb_mag = Tensor::::new(&allocator, [1usize, 1, freq_len]) + .expect("failed to allocate lpb_mag tensor"); + let (_, lpb_mag_slice) = lpb_mag + .try_extract_tensor_mut::() + .expect("failed to extract lpb_mag pointer"); + let lpb_mag_ptr = lpb_mag_slice.as_mut_ptr(); + let lpb_mag_len = lpb_mag_slice.len(); + + let mut estimated_block = Tensor::::new(&allocator, [1usize, 1, block_len]) + .expect("failed to allocate estimated_block tensor"); + let (_, est_slice) = estimated_block + .try_extract_tensor_mut::() + .expect("failed to extract estimated_block pointer"); + let estimated_block_ptr = est_slice.as_mut_ptr(); + let estimated_block_len = est_slice.len(); + + let mut in_lpb = Tensor::::new(&allocator, [1usize, 1, block_len]) + .expect("failed to allocate in_lpb tensor"); + let (_, in_lpb_slice) = in_lpb + .try_extract_tensor_mut::() + .expect("failed to extract in_lpb pointer"); + let in_lpb_ptr = in_lpb_slice.as_mut_ptr(); + let in_lpb_len = in_lpb_slice.len(); + + Self { + in_mag, + in_mag_ptr, + in_mag_len, + lpb_mag, + lpb_mag_ptr, + lpb_mag_len, + estimated_block, + estimated_block_ptr, + estimated_block_len, + in_lpb, + in_lpb_ptr, + in_lpb_len, + } + } +} + +// Thread-local cache for TensorCache (4 Tensor + raw pointers). +// Cell<*mut> has zero locking overhead. ProcessingContext remains stack-local. +thread_local! { + static TENSOR_CACHE: Cell<*mut TensorCache> = Cell::new(std::ptr::null_mut()); +} + struct ProcessingContext { scratch: Vec>, ifft_scratch: Vec>, @@ -18,19 +102,29 @@ struct ProcessingContext { lpb_buffer_fft: Vec, lpb_block_fft: Vec>, estimated_block_vec: Vec, - in_mag: Array3, - lpb_mag: Array3, - estimated_block: Array3, - in_lpb: Array3, + in_mag_ptr: *mut f32, + in_mag_len: usize, + lpb_mag_ptr: *mut f32, + lpb_mag_len: usize, + estimated_block_ptr: *mut f32, + estimated_block_len: usize, + in_lpb_ptr: *mut f32, + in_lpb_len: usize, out_mask: Vec, out_block: Vec, } +// SAFETY: ProcessingContext is used exclusively from a single thread. +// The raw pointers point into ORT-allocated CPU tensor memory which is stable +// for the lifetime of TensorCache and never reallocated. +unsafe impl Send for ProcessingContext {} + impl ProcessingContext { fn new( block_len: usize, fft: &Arc>, ifft: &Arc>, + tc: &TensorCache, ) -> Self { Self { scratch: vec![Complex::new(0.0f32, 0.0f32); fft.get_scratch_len()], @@ -40,10 +134,14 @@ impl ProcessingContext { lpb_buffer_fft: vec![0.0f32; block_len], lpb_block_fft: vec![Complex::new(0.0f32, 0.0f32); block_len / 2 + 1], estimated_block_vec: vec![0.0f32; block_len], - in_mag: Array3::::zeros((1, 1, block_len / 2 + 1)), - lpb_mag: Array3::::zeros((1, 1, block_len / 2 + 1)), - estimated_block: Array3::::zeros((1, 1, block_len)), - in_lpb: Array3::::zeros((1, 1, block_len)), + in_mag_ptr: tc.in_mag_ptr, + in_mag_len: tc.in_mag_len, + lpb_mag_ptr: tc.lpb_mag_ptr, + lpb_mag_len: tc.lpb_mag_len, + estimated_block_ptr: tc.estimated_block_ptr, + estimated_block_len: tc.estimated_block_len, + in_lpb_ptr: tc.in_lpb_ptr, + in_lpb_len: tc.in_lpb_len, out_mask: vec![0.0f32; block_len / 2 + 1], out_block: vec![0.0f32; block_len], } @@ -58,8 +156,8 @@ pub struct AEC { fft: Arc>, ifft: Arc>, // streaming state - states_1: Array4, - states_2: Array4, + states_1: Tensor, + states_2: Tensor, in_buffer: CircularBuffer, in_buffer_lpb: CircularBuffer, out_buffer: CircularBuffer, @@ -74,10 +172,23 @@ impl AEC { let fft = fft_planner.plan_fft_forward(block_len); let ifft = fft_planner.plan_fft_inverse(block_len); - let session_1 = hypr_onnx::load_model_from_bytes(model::BYTES_1)?; - let session_2 = hypr_onnx::load_model_from_bytes(model::BYTES_2)?; + let session_1 = Session::builder()? + .with_intra_threads(2)? + .with_inter_threads(1)? + .with_optimization_level(GraphOptimizationLevel::Level3)? + .with_independent_thread_pool()? + .with_log_level(LogLevel::Fatal)? + .commit_from_memory(model::BYTES_1)?; + let session_2 = Session::builder()? + .with_intra_threads(4)? + .with_inter_threads(1)? + .with_optimization_level(GraphOptimizationLevel::Level3)? + .with_independent_thread_pool()? + .with_log_level(LogLevel::Fatal)? + .commit_from_memory(model::BYTES_2)?; let state_size = model::STATE_SIZE; + let allocator = Allocator::default(); Ok(AEC { session_1, @@ -86,8 +197,10 @@ impl AEC { block_shift, fft, ifft, - states_1: Array4::::zeros((1, 2, state_size, 2)), - states_2: Array4::::zeros((1, 2, state_size, 2)), + states_1: Tensor::::new(&allocator, [1usize, 2, state_size, 2]) + .expect("failed to allocate states_1 tensor"), + states_2: Tensor::::new(&allocator, [1usize, 2, state_size, 2]) + .expect("failed to allocate states_2 tensor"), in_buffer: CircularBuffer::new(block_len, block_shift), in_buffer_lpb: CircularBuffer::new(block_len, block_shift), out_buffer: CircularBuffer::new(block_len, block_shift), @@ -96,9 +209,12 @@ impl AEC { } pub fn reset(&mut self) { - let state_size = model::STATE_SIZE; - self.states_1 = Array4::::zeros((1, 2, state_size, 2)); - self.states_2 = Array4::::zeros((1, 2, state_size, 2)); + if let Ok((_, s)) = self.states_1.try_extract_tensor_mut::() { + s.fill(0.0); + } + if let Ok((_, s)) = self.states_2.try_extract_tensor_mut::() { + s.fill(0.0); + } self.in_buffer.clear(); self.in_buffer_lpb.clear(); self.out_buffer.clear(); @@ -111,91 +227,74 @@ impl AEC { fft_buffer: &mut [f32], fft_result: &mut [Complex], scratch: &mut [Complex], - magnitude: &mut Array3, + magnitude: &mut [f32], ) -> Result<(), crate::Error> { fft_buffer.copy_from_slice(input); self.fft .process_with_scratch(fft_buffer, fft_result, scratch)?; - for (i, &c) in fft_result.iter().enumerate() { - magnitude[[0, 0, i]] = c.norm(); + for (m, &c) in magnitude.iter_mut().zip(fft_result.iter()) { + *m = c.norm(); } Ok(()) } - fn run_model_1(&mut self, ctx: &mut ProcessingContext) -> Result<(), crate::Error> { - let mut outputs = self.session_1.run(hypr_onnx::ort::inputs![ - TensorRef::from_array_view(ctx.in_mag.view())?, - TensorRef::from_array_view(self.states_1.view())?, - TensorRef::from_array_view(ctx.lpb_mag.view())? - ])?; + fn run_model_1( + &mut self, + ctx: &mut ProcessingContext, + tc: &TensorCache, + ) -> Result<(), crate::Error> { + let states_view = self.states_1.view().try_upgrade().map_err(|_| { + crate::Error::ShapeError(hypr_onnx::ndarray::ShapeError::from_kind( + hypr_onnx::ndarray::ErrorKind::IncompatibleLayout, + )) + })?; + let inputs = hypr_onnx::ort::inputs![&tc.in_mag, states_view, &tc.lpb_mag]; + let mut outputs = self.session_1.run(inputs)?; let out_mask = outputs .remove("Identity") .ok_or_else(|| crate::Error::MissingOutput("Identity".to_string()))?; - let out_mask_view = out_mask.try_extract_array::()?; - ctx.out_mask - .copy_from_slice(out_mask_view.view().as_slice().ok_or_else(|| { - crate::Error::ShapeError(hypr_onnx::ndarray::ShapeError::from_kind( - hypr_onnx::ndarray::ErrorKind::IncompatibleLayout, - )) - })?); + let (_, out_mask_data) = out_mask.try_extract_tensor::()?; + ctx.out_mask.copy_from_slice(out_mask_data); let new_states = outputs .remove("Identity_1") .ok_or_else(|| crate::Error::MissingOutput("Identity_1".to_string()))?; - let new_states_view = new_states.try_extract_array::()?; - self.states_1 - .as_slice_mut() - .ok_or_else(|| { - crate::Error::ShapeError(hypr_onnx::ndarray::ShapeError::from_kind( - hypr_onnx::ndarray::ErrorKind::IncompatibleLayout, - )) - })? - .copy_from_slice(new_states_view.view().as_slice().ok_or_else(|| { - crate::Error::ShapeError(hypr_onnx::ndarray::ShapeError::from_kind( - hypr_onnx::ndarray::ErrorKind::IncompatibleLayout, - )) - })?); + let (_, new_states_data) = new_states.try_extract_tensor::()?; + let (_, states_buf) = self.states_1.try_extract_tensor_mut::()?; + states_buf.copy_from_slice(new_states_data); Ok(()) } - fn run_model_2(&mut self, ctx: &mut ProcessingContext) -> Result<(), crate::Error> { - let mut outputs = self.session_2.run(hypr_onnx::ort::inputs![ - TensorRef::from_array_view(ctx.estimated_block.view())?, - TensorRef::from_array_view(self.states_2.view())?, - TensorRef::from_array_view(ctx.in_lpb.view())? - ])?; + fn run_model_2( + &mut self, + ctx: &mut ProcessingContext, + tc: &TensorCache, + ) -> Result<(), crate::Error> { + let states_view = self.states_2.view().try_upgrade().map_err(|_| { + crate::Error::ShapeError(hypr_onnx::ndarray::ShapeError::from_kind( + hypr_onnx::ndarray::ErrorKind::IncompatibleLayout, + )) + })?; + let inputs = + hypr_onnx::ort::inputs![&tc.estimated_block, states_view, &tc.in_lpb]; + let mut outputs = self.session_2.run(inputs)?; let out_block = outputs .remove("Identity") .ok_or_else(|| crate::Error::MissingOutput("Identity".into()))?; - let out_block_view = out_block.try_extract_array::()?; - ctx.out_block - .copy_from_slice(out_block_view.view().as_slice().ok_or_else(|| { - crate::Error::ShapeError(hypr_onnx::ndarray::ShapeError::from_kind( - hypr_onnx::ndarray::ErrorKind::IncompatibleLayout, - )) - })?); + let (_, out_block_data) = out_block.try_extract_tensor::()?; + ctx.out_block.copy_from_slice(out_block_data); let new_states = outputs .remove("Identity_1") .ok_or_else(|| crate::Error::MissingOutput("Identity_1".into()))?; - let new_states_view = new_states.try_extract_array::()?; - self.states_2 - .as_slice_mut() - .ok_or_else(|| { - crate::Error::ShapeError(hypr_onnx::ndarray::ShapeError::from_kind( - hypr_onnx::ndarray::ErrorKind::IncompatibleLayout, - )) - })? - .copy_from_slice(new_states_view.view().as_slice().ok_or_else(|| { - crate::Error::ShapeError(hypr_onnx::ndarray::ShapeError::from_kind( - hypr_onnx::ndarray::ErrorKind::IncompatibleLayout, - )) - })?); + let (_, new_states_data) = new_states.try_extract_tensor::()?; + let (_, states_buf) = self.states_2.try_extract_tensor_mut::()?; + states_buf.copy_from_slice(new_states_data); Ok(()) } @@ -265,8 +364,19 @@ impl AEC { }; let num_blocks = effective_len / self.block_shift; - // Create processing context with all buffers - let mut ctx = ProcessingContext::new(self.block_len, &self.fft, &self.ifft); + // Take or create TensorCache from thread-local: eliminates 4 Tensor + // allocations + 16 C FFI calls per call. Cell swap has zero locking overhead. + // ProcessingContext remains stack-local to keep raw pointer fields in registers. + let tc_raw = TENSOR_CACHE.with(|c| c.replace(std::ptr::null_mut())); + let tc: Box = if tc_raw.is_null() { + Box::new(TensorCache::new(self.block_len)) + } else { + // SAFETY: non-null pointer was created by Box::into_raw; we own it now. + unsafe { Box::from_raw(tc_raw) } + }; + + // Create processing context with all buffers; pointers from tc + let mut ctx = ProcessingContext::new(self.block_len, &self.fft, &self.ifft, &tc); for idx in 0..num_blocks { // Shift values and write to buffer of the input audio @@ -280,24 +390,37 @@ impl AEC { } // Calculate FFT of input block - self.calculate_fft_magnitude( - self.in_buffer.data(), - &mut ctx.in_buffer_fft, - &mut ctx.in_block_fft, - &mut ctx.scratch, - &mut ctx.in_mag, - )?; + // SAFETY: in_mag_ptr points into stable ORT CPU tensor memory allocated in + // ProcessingContext::new(). The tensor is never resized or reallocated. + // No concurrent access occurs — single-threaded hot path. + { + let in_mag_slice = unsafe { + std::slice::from_raw_parts_mut(ctx.in_mag_ptr, ctx.in_mag_len) + }; + self.calculate_fft_magnitude( + self.in_buffer.data(), + &mut ctx.in_buffer_fft, + &mut ctx.in_block_fft, + &mut ctx.scratch, + in_mag_slice, + )?; + } // Calculate FFT of lpb block - self.calculate_fft_magnitude( - self.in_buffer_lpb.data(), - &mut ctx.lpb_buffer_fft, - &mut ctx.lpb_block_fft, - &mut ctx.scratch, - &mut ctx.lpb_mag, - )?; + { + let lpb_mag_slice = unsafe { + std::slice::from_raw_parts_mut(ctx.lpb_mag_ptr, ctx.lpb_mag_len) + }; + self.calculate_fft_magnitude( + self.in_buffer_lpb.data(), + &mut ctx.lpb_buffer_fft, + &mut ctx.lpb_block_fft, + &mut ctx.scratch, + lpb_mag_slice, + )?; + } - self.run_model_1(&mut ctx)?; + self.run_model_1(&mut ctx, &tc)?; // Apply mask and calculate IFFT for (c, &m) in ctx.in_block_fft.iter_mut().zip(ctx.out_mask.iter()) { @@ -311,18 +434,27 @@ impl AEC { &mut ctx.ifft_scratch, )?; - // Normalize IFFT result and copy to Array3 for second model (fused) + // Normalize IFFT result and copy to tensor for second model (fused) let norm_factor = 1.0 / self.block_len as f32; - let est_slice = ctx.estimated_block.as_slice_mut().unwrap(); - for (d, &s) in est_slice.iter_mut().zip(ctx.estimated_block_vec.iter()) { - *d = s * norm_factor; + { + let est_slice = unsafe { + std::slice::from_raw_parts_mut( + ctx.estimated_block_ptr, + ctx.estimated_block_len, + ) + }; + for (d, &s) in est_slice.iter_mut().zip(ctx.estimated_block_vec.iter()) { + *d = s * norm_factor; + } + } + { + let in_lpb_slice = unsafe { + std::slice::from_raw_parts_mut(ctx.in_lpb_ptr, ctx.in_lpb_len) + }; + in_lpb_slice.copy_from_slice(self.in_buffer_lpb.data()); } - ctx.in_lpb - .as_slice_mut() - .unwrap() - .copy_from_slice(self.in_buffer_lpb.data()); - self.run_model_2(&mut ctx)?; + self.run_model_2(&mut ctx, &tc)?; self.out_buffer.shift_and_accumulate(&ctx.out_block); @@ -337,6 +469,10 @@ impl AEC { } self.normalize_output(&mut out_file); + + // Return TensorCache to thread-local for reuse. + TENSOR_CACHE.with(|c| c.set(Box::into_raw(tc))); + Ok(out_file) }