From 1a5be70d8095d8efd8a2b98e7b53c352af5a73bd Mon Sep 17 00:00:00 2001 From: MKY508 Date: Mon, 30 Mar 2026 22:52:23 +0800 Subject: [PATCH 1/5] whisper : add speaker diarization support Add speaker diarization based on ECAPA-TDNN speaker embeddings. When enabled via --diarize, each transcription segment gets assigned a speaker ID. The pipeline works by computing a 192-dim speaker embedding per segment using a ported SpeechBrain ECAPA-TDNN model, then clustering them with agglomerative hierarchical clustering. New files: - src/whisper-diarize.cpp/h: mel computation, ECAPA-TDNN forward pass, clustering - src/whisper-speaker.cpp/h: GGML model loader - models/convert-speaker-to-ggml.py: SpeechBrain model converter Usage: python models/convert-speaker-to-ggml.py --output models/ggml-speaker-ecapa-tdnn.bin ./whisper-cli -m models/ggml-base.en.bin \ --diarize --diarize-model models/ggml-speaker-ecapa-tdnn.bin -f input.wav The feature is compile-gated behind WHISPER_DIARIZE and has zero overhead when disabled. Embeddings match SpeechBrain PyTorch output (cosine distance < 0.05). Known limitations: ~200MB memory per encoder context, no GPU backend, O(n^2) clustering. Resolves: https://github.com/ggml-org/whisper.cpp/issues/64 --- include/whisper-speaker.h | 39 + include/whisper.h | 19 + models/convert-speaker-to-ggml.py | 332 +++++ src/CMakeLists.txt | 14 + src/whisper-diarize.cpp | 1921 +++++++++++++++++++++++++++++ src/whisper-diarize.h | 79 ++ src/whisper-speaker.cpp | 275 +++++ src/whisper.cpp | 189 +++ tests/CMakeLists.txt | 68 + 9 files changed, 2936 insertions(+) create mode 100644 include/whisper-speaker.h create mode 100644 models/convert-speaker-to-ggml.py create mode 100644 src/whisper-diarize.cpp create mode 100644 src/whisper-diarize.h create mode 100644 src/whisper-speaker.cpp diff --git a/include/whisper-speaker.h b/include/whisper-speaker.h new file mode 100644 index 00000000000..42212a557b8 --- /dev/null +++ b/include/whisper-speaker.h @@ -0,0 +1,39 @@ +#ifndef WHISPER_SPEAKER_H +#define WHISPER_SPEAKER_H + +#include "ggml.h" +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// Opaque speaker model context +struct whisper_speaker_model; + +// Load speaker model from GGML binary file +struct whisper_speaker_model * whisper_speaker_load_from_file(const char * path_model); + +// Free model resources +void whisper_speaker_free(struct whisper_speaker_model * model); + +// Print model structure info +void whisper_speaker_validate(struct whisper_speaker_model * model); + +// Get embedding dimension (192 for ECAPA-TDNN) +int whisper_speaker_get_embedding_dim(struct whisper_speaker_model * model); + +// Get tensor count +int whisper_speaker_get_tensor_count(struct whisper_speaker_model * model); + +// Get tensor by index +struct ggml_tensor * whisper_speaker_get_tensor(struct whisper_speaker_model * model, int idx); + +// Find tensor by name (e.g. "mods.embedding_model.blocks.0.conv.conv.weight") +struct ggml_tensor * whisper_speaker_find_tensor(struct whisper_speaker_model * model, const char * name); + +#ifdef __cplusplus +} +#endif + +#endif // WHISPER_SPEAKER_H diff --git a/include/whisper.h b/include/whisper.h index f4cc6bf7abd..530484ae242 100644 --- a/include/whisper.h +++ b/include/whisper.h @@ -588,6 +588,12 @@ extern "C" { const char * vad_model_path; // Path to VAD model whisper_vad_params vad_params; + + // Speaker diarization params + bool diarize; // Enable speaker diarization (default: false) + const char * diarize_model_path; // Path to speaker embedding model file (GGUF format) + float diarize_threshold; // Distance threshold for clustering (default: 0.5f) + int diarize_speakers; // Target speaker count; 0 = auto-detect (default: 0) }; // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_context_params & whisper_free_params() @@ -647,6 +653,19 @@ extern "C" { WHISPER_API bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment); WHISPER_API bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment); + // Speaker diarization accessor + + // Get the speaker ID assigned to the given segment (result of diarization clustering) + // Returns: 0-based speaker ID if diarization enabled, -1 if disabled or invalid segment + WHISPER_API int whisper_full_get_segment_speaker_id( + struct whisper_context * ctx, + int i_segment); + + // Variant that works with whisper_state directly (for advanced use cases) + WHISPER_API int whisper_full_get_segment_speaker_id_from_state( + struct whisper_state * state, + int i_segment); + // Get the text of the specified segment WHISPER_API const char * whisper_full_get_segment_text (struct whisper_context * ctx, int i_segment); WHISPER_API const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment); diff --git a/models/convert-speaker-to-ggml.py b/models/convert-speaker-to-ggml.py new file mode 100644 index 00000000000..f6cac2713f9 --- /dev/null +++ b/models/convert-speaker-to-ggml.py @@ -0,0 +1,332 @@ +#!/usr/bin/env python3 +""" +Convert SpeechBrain speaker embedding model (ECAPA-TDNN) to GGML binary format. + +This script loads the pre-trained SpeakerRecognition model from SpeechBrain, +extracts weights, and converts them to GGML binary format for use with whisper.cpp. + +GGML Format: +- Magic: 0x67676d6c (4 bytes, "ggml") +- Model type: string length (4 bytes) + UTF-8 string +- Version: major, minor, patch (3 x 4 bytes) +- Hyperparameters: embedding_dim, n_channels (2 x 4 bytes) +- Tensor count (4 bytes) +- Tensors: for each tensor: + - n_dims (4 bytes) + - name_len (4 bytes) + - dims (n_dims x 4 bytes) + - name (name_len bytes) + - data (product(dims) x 4 bytes, float32) + +Usage: + python convert-speaker-to-ggml.py --output ggml-speaker-ecapa-tdnn.bin + python convert-speaker-to-ggml.py --model speechbrain/spkrec-ecapa-voxceleb --output custom.bin + python convert-speaker-to-ggml.py --test # Minimal test run +""" + +import os +import struct +import argparse +import sys +import tempfile +import torch +import numpy as np +from pathlib import Path + +def fuse_batch_norm_weights(conv_weight, conv_bias, bn_mean, bn_var, bn_weight, bn_bias, eps=1e-5): + """ + Fuse BatchNorm into conv weights for inference. + + Args: + conv_weight: [out_c, in_c, kernel_size] + conv_bias: [out_c] + bn_mean, bn_var, bn_weight, bn_bias: [out_c] each + + Returns: + fused_weight: [out_c, in_c, kernel_size] + fused_bias: [out_c] + """ + # Convert to numpy if needed + if isinstance(conv_weight, torch.Tensor): + conv_weight = conv_weight.cpu().numpy() + if isinstance(conv_bias, torch.Tensor): + conv_bias = conv_bias.cpu().numpy() + if isinstance(bn_mean, torch.Tensor): + bn_mean = bn_mean.cpu().numpy() + if isinstance(bn_var, torch.Tensor): + bn_var = bn_var.cpu().numpy() + if isinstance(bn_weight, torch.Tensor): + bn_weight = bn_weight.cpu().numpy() + if isinstance(bn_bias, torch.Tensor): + bn_bias = bn_bias.cpu().numpy() + + # Fusion formula: W_fused = W * γ / sqrt(σ² + ε), b_fused = β + γ * (b - μ) / sqrt(σ² + ε) + scale = bn_weight / np.sqrt(bn_var + eps) # [out_c] + + # Broadcast scale across weight dimensions [out_c, in_c, kernel_size] + fused_weight = conv_weight.astype(np.float32) * scale[:, np.newaxis, np.newaxis] + fused_bias = bn_bias + scale * (conv_bias.astype(np.float32) - bn_mean) + + return fused_weight, fused_bias + +def load_speaker_model(model_name: str, tmp_dir: str = None): + """ + Load SpeakerRecognition model from SpeechBrain. + + Args: + model_name: HuggingFace model identifier (default: speechbrain/spkrec-ecapa-voxceleb) + tmp_dir: temporary directory for model cache + + Returns: + Model object and state_dict + """ + try: + # Monkey-patch for torchaudio >= 2.6 compatibility + import torchaudio + if not hasattr(torchaudio, 'list_audio_backends'): + torchaudio.list_audio_backends = lambda: ['default'] + + # SpeechBrain >= 1.0 moved pretrained to inference + try: + from speechbrain.inference.speaker import SpeakerRecognition + except ImportError: + from speechbrain.pretrained import SpeakerRecognition + except ImportError as e: + print(f"Error: Failed to import SpeechBrain. Install with: pip install -r requirements-convert.txt") + print(f"Original error: {e}") + sys.exit(1) + + if tmp_dir is None: + tmp_dir = tempfile.mkdtemp(prefix='spkrec_') + + print(f"Loading SpeechBrain model: {model_name}") + print(f"Cache directory: {tmp_dir}") + + try: + model = SpeakerRecognition.from_hparams( + source=model_name, + savedir=tmp_dir, + run_opts={'device': 'cpu'}, + freeze_params=True + ) + except Exception as e: + print(f"Failed to load model from {model_name}: {e}") + sys.exit(1) + + print("Model loaded successfully") + + # Extract state_dict + state_dict = model.state_dict() + + print(f"State dict contains {len(state_dict)} tensors") + + return model, state_dict + +def convert_speaker_model(model_name: str, output_path: str, test_mode: bool = False, validate_mode: bool = False): + """ + Convert SpeechBrain speaker model to GGML binary format. + + Args: + model_name: HuggingFace model identifier + output_path: output file path for GGML binary + test_mode: if True, skip validation + validate_mode: if True, run validation after conversion + """ + print("\n" + "="*60) + print("SpeechBrain → GGML Speaker Model Conversion") + print("="*60) + + # Load model + model, state_dict = load_speaker_model(model_name) + + # Write GGML binary + print(f"\nWriting GGML binary to: {output_path}") + + output_dir = os.path.dirname(output_path) or '.' + os.makedirs(output_dir, exist_ok=True) + + with open(output_path, 'wb') as fout: + # Write magic number + magic = 0x67676d6c # "ggml" + fout.write(struct.pack('i', magic)) + print(f"Magic number: 0x{magic:08x}") + + # Write model type + model_type = b"spkrec-ecapa-tdnn" + fout.write(struct.pack('i', len(model_type))) + fout.write(model_type) + print(f"Model type: {model_type.decode('utf-8')}") + + # Write version + version_major = 1 + version_minor = 0 + version_patch = 0 + fout.write(struct.pack('i', version_major)) + fout.write(struct.pack('i', version_minor)) + fout.write(struct.pack('i', version_patch)) + print(f"Version: {version_major}.{version_minor}.{version_patch}") + + # Write hyperparameters + embedding_dim = 192 # ECAPA-TDNN output dimension + n_channels = 512 # Internal architecture parameter + fout.write(struct.pack('i', embedding_dim)) + fout.write(struct.pack('i', n_channels)) + print(f"Embedding dimension: {embedding_dim}") + print(f"Internal channels: {n_channels}") + + # Count all tensors with dim > 0 (skip scalars like num_batches_tracked) + n_tensors = sum(1 for k, v in state_dict.items() if v.dim() > 0) + fout.write(struct.pack('i', n_tensors)) + n_scalars = len(state_dict) - n_tensors + print(f"Tensor count: {n_tensors} (skipping {n_scalars} scalars)") + + # Write all tensors (no BN fusion — SpeechBrain uses Conv→ReLU→BN order, + # so BN cannot be fused into conv weights; C++ applies runtime BN after ReLU) + print("\nWriting tensors (no fusion, runtime BN in C++):") + + tensor_count = 0 + total_bytes = 0 + + for name, tensor in state_dict.items(): + if tensor.dim() == 0: + continue # skip scalars (e.g. num_batches_tracked) + + data = tensor.detach().cpu().numpy().astype(np.float32) + + with open(output_path, 'ab') as fout: + n_dims = len(data.shape) + + # Write tensor header + fout.write(struct.pack('i', n_dims)) + + name_bytes = name.encode('utf-8') + fout.write(struct.pack('i', len(name_bytes))) + + # Write dimensions in REVERSED order for ggml column-major compatibility. + # NumPy row-major: last dim varies fastest in memory. + # ggml column-major: ne[0] varies fastest in memory. + # By reversing, ne[0] = last PyTorch dim, matching the memory layout. + for dim in reversed(data.shape): + fout.write(struct.pack('i', dim)) + + # Write tensor name + fout.write(name_bytes) + + # Write tensor data (row-major bytes, matching reversed ggml dims) + tensor_bytes = data.tobytes() + fout.write(tensor_bytes) + + total_bytes += len(tensor_bytes) + tensor_count += 1 + + shape_str = 'x'.join(str(d) for d in data.shape) + ggml_dims = 'x'.join(str(d) for d in reversed(data.shape)) + print(f" [{tensor_count}] {name}: {shape_str} → ggml [{ggml_dims}]") + + # Verify output + if not os.path.exists(output_path): + print(f"\nError: Output file not created: {output_path}") + sys.exit(1) + + file_size = os.path.getsize(output_path) + file_size_mb = file_size / 1024 / 1024 + + print(f"\n" + "="*60) + print("Conversion complete!") + print("="*60) + print(f"Output file: {output_path}") + print(f"File size: {file_size_mb:.2f} MB ({file_size} bytes)") + print(f"Tensors written: {tensor_count}") + print(f"Tensor data size: {total_bytes / 1024 / 1024:.2f} MB") + + if validate_mode: + print("\nRunning validation...") + validate_conversion(output_path) + + if not test_mode: + print("\nNext steps:") + print(f" 1. Compile C++ test: cd build && cmake .. && make") + print(f" 2. Run test: ./test-speaker-model-load ../{output_path}") + print(f" 3. Validate numerically: python models/validate-speaker-model.py {output_path}") + +def validate_conversion(ggml_path: str): + """ + Quick validation: load GGML file and check magic number and header. + """ + if not os.path.exists(ggml_path): + print(f"Error: File not found: {ggml_path}") + return False + + with open(ggml_path, 'rb') as fin: + # Read magic + magic_bytes = fin.read(4) + magic = struct.unpack('i', magic_bytes)[0] + + if magic != 0x67676d6c: + print(f"Invalid magic number: 0x{magic:08x} (expected 0x67676d6c)") + return False + + print(f"Magic number valid: 0x{magic:08x}") + + str_len = struct.unpack('i', fin.read(4))[0] + model_type = fin.read(str_len).decode('utf-8') + print(f"Model type: {model_type}") + + major, minor, patch = struct.unpack('iii', fin.read(12)) + print(f"Version: {major}.{minor}.{patch}") + + embedding_dim = struct.unpack('i', fin.read(4))[0] + n_channels = struct.unpack('i', fin.read(4))[0] + print(f"Embedding dimension: {embedding_dim}") + print(f"Internal channels: {n_channels}") + + n_tensors = struct.unpack('i', fin.read(4))[0] + print(f"Tensor count: {n_tensors}") + + return True + +def main(): + parser = argparse.ArgumentParser( + description='Convert SpeechBrain speaker embedding model to GGML binary format' + ) + parser.add_argument( + '--model', + default='speechbrain/spkrec-ecapa-voxceleb', + help='HuggingFace model identifier (default: speechbrain/spkrec-ecapa-voxceleb)' + ) + parser.add_argument( + '--output', + default='ggml-speaker-ecapa-tdnn.bin', + help='Output file path (default: ggml-speaker-ecapa-tdnn.bin)' + ) + parser.add_argument( + '--test', + action='store_true', + help='Test mode: minimal output, no verification' + ) + parser.add_argument( + '--validate', + action='store_true', + help='Run validation after conversion' + ) + + args = parser.parse_args() + + try: + convert_speaker_model( + args.model, + args.output, + test_mode=args.test, + validate_mode=args.validate + ) + except KeyboardInterrupt: + print("\n\nConversion cancelled by user") + sys.exit(130) + except Exception as e: + print(f"\nFatal error: {e}") + import traceback + traceback.print_exc() + sys.exit(1) + +if __name__ == '__main__': + main() diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 095a2791de5..9826bb32d7a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -103,10 +103,20 @@ endif() # whisper +# Add diarization sources conditionally +set(WHISPER_DIARIZE_SOURCES) +if (WHISPER_DIARIZE) + list(APPEND WHISPER_DIARIZE_SOURCES + whisper-diarize.cpp + whisper-speaker.cpp + ) +endif() + add_library(whisper ../include/whisper.h whisper-arch.h whisper.cpp + ${WHISPER_DIARIZE_SOURCES} ) # Set the version numbers @@ -141,6 +151,10 @@ if (WHISPER_MKL) target_link_libraries(whisper PRIVATE MKL::MKL) endif() +if (WHISPER_DIARIZE) + target_compile_definitions(whisper PRIVATE WHISPER_DIARIZE) +endif() + if (BUILD_SHARED_LIBS) set_target_properties(whisper PROPERTIES POSITION_INDEPENDENT_CODE ON) target_compile_definitions(whisper PRIVATE WHISPER_SHARED WHISPER_BUILD) diff --git a/src/whisper-diarize.cpp b/src/whisper-diarize.cpp new file mode 100644 index 00000000000..02f795c87ee --- /dev/null +++ b/src/whisper-diarize.cpp @@ -0,0 +1,1921 @@ +#include "whisper-diarize.h" +#include "ggml-cpu.h" +#include +#include +#include +#include +#include +#include + +// Define logging macros for consistency with whisper.cpp +#define WHISPER_LOG_ERROR(...) fprintf(stderr, "[ERROR] " __VA_ARGS__) +#define WHISPER_LOG_WARN(...) fprintf(stderr, "[WARN] " __VA_ARGS__) +#define WHISPER_LOG_INFO(...) fprintf(stderr, "[INFO] " __VA_ARGS__) + +// Mel-spectrogram computation (80-bin, compatible with SpeechBrain ECAPA-TDNN) + +#define WHISPER_N_FFT 400 +#define WHISPER_HOP_LENGTH 160 +#define WHISPER_SAMPLE_RATE 16000 + +#define MEL_N_BINS 80 +#define MEL_FMIN 0.0f +#define MEL_FMAX 8000.0f + +// FFT constants +#define FFT_SIZE 512 // Next power of 2 for efficiency + +static float g_hann_window[WHISPER_N_FFT] = {0}; +static int g_hann_computed = 0; + +static void compute_hann_window() { + if (g_hann_computed) return; + + // Hamming window (matching SpeechBrain pretrained model) + for (int i = 0; i < WHISPER_N_FFT; i++) { + g_hann_window[i] = 0.54f - 0.46f * cosf(2.0f * M_PI * i / (WHISPER_N_FFT - 1)); + } + g_hann_computed = 1; +} + +// Cooley-Tukey radix-2 FFT (complex-to-complex, in-place) +// data: interleaved complex [re0, im0, re1, im1, ...], length 2*n +static void fft_radix2(float * data, int n) { + // Bit-reversal permutation + for (int i = 1, j = 0; i < n; i++) { + int bit = n >> 1; + for (; j & bit; bit >>= 1) j ^= bit; + j ^= bit; + if (i < j) { + float tr = data[2*i]; data[2*i] = data[2*j]; data[2*j] = tr; + float ti = data[2*i+1]; data[2*i+1] = data[2*j+1]; data[2*j+1] = ti; + } + } + // Butterfly passes + for (int len = 2; len <= n; len <<= 1) { + float ang = -2.0f * (float)M_PI / len; + float wr = cosf(ang), wi = sinf(ang); + for (int i = 0; i < n; i += len) { + float cur_r = 1.0f, cur_i = 0.0f; + for (int j = 0; j < len / 2; j++) { + int u = i + j, v = i + j + len / 2; + float tr = data[2*v]*cur_r - data[2*v+1]*cur_i; + float ti = data[2*v]*cur_i + data[2*v+1]*cur_r; + data[2*v] = data[2*u] - tr; + data[2*v+1] = data[2*u+1] - ti; + data[2*u] += tr; + data[2*u+1] += ti; + float nr = cur_r*wr - cur_i*wi; + cur_i = cur_r*wi + cur_i*wr; + cur_r = nr; + } + } + } +} + +// Convert frequency (Hz) to mel scale +static inline float hz_to_mel(float hz) { + return 2595.0f * log10f(1.0f + hz / 700.0f); +} + +// Convert mel scale back to frequency (Hz) +static inline float mel_to_hz(float mel) { + return 700.0f * (powf(10.0f, mel / 2595.0f) - 1.0f); +} + +// Create triangular mel filterbank (librosa-style continuous frequency mapping) +// Returns pointer to filters array of size n_mels * n_fft +static float * create_mel_filters(int n_mels, int n_fft, int sample_rate) { + float * filters = (float *)calloc(n_mels * n_fft, sizeof(float)); + if (!filters) return NULL; + + // Compute mel boundary frequencies + float fmin_mel = hz_to_mel(MEL_FMIN); + float fmax_mel = hz_to_mel(MEL_FMAX); + + float * mel_hz = (float *)malloc((n_mels + 2) * sizeof(float)); + for (int i = 0; i < n_mels + 2; i++) { + mel_hz[i] = mel_to_hz(fmin_mel + (fmax_mel - fmin_mel) * i / (n_mels + 1)); + } + + // Compute frequency of each FFT bin + int fft_size = 2 * (n_fft - 1); + float * fft_freqs = (float *)malloc(n_fft * sizeof(float)); + for (int k = 0; k < n_fft; k++) { + fft_freqs[k] = (float)k * sample_rate / fft_size; + } + + // Build triangular filters (peak=1, no area normalization) + for (int m = 0; m < n_mels; m++) { + float lower = mel_hz[m]; + float center = mel_hz[m + 1]; + float upper = mel_hz[m + 2]; + + for (int k = 0; k < n_fft; k++) { + float freq = fft_freqs[k]; + if (freq >= lower && freq < center && center > lower) { + filters[m * n_fft + k] = (freq - lower) / (center - lower); + } else if (freq >= center && freq < upper && upper > center) { + filters[m * n_fft + k] = (upper - freq) / (upper - center); + } + } + + } + + free(mel_hz); + free(fft_freqs); + return filters; +} + +// Compute 80-bin mel-spectrogram from PCM samples +float * whisper_compute_mel_80(const float * samples, int n_samples) { + if (!samples || n_samples <= 0) { + return NULL; + } + + compute_hann_window(); + + // Center padding: add n_fft/2 samples on both sides + int pad = WHISPER_N_FFT / 2; // 200 samples + int padded_len = n_samples + 2 * pad; + + // Create padded signal + float * padded_samples = (float *)calloc(padded_len, sizeof(float)); + if (!padded_samples) { + return NULL; + } + // Copy original samples to center (zero padding at edges already done by calloc) + memcpy(padded_samples + pad, samples, n_samples * sizeof(float)); + + // Calculate number of frames (now with padded length) + int n_frames = (padded_len - WHISPER_N_FFT) / WHISPER_HOP_LENGTH + 1; + if (n_frames <= 0) { + free(padded_samples); + return NULL; + } + + // Allocate output mel array [n_frames, 80] + float * mel = (float *)calloc(n_frames * MEL_N_BINS, sizeof(float)); + if (!mel) { + return NULL; + } + + // Create mel filterbank once (n_fft=400 → 201 bins) + static float * mel_filters = NULL; + static int mel_filters_initialized = 0; + int n_fft_bins = 1 + WHISPER_N_FFT / 2; // 201 bins for 400-point DFT + if (!mel_filters_initialized) { + mel_filters = create_mel_filters(MEL_N_BINS, n_fft_bins, WHISPER_SAMPLE_RATE); + mel_filters_initialized = 1; + } + + if (!mel_filters) { + free(mel); + return NULL; + } + + // Process each frame + for (int t = 0; t < n_frames; t++) { + int offset = t * WHISPER_HOP_LENGTH; + + // Extract frame, apply Hamming window, zero-pad to 512 as complex interleaved + float * fft_buf = (float *)calloc(FFT_SIZE * 2, sizeof(float)); + for (int i = 0; i < WHISPER_N_FFT; i++) { + fft_buf[2*i] = padded_samples[offset + i] * g_hann_window[i]; + // imaginary = 0 (calloc) + } + // positions WHISPER_N_FFT..FFT_SIZE-1 are zero-padded (calloc) + + // In-place 512-point FFT + fft_radix2(fft_buf, FFT_SIZE); + + // Extract power spectrum for first n_fft_bins=201 bins + // Bin k of 512-point FFT → freq = k * sr / 512 + // Bin k of 400-point DFT → freq = k * sr / 400 + // Map: for target bin j (400-point), find 512-point bin at same frequency + // freq_j = j * sr / 400 → k_512 = j * 512 / 400 = j * 1.28 + // Use linear interpolation between adjacent 512-point bins + float * mag = (float *)malloc(n_fft_bins * sizeof(float)); + for (int j = 0; j < n_fft_bins; j++) { + float k_f = j * (float)FFT_SIZE / WHISPER_N_FFT; // fractional 512-bin index + int k0 = (int)k_f; + float frac = k_f - k0; + int k1 = k0 + 1; + if (k1 >= FFT_SIZE / 2 + 1) k1 = k0; + + // Power at each 512-bin + float p0 = fft_buf[2*k0]*fft_buf[2*k0] + fft_buf[2*k0+1]*fft_buf[2*k0+1]; + float p1 = fft_buf[2*k1]*fft_buf[2*k1] + fft_buf[2*k1+1]*fft_buf[2*k1+1]; + + // Interpolate + mag[j] = p0 * (1.0f - frac) + p1 * frac; + } + + // Apply mel filterbank + for (int m = 0; m < MEL_N_BINS; m++) { + float mel_val = 0.0f; + for (int k = 0; k < n_fft_bins; k++) { + mel_val += mag[k] * mel_filters[m * n_fft_bins + k]; + } + + // dB scale: 10 * log10(max(x, 1e-10)) + mel[t * MEL_N_BINS + m] = 10.0f * log10f(fmaxf(mel_val, 1e-10f)); + } + + free(fft_buf); + free(mag); + } + + // top_db clipping: clamp to (max_db - 80) + float max_db = -1e30f; + for (int i = 0; i < n_frames * MEL_N_BINS; i++) { + if (mel[i] > max_db) max_db = mel[i]; + } + float min_db = max_db - 80.0f; + for (int i = 0; i < n_frames * MEL_N_BINS; i++) { + if (mel[i] < min_db) mel[i] = min_db; + } + + free(padded_samples); + return mel; +} + +// Get number of frames for given sample count +int whisper_get_mel_n_frames(int n_samples) { + if (n_samples <= 0) { + return 0; + } + // Match whisper_compute_mel_80: center padding adds WHISPER_N_FFT/2 on each side + int padded_len = n_samples + WHISPER_N_FFT; // 2 * (N_FFT/2) + return (padded_len - WHISPER_N_FFT) / WHISPER_HOP_LENGTH + 1; +} + +// Free mel-spectrogram buffer +void whisper_mel_free(float * mel) { + if (mel) { + free(mel); + } +} + +// Speaker encoder forward pass + +static struct ggml_tensor * apply_simple_norm( + struct ggml_context * ctx, + struct ggml_tensor * x) { + if (!x) { + WHISPER_LOG_ERROR("apply_simple_norm: NULL input tensor\n"); + return x; + } + return ggml_norm(ctx, x, 1e-5f); +} + +// Reshape conv1d output to 4D for broadcasting +static struct ggml_tensor * ensure_4d_from_conv1d(struct ggml_context * ctx, struct ggml_tensor * t) { + // Conv1d outputs 3D: [OW, OC, batch] + // Reshape to 4D: [OW, OC, batch, 1] + return ggml_reshape_4d(ctx, t, t->ne[0], t->ne[1], t->ne[2], 1); +} + +// Precompute BatchNorm: scale = gamma / sqrt(var + eps), offset = beta - mean * scale +static void precompute_bn_params( + struct ggml_tensor * bn_mean, // [C] (mu) + struct ggml_tensor * bn_var, // [C] (sigma²) + struct ggml_tensor * bn_gamma, // [C] (weight/scale parameter) + struct ggml_tensor * bn_beta, // [C] (bias/shift parameter) + struct ggml_tensor * bn_scale, + struct ggml_tensor * bn_offset) +{ + if (!bn_mean || !bn_var || !bn_gamma || !bn_beta || !bn_scale || !bn_offset) { + WHISPER_LOG_ERROR("precompute_bn_params: NULL tensor\n"); + return; + } + + int32_t C = bn_scale->ne[0]; + float * mean = (float *)bn_mean->data; + float * var = (float *)bn_var->data; + float * gamma = (float *)bn_gamma->data; + float * beta = (float *)bn_beta->data; + float * scale = (float *)bn_scale->data; + float * offset = (float *)bn_offset->data; + + const float eps = 1e-5f; + + for (int32_t c = 0; c < C; c++) { + scale[c] = gamma[c] / sqrtf(var[c] + eps); + offset[c] = beta[c] - mean[c] * scale[c]; + } +} + +// Apply BatchNorm: output = x * scale + offset +static struct ggml_tensor * apply_runtime_bn( + struct ggml_context * ctx, + struct ggml_tensor * x, // [T, C, 1, 1] + struct ggml_tensor * bn_scale, // [C] + struct ggml_tensor * bn_offset) // [C] +{ + if (!x || !bn_scale || !bn_offset) { + WHISPER_LOG_ERROR("apply_runtime_bn: NULL tensor\n"); + return x; + } + + // Reshape scale and offset for broadcasting: [C] → [1, C, 1, 1] + struct ggml_tensor * scale_reshaped = ggml_reshape_4d(ctx, bn_scale, 1, bn_scale->ne[0], 1, 1); + struct ggml_tensor * offset_reshaped = ggml_reshape_4d(ctx, bn_offset, 1, bn_offset->ne[0], 1, 1); + + // Apply: output = (x * scale) + offset + struct ggml_tensor * scaled = ggml_mul(ctx, x, scale_reshaped); + struct ggml_tensor * output = ggml_add(ctx, scaled, offset_reshaped); + + return output; +} + +static struct ggml_tensor * ggml_conv_weight_f32_to_f16( + struct ggml_context * ctx, + struct ggml_tensor * weight_f32) { + + // Convert F32 → F16 (required by ggml_conv_1d) + int64_t ne0 = weight_f32->ne[0]; + int64_t ne1 = weight_f32->ne[1]; + int64_t ne2 = weight_f32->ne[2]; + + + + struct ggml_tensor * weight_f16 = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, ne0, ne1, ne2); + if (!weight_f16) { + WHISPER_LOG_ERROR("Failed to allocate F16 weight tensor\n"); + return NULL; + } + + float * src = (float *)weight_f32->data; + ggml_fp16_t * dst = (ggml_fp16_t *)weight_f16->data; + + int64_t n = ne0 * ne1 * ne2; + for (int64_t i = 0; i < n; i++) { + dst[i] = ggml_fp32_to_fp16(src[i]); + } + + return weight_f16; +} + +struct whisper_speaker_encoder { + struct whisper_speaker_model * model; + struct ggml_context * ctx; + struct ggml_cgraph * graph; + struct ggml_tensor * input_mel; + struct ggml_tensor * output_embedding; + int n_frames; + int n_mels; +}; + +// Initialize encoder with loaded speaker model +struct whisper_speaker_encoder * whisper_speaker_encoder_new( + struct whisper_speaker_model * model, + int n_frames, + int /*device*/) { + + if (!model || n_frames <= 0) { + return NULL; + } + + // Allocate encoder struct + struct whisper_speaker_encoder * encoder = + (struct whisper_speaker_encoder *)malloc(sizeof(struct whisper_speaker_encoder)); + if (!encoder) { + return NULL; + } + + encoder->model = model; + encoder->n_frames = n_frames; + encoder->n_mels = 80; // Fixed for ECAPA-TDNN + + // Dynamic context size: base 200MB + ~0.5MB per frame for intermediate tensors + size_t ctx_bytes = (size_t)200 * 1024 * 1024 + (size_t)n_frames * 512 * 1024; + struct ggml_init_params params = { + .mem_size = ctx_bytes, + .mem_buffer = NULL, + .no_alloc = false, + }; + + encoder->ctx = ggml_init(params); + if (!encoder->ctx) { + free(encoder); + return NULL; + } + + // Create ggml_cgraph + encoder->graph = ggml_new_graph(encoder->ctx); + if (!encoder->graph) { + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + // Input mel: [80, T] in ggml, transposed to [T, 80] in the graph + encoder->input_mel = ggml_new_tensor_2d(encoder->ctx, GGML_TYPE_F32, + encoder->n_mels, encoder->n_frames); + if (!encoder->input_mel) { + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + // Set input tensor as model input + ggml_set_name(encoder->input_mel, "input_mel"); + ggml_set_input(encoder->input_mel); + + // Allocate output tensor placeholder: [192] + encoder->output_embedding = ggml_new_tensor_1d(encoder->ctx, GGML_TYPE_F32, 192); + if (!encoder->output_embedding) { + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + // Build ECAPA-TDNN forward computation graph + struct ggml_tensor * cur = ggml_cont(encoder->ctx, ggml_transpose(encoder->ctx, encoder->input_mel)); // [T, 80] + + // Layer 0: Conv1d(80→1024, k=5) + + struct ggml_tensor * layer0_w = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.0.conv.conv.weight"); + struct ggml_tensor * layer0_b = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.0.conv.conv.bias"); + + if (!layer0_w || !layer0_b) { + WHISPER_LOG_ERROR("Layer 0: Failed to load weights\n"); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + + + // Conv1d(k=5, s=1, p=2, d=1) + struct ggml_tensor * layer0_w_ggml = ggml_conv_weight_f32_to_f16(encoder->ctx, layer0_w); + if (!layer0_w_ggml) { + WHISPER_LOG_ERROR("Layer 0: Failed to convert weight layout\n"); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + // Verify weight data + if (!layer0_w_ggml->data) { + WHISPER_LOG_ERROR("Layer 0: Weight tensor has NULL data!\n"); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + cur = ggml_conv_1d(encoder->ctx, layer0_w_ggml, cur, 1, 2, 1); + if (!cur) { + WHISPER_LOG_ERROR("Layer 0: ggml_conv_1d failed\n"); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + cur = ensure_4d_from_conv1d(encoder->ctx, cur); + + struct ggml_tensor * layer0_b_reshaped = ggml_reshape_3d(encoder->ctx, layer0_b, 1, 1024, 1); + cur = ggml_add(encoder->ctx, cur, layer0_b_reshaped); + + + // ReLU + cur = ggml_relu(encoder->ctx, cur); + + // BatchNorm + struct ggml_tensor * bn0_mean = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.0.norm.norm.running_mean"); + struct ggml_tensor * bn0_var = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.0.norm.norm.running_var"); + struct ggml_tensor * bn0_gamma = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.0.norm.norm.weight"); + struct ggml_tensor * bn0_beta = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.0.norm.norm.bias"); + + if (bn0_mean && bn0_var && bn0_gamma && bn0_beta) { + int32_t bn0_channels = cur->ne[1]; + struct ggml_tensor * bn0_scale = ggml_new_tensor_1d(encoder->ctx, GGML_TYPE_F32, bn0_channels); + struct ggml_tensor * bn0_offset = ggml_new_tensor_1d(encoder->ctx, GGML_TYPE_F32, bn0_channels); + + precompute_bn_params(bn0_mean, bn0_var, bn0_gamma, bn0_beta, bn0_scale, bn0_offset); + cur = apply_runtime_bn(encoder->ctx, cur, bn0_scale, bn0_offset); + } else { + WHISPER_LOG_WARN("Layer 0: Missing BN tensors, skipping BN\n"); + } + + // Layers 1-3: SE-Res2Net blocks + struct ggml_tensor * layer1_tdnn1_w = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.1.tdnn1.conv.conv.weight"); + struct ggml_tensor * layer1_tdnn1_b = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.1.tdnn1.conv.conv.bias"); + + // Res2Net branches (7 branches, each with dilation 1,2,3,4,1,2,3) + struct ggml_tensor * layer1_res2net[7] = {NULL}; + struct ggml_tensor * layer1_res2net_b[7] = {NULL}; + const char * res2net_names[] = { + "mods.embedding_model.blocks.1.res2net_block.blocks.0.conv.conv.weight", + "mods.embedding_model.blocks.1.res2net_block.blocks.1.conv.conv.weight", + "mods.embedding_model.blocks.1.res2net_block.blocks.2.conv.conv.weight", + "mods.embedding_model.blocks.1.res2net_block.blocks.3.conv.conv.weight", + "mods.embedding_model.blocks.1.res2net_block.blocks.4.conv.conv.weight", + "mods.embedding_model.blocks.1.res2net_block.blocks.5.conv.conv.weight", + "mods.embedding_model.blocks.1.res2net_block.blocks.6.conv.conv.weight", + }; + const char * res2net_bias_names[] = { + "mods.embedding_model.blocks.1.res2net_block.blocks.0.conv.conv.bias", + "mods.embedding_model.blocks.1.res2net_block.blocks.1.conv.conv.bias", + "mods.embedding_model.blocks.1.res2net_block.blocks.2.conv.conv.bias", + "mods.embedding_model.blocks.1.res2net_block.blocks.3.conv.conv.bias", + "mods.embedding_model.blocks.1.res2net_block.blocks.4.conv.conv.bias", + "mods.embedding_model.blocks.1.res2net_block.blocks.5.conv.conv.bias", + "mods.embedding_model.blocks.1.res2net_block.blocks.6.conv.conv.bias", + }; + + for (int i = 0; i < 7; i++) { + layer1_res2net[i] = whisper_speaker_find_tensor(encoder->model, res2net_names[i]); + layer1_res2net_b[i] = whisper_speaker_find_tensor(encoder->model, res2net_bias_names[i]); + if (!layer1_res2net[i] || !layer1_res2net_b[i]) { + WHISPER_LOG_ERROR("Layer 1: Failed to load Res2Net branch %d\n", i); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + } + + // SE block weights + struct ggml_tensor * layer1_se_fc1_w = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.1.se_block.conv1.conv.weight"); + struct ggml_tensor * layer1_se_fc1_b = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.1.se_block.conv1.conv.bias"); + struct ggml_tensor * layer1_se_fc2_w = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.1.se_block.conv2.conv.weight"); + struct ggml_tensor * layer1_se_fc2_b = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.1.se_block.conv2.conv.bias"); + + // TDNN2 + struct ggml_tensor * layer1_tdnn2_w = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.1.tdnn2.conv.conv.weight"); + struct ggml_tensor * layer1_tdnn2_b = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.1.tdnn2.conv.conv.bias"); + + if (!layer1_tdnn1_w || !layer1_tdnn1_b || !layer1_se_fc1_w || !layer1_se_fc1_b || + !layer1_se_fc2_w || !layer1_se_fc2_b || !layer1_tdnn2_w || !layer1_tdnn2_b) { + WHISPER_LOG_ERROR("Layer 1: Failed to load SE block weights\n"); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + + // Current input + struct ggml_tensor * layer1_input = cur; // [T, 1024] + + // TDNN1: [T, 1024] → [T, 1024] + struct ggml_tensor * layer1_tdnn1_w_ggml = ggml_conv_weight_f32_to_f16(encoder->ctx, layer1_tdnn1_w); + if (!layer1_tdnn1_w_ggml) { + WHISPER_LOG_ERROR("Layer 1: Failed to convert TDNN1 weight layout\n"); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + struct ggml_tensor * layer1_tdnn1_out = ggml_conv_1d(encoder->ctx, layer1_tdnn1_w_ggml, layer1_input, 1, 0, 1); + layer1_tdnn1_out = ggml_reshape_4d(encoder->ctx, layer1_tdnn1_out, layer1_tdnn1_out->ne[0], layer1_tdnn1_out->ne[1], layer1_tdnn1_out->ne[2], 1); + struct ggml_tensor * layer1_tdnn1_b_reshaped = ggml_reshape_3d(encoder->ctx, layer1_tdnn1_b, 1, 1024, 1); + layer1_tdnn1_out = ggml_add(encoder->ctx, layer1_tdnn1_out, layer1_tdnn1_b_reshaped); + layer1_tdnn1_out = ggml_relu(encoder->ctx, layer1_tdnn1_out); + + // Layer 1 TDNN1: Runtime BatchNorm + { + struct ggml_tensor * bn1_tdnn1_mean = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.1.tdnn1.norm.norm.running_mean"); + struct ggml_tensor * bn1_tdnn1_var = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.1.tdnn1.norm.norm.running_var"); + struct ggml_tensor * bn1_tdnn1_gamma = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.1.tdnn1.norm.norm.weight"); + struct ggml_tensor * bn1_tdnn1_beta = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.1.tdnn1.norm.norm.bias"); + + if (bn1_tdnn1_mean && bn1_tdnn1_var && bn1_tdnn1_gamma && bn1_tdnn1_beta) { + int32_t bn1_channels = layer1_tdnn1_out->ne[1]; + struct ggml_tensor * bn1_scale = ggml_new_tensor_1d(encoder->ctx, GGML_TYPE_F32, bn1_channels); + struct ggml_tensor * bn1_offset = ggml_new_tensor_1d(encoder->ctx, GGML_TYPE_F32, bn1_channels); + + precompute_bn_params(bn1_tdnn1_mean, bn1_tdnn1_var, bn1_tdnn1_gamma, bn1_tdnn1_beta, bn1_scale, bn1_offset); + + layer1_tdnn1_out = apply_runtime_bn(encoder->ctx, layer1_tdnn1_out, bn1_scale, bn1_offset); + } + } + + // Res2Net: Split [T, 1024] into 8 groups of [T, 128], apply 7 dilated convs + 1 identity, concatenate back + // In ggml column-major: ne[0]=T varies fastest, so channel stride = nb[1] = T * elem_size + struct ggml_tensor * layer1_res2net_splits[8]; + const int32_t group_channels = 128; + + size_t chan_stride = layer1_tdnn1_out->nb[1]; // bytes per channel = T * elem_size + + for (int g = 0; g < 8; g++) { + size_t offset = g * group_channels * chan_stride; + layer1_res2net_splits[g] = ggml_view_2d(encoder->ctx, + layer1_tdnn1_out, + layer1_tdnn1_out->ne[0], // T frames + group_channels, // 128 channels + chan_stride, // stride per channel (T * elem_size) + offset); + ggml_set_name(layer1_res2net_splits[g], "res2net_split"); + } + + // Res2Net: chunk[0]=identity, chunk[i]=conv(chunk[i]+y[i-1]) for i>=2 + // All 7 blocks use dilation=2 for Layer 1 + struct ggml_tensor * layer1_res2net_branches[8]; + + // Chunk 0: identity + layer1_res2net_branches[0] = layer1_res2net_splits[0]; + + // Chunks 1-7: apply blocks[0..6] + for (int i = 1; i < 8; i++) { + int b = i - 1; // block index + + struct ggml_tensor * branch_w_ggml = ggml_conv_weight_f32_to_f16( + encoder->ctx, layer1_res2net[b]); + if (!branch_w_ggml) { + WHISPER_LOG_ERROR("Layer 1: Failed to convert Res2Net branch %d weight\n", b); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + // Cumulative: input = chunk[i] + y[i-1] for i >= 2 + struct ggml_tensor * conv_input = layer1_res2net_splits[i]; + if (i >= 2) { + conv_input = ggml_add(encoder->ctx, conv_input, layer1_res2net_branches[i - 1]); + } + + // Conv1d(128→128, k=3, dilation=2, padding=2) + struct ggml_tensor * branch_conv = ggml_conv_1d(encoder->ctx, + branch_w_ggml, conv_input, 1, 2, 2); + + branch_conv = ensure_4d_from_conv1d(encoder->ctx, branch_conv); + + struct ggml_tensor * branch_b_reshaped = ggml_reshape_3d(encoder->ctx, + layer1_res2net_b[b], 1, 128, 1); + branch_conv = ggml_add(encoder->ctx, branch_conv, branch_b_reshaped); + branch_conv = ggml_relu(encoder->ctx, branch_conv); + + // Runtime BatchNorm + { + char bn_name[256]; + snprintf(bn_name, sizeof(bn_name), "mods.embedding_model.blocks.1.res2net_block.blocks.%d.norm.norm.running_mean", b); + struct ggml_tensor * bn_mean = whisper_speaker_find_tensor(encoder->model, bn_name); + snprintf(bn_name, sizeof(bn_name), "mods.embedding_model.blocks.1.res2net_block.blocks.%d.norm.norm.running_var", b); + struct ggml_tensor * bn_var = whisper_speaker_find_tensor(encoder->model, bn_name); + snprintf(bn_name, sizeof(bn_name), "mods.embedding_model.blocks.1.res2net_block.blocks.%d.norm.norm.weight", b); + struct ggml_tensor * bn_gamma = whisper_speaker_find_tensor(encoder->model, bn_name); + snprintf(bn_name, sizeof(bn_name), "mods.embedding_model.blocks.1.res2net_block.blocks.%d.norm.norm.bias", b); + struct ggml_tensor * bn_beta = whisper_speaker_find_tensor(encoder->model, bn_name); + + if (bn_mean && bn_var && bn_gamma && bn_beta) { + struct ggml_tensor * bn_scale = ggml_new_tensor_1d(encoder->ctx, GGML_TYPE_F32, 128); + struct ggml_tensor * bn_offset = ggml_new_tensor_1d(encoder->ctx, GGML_TYPE_F32, 128); + precompute_bn_params(bn_mean, bn_var, bn_gamma, bn_beta, bn_scale, bn_offset); + branch_conv = apply_runtime_bn(encoder->ctx, branch_conv, bn_scale, bn_offset); + } + } + + layer1_res2net_branches[i] = ggml_reshape_2d(encoder->ctx, branch_conv, + branch_conv->ne[0], branch_conv->ne[1]); + } + + // Step 3: Concatenate 8 branches back to [T, 1024] + struct ggml_tensor * res2net_concat = layer1_res2net_branches[0]; + for (int g = 1; g < 8; g++) { + res2net_concat = ggml_concat(encoder->ctx, res2net_concat, + layer1_res2net_branches[g], 1); + if (!res2net_concat) { + WHISPER_LOG_ERROR("Layer 1: Failed to concatenate Res2Net branch %d\n", g); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + } + + if (res2net_concat->ne[1] != 1024) { + WHISPER_LOG_ERROR("Layer 1 Res2Net: ERROR - concat output is %lld, expected 1024\n", + res2net_concat->ne[1]); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + // TDNN2 + + // TDNN2: [T, 1024] → [T, 1024] + struct ggml_tensor * layer1_tdnn2_w_ggml = ggml_conv_weight_f32_to_f16(encoder->ctx, layer1_tdnn2_w); + if (!layer1_tdnn2_w_ggml) { + WHISPER_LOG_ERROR("Layer 1: Failed to convert TDNN2 weight layout\n"); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + struct ggml_tensor * layer1_tdnn2_out = ggml_conv_1d(encoder->ctx, layer1_tdnn2_w_ggml, res2net_concat, 1, 0, 1); + layer1_tdnn2_out = ggml_reshape_4d(encoder->ctx, layer1_tdnn2_out, layer1_tdnn2_out->ne[0], layer1_tdnn2_out->ne[1], layer1_tdnn2_out->ne[2], 1); + struct ggml_tensor * layer1_tdnn2_b_reshaped = ggml_reshape_3d(encoder->ctx, layer1_tdnn2_b, 1, 1024, 1); + layer1_tdnn2_out = ggml_add(encoder->ctx, layer1_tdnn2_out, layer1_tdnn2_b_reshaped); + layer1_tdnn2_out = ggml_relu(encoder->ctx, layer1_tdnn2_out); + + // Layer 1 TDNN2: Runtime BatchNorm + { + struct ggml_tensor * bn1_tdnn2_mean = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.1.tdnn2.norm.norm.running_mean"); + struct ggml_tensor * bn1_tdnn2_var = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.1.tdnn2.norm.norm.running_var"); + struct ggml_tensor * bn1_tdnn2_gamma = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.1.tdnn2.norm.norm.weight"); + struct ggml_tensor * bn1_tdnn2_beta = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.1.tdnn2.norm.norm.bias"); + + if (bn1_tdnn2_mean && bn1_tdnn2_var && bn1_tdnn2_gamma && bn1_tdnn2_beta) { + int32_t bn1_channels = layer1_tdnn2_out->ne[1]; + struct ggml_tensor * bn1_scale = ggml_new_tensor_1d(encoder->ctx, GGML_TYPE_F32, bn1_channels); + struct ggml_tensor * bn1_offset = ggml_new_tensor_1d(encoder->ctx, GGML_TYPE_F32, bn1_channels); + + precompute_bn_params(bn1_tdnn2_mean, bn1_tdnn2_var, bn1_tdnn2_gamma, bn1_tdnn2_beta, bn1_scale, bn1_offset); + + layer1_tdnn2_out = apply_runtime_bn(encoder->ctx, layer1_tdnn2_out, bn1_scale, bn1_offset); + } + } + + // SE Block: GlobalAvgPool → FC1 → ReLU → FC2 → Sigmoid → Scale + + // Global average pooling: [T, 1024] → [1, 1024] + struct ggml_tensor * se_gap = ggml_pool_1d(encoder->ctx, + layer1_tdnn2_out, + GGML_OP_POOL_AVG, + (int)layer1_tdnn2_out->ne[0], // kernel = full seq length + (int)layer1_tdnn2_out->ne[0], // stride = full seq length + 0); + + // Reshape gap to 1D: [1024] + struct ggml_tensor * se_gap_1d = ggml_reshape_1d(encoder->ctx, se_gap, 1024); + + // FC1: [1024] → [128] with ReLU + struct ggml_tensor * se_fc1_w_ggml = ggml_conv_weight_f32_to_f16( + encoder->ctx, layer1_se_fc1_w); + if (!se_fc1_w_ggml) { + WHISPER_LOG_ERROR("Layer 1 SE: Failed to convert FC1 weight\n"); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + struct ggml_tensor * se_fc1 = ggml_mul_mat(encoder->ctx, + ggml_reshape_2d(encoder->ctx, se_fc1_w_ggml, 1024, 128), + ggml_reshape_2d(encoder->ctx, se_gap_1d, 1024, 1)); + se_fc1 = ggml_reshape_1d(encoder->ctx, se_fc1, 128); + se_fc1 = ggml_add(encoder->ctx, se_fc1, layer1_se_fc1_b); + se_fc1 = ggml_relu(encoder->ctx, se_fc1); + + // FC2: [128] → [1024] with Sigmoid + struct ggml_tensor * se_fc2_w_ggml = ggml_conv_weight_f32_to_f16( + encoder->ctx, layer1_se_fc2_w); + if (!se_fc2_w_ggml) { + WHISPER_LOG_ERROR("Layer 1 SE: Failed to convert FC2 weight\n"); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + struct ggml_tensor * se_fc2 = ggml_mul_mat(encoder->ctx, + ggml_reshape_2d(encoder->ctx, se_fc2_w_ggml, 128, 1024), + ggml_reshape_2d(encoder->ctx, se_fc1, 128, 1)); + se_fc2 = ggml_reshape_1d(encoder->ctx, se_fc2, 1024); + se_fc2 = ggml_add(encoder->ctx, se_fc2, layer1_se_fc2_b); + struct ggml_tensor * se_gates = ggml_sigmoid(encoder->ctx, se_fc2); + + // Scale: [T, 1024] × [1024] element-wise + struct ggml_tensor * se_gates_reshaped = ggml_reshape_3d(encoder->ctx, + se_gates, 1, 1024, 1); + + struct ggml_tensor * layer1_se_out = ggml_mul(encoder->ctx, + layer1_tdnn2_out, se_gates_reshaped); + + // Residual connection + cur = ggml_add(encoder->ctx, layer1_se_out, layer1_input); + + struct ggml_tensor * layer1_out = cur; // [n_frames, 1024] + + // Layer 2 + struct ggml_tensor * layer2_tdnn1_w = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.2.tdnn1.conv.conv.weight"); + struct ggml_tensor * layer2_tdnn1_b = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.2.tdnn1.conv.conv.bias"); + + // Res2Net branches (7 branches) + struct ggml_tensor * layer2_res2net[7] = {NULL}; + struct ggml_tensor * layer2_res2net_b[7] = {NULL}; + const char * res2net_names_2[] = { + "mods.embedding_model.blocks.2.res2net_block.blocks.0.conv.conv.weight", + "mods.embedding_model.blocks.2.res2net_block.blocks.1.conv.conv.weight", + "mods.embedding_model.blocks.2.res2net_block.blocks.2.conv.conv.weight", + "mods.embedding_model.blocks.2.res2net_block.blocks.3.conv.conv.weight", + "mods.embedding_model.blocks.2.res2net_block.blocks.4.conv.conv.weight", + "mods.embedding_model.blocks.2.res2net_block.blocks.5.conv.conv.weight", + "mods.embedding_model.blocks.2.res2net_block.blocks.6.conv.conv.weight", + }; + const char * res2net_bias_names_2[] = { + "mods.embedding_model.blocks.2.res2net_block.blocks.0.conv.conv.bias", + "mods.embedding_model.blocks.2.res2net_block.blocks.1.conv.conv.bias", + "mods.embedding_model.blocks.2.res2net_block.blocks.2.conv.conv.bias", + "mods.embedding_model.blocks.2.res2net_block.blocks.3.conv.conv.bias", + "mods.embedding_model.blocks.2.res2net_block.blocks.4.conv.conv.bias", + "mods.embedding_model.blocks.2.res2net_block.blocks.5.conv.conv.bias", + "mods.embedding_model.blocks.2.res2net_block.blocks.6.conv.conv.bias", + }; + + for (int i = 0; i < 7; i++) { + layer2_res2net[i] = whisper_speaker_find_tensor(encoder->model, res2net_names_2[i]); + layer2_res2net_b[i] = whisper_speaker_find_tensor(encoder->model, res2net_bias_names_2[i]); + if (!layer2_res2net[i] || !layer2_res2net_b[i]) { + WHISPER_LOG_ERROR("Layer 2: Failed to load Res2Net branch %d\n", i); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + } + + // SE block weights + struct ggml_tensor * layer2_se_fc1_w = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.2.se_block.conv1.conv.weight"); + struct ggml_tensor * layer2_se_fc1_b = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.2.se_block.conv1.conv.bias"); + struct ggml_tensor * layer2_se_fc2_w = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.2.se_block.conv2.conv.weight"); + struct ggml_tensor * layer2_se_fc2_b = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.2.se_block.conv2.conv.bias"); + + // TDNN2 + struct ggml_tensor * layer2_tdnn2_w = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.2.tdnn2.conv.conv.weight"); + struct ggml_tensor * layer2_tdnn2_b = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.2.tdnn2.conv.conv.bias"); + + if (!layer2_tdnn1_w || !layer2_tdnn1_b || !layer2_se_fc1_w || !layer2_se_fc1_b || + !layer2_se_fc2_w || !layer2_se_fc2_b || !layer2_tdnn2_w || !layer2_tdnn2_b) { + WHISPER_LOG_ERROR("Layer 2: Failed to load SE block weights\n"); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + // Current input + struct ggml_tensor * layer2_input = cur; // [T, 1024] + + // TDNN1: [T, 1024] → [T, 1024] + struct ggml_tensor * layer2_tdnn1_w_ggml = ggml_conv_weight_f32_to_f16(encoder->ctx, layer2_tdnn1_w); + if (!layer2_tdnn1_w_ggml) { + WHISPER_LOG_ERROR("Layer 2: Failed to convert TDNN1 weight layout\n"); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + struct ggml_tensor * layer2_tdnn1_out = ggml_conv_1d(encoder->ctx, layer2_tdnn1_w_ggml, layer2_input, 1, 0, 1); + layer2_tdnn1_out = ggml_reshape_4d(encoder->ctx, layer2_tdnn1_out, layer2_tdnn1_out->ne[0], layer2_tdnn1_out->ne[1], layer2_tdnn1_out->ne[2], 1); + struct ggml_tensor * layer2_tdnn1_b_reshaped = ggml_reshape_3d(encoder->ctx, layer2_tdnn1_b, 1, 1024, 1); + layer2_tdnn1_out = ggml_add(encoder->ctx, layer2_tdnn1_out, layer2_tdnn1_b_reshaped); + layer2_tdnn1_out = ggml_relu(encoder->ctx, layer2_tdnn1_out); + + // Layer 2 TDNN1: Runtime BatchNorm + { + struct ggml_tensor * bn2_tdnn1_mean = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.2.tdnn1.norm.norm.running_mean"); + struct ggml_tensor * bn2_tdnn1_var = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.2.tdnn1.norm.norm.running_var"); + struct ggml_tensor * bn2_tdnn1_gamma = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.2.tdnn1.norm.norm.weight"); + struct ggml_tensor * bn2_tdnn1_beta = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.2.tdnn1.norm.norm.bias"); + + if (bn2_tdnn1_mean && bn2_tdnn1_var && bn2_tdnn1_gamma && bn2_tdnn1_beta) { + int32_t bn2_channels = layer2_tdnn1_out->ne[1]; + struct ggml_tensor * bn2_scale = ggml_new_tensor_1d(encoder->ctx, GGML_TYPE_F32, bn2_channels); + struct ggml_tensor * bn2_offset = ggml_new_tensor_1d(encoder->ctx, GGML_TYPE_F32, bn2_channels); + + precompute_bn_params(bn2_tdnn1_mean, bn2_tdnn1_var, bn2_tdnn1_gamma, bn2_tdnn1_beta, bn2_scale, bn2_offset); + + layer2_tdnn1_out = apply_runtime_bn(encoder->ctx, layer2_tdnn1_out, bn2_scale, bn2_offset); + } + } + + // Res2Net: Split [T, 1024] into 8 groups of [T, 128] + struct ggml_tensor * layer2_res2net_splits[8]; + size_t chan_stride_l2 = layer2_tdnn1_out->nb[1]; + + for (int g = 0; g < 8; g++) { + size_t offset = g * group_channels * chan_stride_l2; + layer2_res2net_splits[g] = ggml_view_2d(encoder->ctx, + layer2_tdnn1_out, + layer2_tdnn1_out->ne[0], + group_channels, + chan_stride_l2, + offset); + ggml_set_name(layer2_res2net_splits[g], "layer2_res2net_split"); + } + + // Res2Net: chunk[0]=identity, chunk[i]=conv(chunk[i]+y[i-1]) for i>=2 + // All 7 blocks use dilation=3 for Layer 2 + struct ggml_tensor * layer2_res2net_branches[8]; + + // Chunk 0: identity + layer2_res2net_branches[0] = layer2_res2net_splits[0]; + + // Chunks 1-7: apply blocks[0..6] + for (int i = 1; i < 8; i++) { + int b = i - 1; // block index + + struct ggml_tensor * branch_w_ggml = ggml_conv_weight_f32_to_f16( + encoder->ctx, layer2_res2net[b]); + if (!branch_w_ggml) { + WHISPER_LOG_ERROR("Layer 2: Failed to convert Res2Net branch %d weight\n", b); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + // Cumulative: input = chunk[i] + y[i-1] for i >= 2 + struct ggml_tensor * conv_input_l2 = layer2_res2net_splits[i]; + if (i >= 2) { + conv_input_l2 = ggml_add(encoder->ctx, conv_input_l2, layer2_res2net_branches[i - 1]); + } + + // Conv1d(128→128, k=3, dilation=3, padding=3) + struct ggml_tensor * branch_conv = ggml_conv_1d(encoder->ctx, + branch_w_ggml, conv_input_l2, 1, 3, 3); + + branch_conv = ensure_4d_from_conv1d(encoder->ctx, branch_conv); + + struct ggml_tensor * branch_b_reshaped = ggml_reshape_3d(encoder->ctx, + layer2_res2net_b[b], 1, 128, 1); + branch_conv = ggml_add(encoder->ctx, branch_conv, branch_b_reshaped); + branch_conv = ggml_relu(encoder->ctx, branch_conv); + + // Layer 2 Res2Net branch runtime BatchNorm + { + char bn_name[256]; + snprintf(bn_name, sizeof(bn_name), "mods.embedding_model.blocks.2.res2net_block.blocks.%d.norm.norm.running_mean", b); + struct ggml_tensor * bn_mean = whisper_speaker_find_tensor(encoder->model, bn_name); + + snprintf(bn_name, sizeof(bn_name), "mods.embedding_model.blocks.2.res2net_block.blocks.%d.norm.norm.running_var", b); + struct ggml_tensor * bn_var = whisper_speaker_find_tensor(encoder->model, bn_name); + + snprintf(bn_name, sizeof(bn_name), "mods.embedding_model.blocks.2.res2net_block.blocks.%d.norm.norm.weight", b); + struct ggml_tensor * bn_gamma = whisper_speaker_find_tensor(encoder->model, bn_name); + + snprintf(bn_name, sizeof(bn_name), "mods.embedding_model.blocks.2.res2net_block.blocks.%d.norm.norm.bias", b); + struct ggml_tensor * bn_beta = whisper_speaker_find_tensor(encoder->model, bn_name); + + if (bn_mean && bn_var && bn_gamma && bn_beta) { + struct ggml_tensor * bn_scale = ggml_new_tensor_1d(encoder->ctx, GGML_TYPE_F32, 128); + struct ggml_tensor * bn_offset = ggml_new_tensor_1d(encoder->ctx, GGML_TYPE_F32, 128); + + precompute_bn_params(bn_mean, bn_var, bn_gamma, bn_beta, bn_scale, bn_offset); + + branch_conv = apply_runtime_bn(encoder->ctx, branch_conv, bn_scale, bn_offset); + } + } + + layer2_res2net_branches[i] = ggml_reshape_2d(encoder->ctx, branch_conv, + branch_conv->ne[0], branch_conv->ne[1]); + } + + // Step 3: Concatenate 8 branches back to [T, 1024] + struct ggml_tensor * layer2_res2net_concat = layer2_res2net_branches[0]; + for (int g = 1; g < 8; g++) { + layer2_res2net_concat = ggml_concat(encoder->ctx, layer2_res2net_concat, + layer2_res2net_branches[g], 1); + if (!layer2_res2net_concat) { + WHISPER_LOG_ERROR("Layer 2: Failed to concatenate Res2Net branch %d\n", g); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + } + + if (layer2_res2net_concat->ne[1] != 1024) { + WHISPER_LOG_ERROR("Layer 2 Res2Net: ERROR - concat output is %lld, expected 1024\n", + layer2_res2net_concat->ne[1]); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + // TDNN2 + + // TDNN2: [T, 1024] → [T, 1024] + struct ggml_tensor * layer2_tdnn2_w_ggml = ggml_conv_weight_f32_to_f16(encoder->ctx, layer2_tdnn2_w); + if (!layer2_tdnn2_w_ggml) { + WHISPER_LOG_ERROR("Layer 2: Failed to convert TDNN2 weight layout\n"); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + struct ggml_tensor * layer2_tdnn2_out = ggml_conv_1d(encoder->ctx, layer2_tdnn2_w_ggml, layer2_res2net_concat, 1, 0, 1); + layer2_tdnn2_out = ggml_reshape_4d(encoder->ctx, layer2_tdnn2_out, layer2_tdnn2_out->ne[0], layer2_tdnn2_out->ne[1], layer2_tdnn2_out->ne[2], 1); + struct ggml_tensor * layer2_tdnn2_b_reshaped = ggml_reshape_3d(encoder->ctx, layer2_tdnn2_b, 1, 1024, 1); + layer2_tdnn2_out = ggml_add(encoder->ctx, layer2_tdnn2_out, layer2_tdnn2_b_reshaped); + layer2_tdnn2_out = ggml_relu(encoder->ctx, layer2_tdnn2_out); + + // Layer 2 TDNN2: Runtime BatchNorm + { + struct ggml_tensor * bn2_tdnn2_mean = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.2.tdnn2.norm.norm.running_mean"); + struct ggml_tensor * bn2_tdnn2_var = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.2.tdnn2.norm.norm.running_var"); + struct ggml_tensor * bn2_tdnn2_gamma = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.2.tdnn2.norm.norm.weight"); + struct ggml_tensor * bn2_tdnn2_beta = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.2.tdnn2.norm.norm.bias"); + + if (bn2_tdnn2_mean && bn2_tdnn2_var && bn2_tdnn2_gamma && bn2_tdnn2_beta) { + int32_t bn2_channels = layer2_tdnn2_out->ne[1]; + struct ggml_tensor * bn2_scale = ggml_new_tensor_1d(encoder->ctx, GGML_TYPE_F32, bn2_channels); + struct ggml_tensor * bn2_offset = ggml_new_tensor_1d(encoder->ctx, GGML_TYPE_F32, bn2_channels); + + precompute_bn_params(bn2_tdnn2_mean, bn2_tdnn2_var, bn2_tdnn2_gamma, bn2_tdnn2_beta, bn2_scale, bn2_offset); + + layer2_tdnn2_out = apply_runtime_bn(encoder->ctx, layer2_tdnn2_out, bn2_scale, bn2_offset); + } + } + + // SE Block: GlobalAvgPool → FC1 → ReLU → FC2 → Sigmoid → Scale + + struct ggml_tensor * layer2_se_gap = ggml_pool_1d(encoder->ctx, + layer2_tdnn2_out, + GGML_OP_POOL_AVG, + (int)layer2_tdnn2_out->ne[0], + (int)layer2_tdnn2_out->ne[0], + 0); + + struct ggml_tensor * layer2_se_gap_1d = ggml_reshape_1d(encoder->ctx, layer2_se_gap, 1024); + + struct ggml_tensor * layer2_se_fc1_w_ggml = ggml_conv_weight_f32_to_f16( + encoder->ctx, layer2_se_fc1_w); + if (!layer2_se_fc1_w_ggml) { + WHISPER_LOG_ERROR("Layer 2 SE: Failed to convert FC1 weight\n"); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + struct ggml_tensor * layer2_se_fc1 = ggml_mul_mat(encoder->ctx, + ggml_reshape_2d(encoder->ctx, layer2_se_fc1_w_ggml, 1024, 128), + ggml_reshape_2d(encoder->ctx, layer2_se_gap_1d, 1024, 1)); + layer2_se_fc1 = ggml_reshape_1d(encoder->ctx, layer2_se_fc1, 128); + layer2_se_fc1 = ggml_add(encoder->ctx, layer2_se_fc1, layer2_se_fc1_b); + layer2_se_fc1 = ggml_relu(encoder->ctx, layer2_se_fc1); + + struct ggml_tensor * layer2_se_fc2_w_ggml = ggml_conv_weight_f32_to_f16( + encoder->ctx, layer2_se_fc2_w); + if (!layer2_se_fc2_w_ggml) { + WHISPER_LOG_ERROR("Layer 2 SE: Failed to convert FC2 weight\n"); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + struct ggml_tensor * layer2_se_fc2 = ggml_mul_mat(encoder->ctx, + ggml_reshape_2d(encoder->ctx, layer2_se_fc2_w_ggml, 128, 1024), + ggml_reshape_2d(encoder->ctx, layer2_se_fc1, 128, 1)); + layer2_se_fc2 = ggml_reshape_1d(encoder->ctx, layer2_se_fc2, 1024); + layer2_se_fc2 = ggml_add(encoder->ctx, layer2_se_fc2, layer2_se_fc2_b); + struct ggml_tensor * layer2_se_gates = ggml_sigmoid(encoder->ctx, layer2_se_fc2); + + struct ggml_tensor * layer2_se_gates_reshaped = ggml_reshape_3d(encoder->ctx, + layer2_se_gates, 1, 1024, 1); + + struct ggml_tensor * layer2_se_out = ggml_mul(encoder->ctx, + layer2_tdnn2_out, layer2_se_gates_reshaped); + + // Residual connection + cur = ggml_add(encoder->ctx, layer2_se_out, layer2_input); + + struct ggml_tensor * layer2_out = cur; // [n_frames, 1024] + + // Layer 3 + struct ggml_tensor * layer3_tdnn1_w = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.3.tdnn1.conv.conv.weight"); + struct ggml_tensor * layer3_tdnn1_b = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.3.tdnn1.conv.conv.bias"); + + // Res2Net branches (7 branches) + struct ggml_tensor * layer3_res2net[7] = {NULL}; + struct ggml_tensor * layer3_res2net_b[7] = {NULL}; + const char * res2net_names_3[] = { + "mods.embedding_model.blocks.3.res2net_block.blocks.0.conv.conv.weight", + "mods.embedding_model.blocks.3.res2net_block.blocks.1.conv.conv.weight", + "mods.embedding_model.blocks.3.res2net_block.blocks.2.conv.conv.weight", + "mods.embedding_model.blocks.3.res2net_block.blocks.3.conv.conv.weight", + "mods.embedding_model.blocks.3.res2net_block.blocks.4.conv.conv.weight", + "mods.embedding_model.blocks.3.res2net_block.blocks.5.conv.conv.weight", + "mods.embedding_model.blocks.3.res2net_block.blocks.6.conv.conv.weight", + }; + const char * res2net_bias_names_3[] = { + "mods.embedding_model.blocks.3.res2net_block.blocks.0.conv.conv.bias", + "mods.embedding_model.blocks.3.res2net_block.blocks.1.conv.conv.bias", + "mods.embedding_model.blocks.3.res2net_block.blocks.2.conv.conv.bias", + "mods.embedding_model.blocks.3.res2net_block.blocks.3.conv.conv.bias", + "mods.embedding_model.blocks.3.res2net_block.blocks.4.conv.conv.bias", + "mods.embedding_model.blocks.3.res2net_block.blocks.5.conv.conv.bias", + "mods.embedding_model.blocks.3.res2net_block.blocks.6.conv.conv.bias", + }; + + for (int i = 0; i < 7; i++) { + layer3_res2net[i] = whisper_speaker_find_tensor(encoder->model, res2net_names_3[i]); + layer3_res2net_b[i] = whisper_speaker_find_tensor(encoder->model, res2net_bias_names_3[i]); + if (!layer3_res2net[i] || !layer3_res2net_b[i]) { + WHISPER_LOG_ERROR("Layer 3: Failed to load Res2Net branch %d\n", i); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + } + + // SE block weights + struct ggml_tensor * layer3_se_fc1_w = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.3.se_block.conv1.conv.weight"); + struct ggml_tensor * layer3_se_fc1_b = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.3.se_block.conv1.conv.bias"); + struct ggml_tensor * layer3_se_fc2_w = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.3.se_block.conv2.conv.weight"); + struct ggml_tensor * layer3_se_fc2_b = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.3.se_block.conv2.conv.bias"); + + // TDNN2 + struct ggml_tensor * layer3_tdnn2_w = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.3.tdnn2.conv.conv.weight"); + struct ggml_tensor * layer3_tdnn2_b = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.3.tdnn2.conv.conv.bias"); + + if (!layer3_tdnn1_w || !layer3_tdnn1_b || !layer3_se_fc1_w || !layer3_se_fc1_b || + !layer3_se_fc2_w || !layer3_se_fc2_b || !layer3_tdnn2_w || !layer3_tdnn2_b) { + WHISPER_LOG_ERROR("Layer 3: Failed to load SE block weights\n"); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + // Current input + struct ggml_tensor * layer3_input = cur; // [T, 1024] + + // TDNN1: [T, 1024] → [T, 1024] + struct ggml_tensor * layer3_tdnn1_w_ggml = ggml_conv_weight_f32_to_f16(encoder->ctx, layer3_tdnn1_w); + if (!layer3_tdnn1_w_ggml) { + WHISPER_LOG_ERROR("Layer 3: Failed to convert TDNN1 weight layout\n"); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + struct ggml_tensor * layer3_tdnn1_out = ggml_conv_1d(encoder->ctx, layer3_tdnn1_w_ggml, layer3_input, 1, 0, 1); + layer3_tdnn1_out = ggml_reshape_4d(encoder->ctx, layer3_tdnn1_out, layer3_tdnn1_out->ne[0], layer3_tdnn1_out->ne[1], layer3_tdnn1_out->ne[2], 1); + struct ggml_tensor * layer3_tdnn1_b_reshaped = ggml_reshape_3d(encoder->ctx, layer3_tdnn1_b, 1, 1024, 1); + layer3_tdnn1_out = ggml_add(encoder->ctx, layer3_tdnn1_out, layer3_tdnn1_b_reshaped); + layer3_tdnn1_out = ggml_relu(encoder->ctx, layer3_tdnn1_out); + + // Layer 3 TDNN1: Runtime BatchNorm + { + struct ggml_tensor * bn3_tdnn1_mean = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.3.tdnn1.norm.norm.running_mean"); + struct ggml_tensor * bn3_tdnn1_var = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.3.tdnn1.norm.norm.running_var"); + struct ggml_tensor * bn3_tdnn1_gamma = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.3.tdnn1.norm.norm.weight"); + struct ggml_tensor * bn3_tdnn1_beta = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.3.tdnn1.norm.norm.bias"); + + if (bn3_tdnn1_mean && bn3_tdnn1_var && bn3_tdnn1_gamma && bn3_tdnn1_beta) { + int32_t bn3_channels = layer3_tdnn1_out->ne[1]; + struct ggml_tensor * bn3_scale = ggml_new_tensor_1d(encoder->ctx, GGML_TYPE_F32, bn3_channels); + struct ggml_tensor * bn3_offset = ggml_new_tensor_1d(encoder->ctx, GGML_TYPE_F32, bn3_channels); + + precompute_bn_params(bn3_tdnn1_mean, bn3_tdnn1_var, bn3_tdnn1_gamma, bn3_tdnn1_beta, bn3_scale, bn3_offset); + + layer3_tdnn1_out = apply_runtime_bn(encoder->ctx, layer3_tdnn1_out, bn3_scale, bn3_offset); + } + } + + // Res2Net: Apply 7 branches with different dilations + // Res2Net: Split [T, 1024] into 8 groups of [T, 128] + struct ggml_tensor * layer3_res2net_splits[8]; + size_t chan_stride_l3 = layer3_tdnn1_out->nb[1]; + + for (int g = 0; g < 8; g++) { + size_t offset = g * group_channels * chan_stride_l3; + layer3_res2net_splits[g] = ggml_view_2d(encoder->ctx, + layer3_tdnn1_out, + layer3_tdnn1_out->ne[0], + group_channels, + chan_stride_l3, + offset); + ggml_set_name(layer3_res2net_splits[g], "layer3_res2net_split"); + } + + // Res2Net: chunk[0]=identity, chunk[i]=conv(chunk[i]+y[i-1]) for i>=2 + // All 7 blocks use dilation=4 for Layer 3 + struct ggml_tensor * layer3_res2net_branches[8]; + + // Chunk 0: identity + layer3_res2net_branches[0] = layer3_res2net_splits[0]; + + // Chunks 1-7: apply blocks[0..6] + for (int i = 1; i < 8; i++) { + int b = i - 1; // block index + + struct ggml_tensor * branch_w_ggml = ggml_conv_weight_f32_to_f16( + encoder->ctx, layer3_res2net[b]); + if (!branch_w_ggml) { + WHISPER_LOG_ERROR("Layer 3: Failed to convert Res2Net branch %d weight\n", b); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + // Cumulative: input = chunk[i] + y[i-1] for i >= 2 + struct ggml_tensor * conv_input_l3 = layer3_res2net_splits[i]; + if (i >= 2) { + conv_input_l3 = ggml_add(encoder->ctx, conv_input_l3, layer3_res2net_branches[i - 1]); + } + + // Conv1d(128→128, k=3, dilation=4, padding=4) + struct ggml_tensor * branch_conv = ggml_conv_1d(encoder->ctx, + branch_w_ggml, conv_input_l3, 1, 4, 4); + + branch_conv = ensure_4d_from_conv1d(encoder->ctx, branch_conv); + + struct ggml_tensor * branch_b_reshaped = ggml_reshape_3d(encoder->ctx, + layer3_res2net_b[b], 1, 128, 1); + branch_conv = ggml_add(encoder->ctx, branch_conv, branch_b_reshaped); + branch_conv = ggml_relu(encoder->ctx, branch_conv); + + // Layer 3 Res2Net branch runtime BatchNorm + { + char bn_name[256]; + snprintf(bn_name, sizeof(bn_name), "mods.embedding_model.blocks.3.res2net_block.blocks.%d.norm.norm.running_mean", b); + struct ggml_tensor * bn_mean = whisper_speaker_find_tensor(encoder->model, bn_name); + + snprintf(bn_name, sizeof(bn_name), "mods.embedding_model.blocks.3.res2net_block.blocks.%d.norm.norm.running_var", b); + struct ggml_tensor * bn_var = whisper_speaker_find_tensor(encoder->model, bn_name); + + snprintf(bn_name, sizeof(bn_name), "mods.embedding_model.blocks.3.res2net_block.blocks.%d.norm.norm.weight", b); + struct ggml_tensor * bn_gamma = whisper_speaker_find_tensor(encoder->model, bn_name); + + snprintf(bn_name, sizeof(bn_name), "mods.embedding_model.blocks.3.res2net_block.blocks.%d.norm.norm.bias", b); + struct ggml_tensor * bn_beta = whisper_speaker_find_tensor(encoder->model, bn_name); + + if (bn_mean && bn_var && bn_gamma && bn_beta) { + struct ggml_tensor * bn_scale = ggml_new_tensor_1d(encoder->ctx, GGML_TYPE_F32, 128); + struct ggml_tensor * bn_offset = ggml_new_tensor_1d(encoder->ctx, GGML_TYPE_F32, 128); + + precompute_bn_params(bn_mean, bn_var, bn_gamma, bn_beta, bn_scale, bn_offset); + + branch_conv = apply_runtime_bn(encoder->ctx, branch_conv, bn_scale, bn_offset); + } + } + + layer3_res2net_branches[i] = ggml_reshape_2d(encoder->ctx, branch_conv, + branch_conv->ne[0], branch_conv->ne[1]); + } + + // Step 3: Concatenate 8 branches back to [T, 1024] + struct ggml_tensor * layer3_res2net_concat = layer3_res2net_branches[0]; + for (int g = 1; g < 8; g++) { + layer3_res2net_concat = ggml_concat(encoder->ctx, layer3_res2net_concat, + layer3_res2net_branches[g], 1); + if (!layer3_res2net_concat) { + WHISPER_LOG_ERROR("Layer 3: Failed to concatenate Res2Net branch %d\n", g); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + } + + if (layer3_res2net_concat->ne[1] != 1024) { + WHISPER_LOG_ERROR("Layer 3 Res2Net: ERROR - concat output is %lld, expected 1024\n", + layer3_res2net_concat->ne[1]); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + // TDNN2 + + // TDNN2: [T, 1024] → [T, 1024] + struct ggml_tensor * layer3_tdnn2_w_ggml = ggml_conv_weight_f32_to_f16(encoder->ctx, layer3_tdnn2_w); + if (!layer3_tdnn2_w_ggml) { + WHISPER_LOG_ERROR("Layer 3: Failed to convert TDNN2 weight layout\n"); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + struct ggml_tensor * layer3_tdnn2_out = ggml_conv_1d(encoder->ctx, layer3_tdnn2_w_ggml, layer3_res2net_concat, 1, 0, 1); + layer3_tdnn2_out = ggml_reshape_4d(encoder->ctx, layer3_tdnn2_out, layer3_tdnn2_out->ne[0], layer3_tdnn2_out->ne[1], layer3_tdnn2_out->ne[2], 1); + struct ggml_tensor * layer3_tdnn2_b_reshaped = ggml_reshape_3d(encoder->ctx, layer3_tdnn2_b, 1, 1024, 1); + layer3_tdnn2_out = ggml_add(encoder->ctx, layer3_tdnn2_out, layer3_tdnn2_b_reshaped); + layer3_tdnn2_out = ggml_relu(encoder->ctx, layer3_tdnn2_out); + + // Layer 3 TDNN2: Runtime BatchNorm + { + struct ggml_tensor * bn3_tdnn2_mean = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.3.tdnn2.norm.norm.running_mean"); + struct ggml_tensor * bn3_tdnn2_var = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.3.tdnn2.norm.norm.running_var"); + struct ggml_tensor * bn3_tdnn2_gamma = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.3.tdnn2.norm.norm.weight"); + struct ggml_tensor * bn3_tdnn2_beta = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.blocks.3.tdnn2.norm.norm.bias"); + + if (bn3_tdnn2_mean && bn3_tdnn2_var && bn3_tdnn2_gamma && bn3_tdnn2_beta) { + int32_t bn3_channels = layer3_tdnn2_out->ne[1]; + struct ggml_tensor * bn3_scale = ggml_new_tensor_1d(encoder->ctx, GGML_TYPE_F32, bn3_channels); + struct ggml_tensor * bn3_offset = ggml_new_tensor_1d(encoder->ctx, GGML_TYPE_F32, bn3_channels); + + precompute_bn_params(bn3_tdnn2_mean, bn3_tdnn2_var, bn3_tdnn2_gamma, bn3_tdnn2_beta, bn3_scale, bn3_offset); + + layer3_tdnn2_out = apply_runtime_bn(encoder->ctx, layer3_tdnn2_out, bn3_scale, bn3_offset); + } + } + + // SE Block: GlobalAvgPool → FC1 → ReLU → FC2 → Sigmoid → Scale + + struct ggml_tensor * layer3_se_gap = ggml_pool_1d(encoder->ctx, + layer3_tdnn2_out, + GGML_OP_POOL_AVG, + (int)layer3_tdnn2_out->ne[0], + (int)layer3_tdnn2_out->ne[0], + 0); + + struct ggml_tensor * layer3_se_gap_1d = ggml_reshape_1d(encoder->ctx, layer3_se_gap, 1024); + + struct ggml_tensor * layer3_se_fc1_w_ggml = ggml_conv_weight_f32_to_f16( + encoder->ctx, layer3_se_fc1_w); + if (!layer3_se_fc1_w_ggml) { + WHISPER_LOG_ERROR("Layer 3 SE: Failed to convert FC1 weight\n"); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + struct ggml_tensor * layer3_se_fc1 = ggml_mul_mat(encoder->ctx, + ggml_reshape_2d(encoder->ctx, layer3_se_fc1_w_ggml, 1024, 128), + ggml_reshape_2d(encoder->ctx, layer3_se_gap_1d, 1024, 1)); + layer3_se_fc1 = ggml_reshape_1d(encoder->ctx, layer3_se_fc1, 128); + layer3_se_fc1 = ggml_add(encoder->ctx, layer3_se_fc1, layer3_se_fc1_b); + layer3_se_fc1 = ggml_relu(encoder->ctx, layer3_se_fc1); + + struct ggml_tensor * layer3_se_fc2_w_ggml = ggml_conv_weight_f32_to_f16( + encoder->ctx, layer3_se_fc2_w); + if (!layer3_se_fc2_w_ggml) { + WHISPER_LOG_ERROR("Layer 3 SE: Failed to convert FC2 weight\n"); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + struct ggml_tensor * layer3_se_fc2 = ggml_mul_mat(encoder->ctx, + ggml_reshape_2d(encoder->ctx, layer3_se_fc2_w_ggml, 128, 1024), + ggml_reshape_2d(encoder->ctx, layer3_se_fc1, 128, 1)); + layer3_se_fc2 = ggml_reshape_1d(encoder->ctx, layer3_se_fc2, 1024); + layer3_se_fc2 = ggml_add(encoder->ctx, layer3_se_fc2, layer3_se_fc2_b); + struct ggml_tensor * layer3_se_gates = ggml_sigmoid(encoder->ctx, layer3_se_fc2); + + struct ggml_tensor * layer3_se_gates_reshaped = ggml_reshape_3d(encoder->ctx, + layer3_se_gates, 1, 1024, 1); + + struct ggml_tensor * layer3_se_out = ggml_mul(encoder->ctx, + layer3_tdnn2_out, layer3_se_gates_reshaped); + + // Residual connection + cur = ggml_add(encoder->ctx, layer3_se_out, layer3_input); + + struct ggml_tensor * layer3_out = cur; // [n_frames, 1024] + + // Layer 4: Multi-layer Feature Aggregation (MFA) + struct ggml_tensor * mfa_w = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.mfa.conv.conv.weight"); + struct ggml_tensor * mfa_b = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.mfa.conv.conv.bias"); + + if (!mfa_w || !mfa_b) { + WHISPER_LOG_ERROR("Layer 4 (MFA): Failed to load weights\n"); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + // Concatenate layers 1-3: [T, 1024] × 3 → [T, 3072] + + struct ggml_tensor * mfa_input = ggml_concat(encoder->ctx, layer1_out, layer2_out, 1); + if (!mfa_input) { + WHISPER_LOG_ERROR("Layer 4 (MFA): Failed to concatenate layers 1-2\n"); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + mfa_input = ggml_concat(encoder->ctx, mfa_input, layer3_out, 1); + if (!mfa_input) { + WHISPER_LOG_ERROR("Layer 4 (MFA): Failed to concatenate layers 1-3\n"); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + // Verify dimension + if (mfa_input->ne[1] != 3072) { + WHISPER_LOG_WARN("Layer 4 (MFA): WARNING - input dimension is %lld, expected 3072\n", mfa_input->ne[1]); + } + + // Apply MFA Conv1d(3072→3072, k=1, padding=0) + BN + struct ggml_tensor * mfa_w_ggml = ggml_conv_weight_f32_to_f16(encoder->ctx, mfa_w); + if (!mfa_w_ggml) { + WHISPER_LOG_ERROR("Layer 4 (MFA): Failed to convert weight layout\n"); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + // MFA conv: [T, 3072] @ [1, 3072, 3072] → [T, 3072] + struct ggml_tensor * mfa_conv_out = ggml_conv_1d(encoder->ctx, mfa_w_ggml, mfa_input, 1, 0, 1); + mfa_conv_out = ggml_reshape_4d(encoder->ctx, mfa_conv_out, mfa_conv_out->ne[0], mfa_conv_out->ne[1], mfa_conv_out->ne[2], 1); + + // Add bias + struct ggml_tensor * mfa_b_reshaped = ggml_reshape_3d(encoder->ctx, mfa_b, 1, 3072, 1); + mfa_conv_out = ggml_add(encoder->ctx, mfa_conv_out, mfa_b_reshaped); + + // ReLU + mfa_conv_out = ggml_relu(encoder->ctx, mfa_conv_out); + + // Runtime BN for MFA + struct ggml_tensor * bn_mfa_mean = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.mfa.norm.norm.running_mean"); + struct ggml_tensor * bn_mfa_var = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.mfa.norm.norm.running_var"); + struct ggml_tensor * bn_mfa_gamma = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.mfa.norm.norm.weight"); + struct ggml_tensor * bn_mfa_beta = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.mfa.norm.norm.bias"); + + if (bn_mfa_mean && bn_mfa_var && bn_mfa_gamma && bn_mfa_beta) { + struct ggml_tensor * bn_mfa_scale = ggml_new_tensor_1d(encoder->ctx, GGML_TYPE_F32, 3072); + struct ggml_tensor * bn_mfa_offset = ggml_new_tensor_1d(encoder->ctx, GGML_TYPE_F32, 3072); + + precompute_bn_params(bn_mfa_mean, bn_mfa_var, bn_mfa_gamma, bn_mfa_beta, bn_mfa_scale, bn_mfa_offset); + + mfa_conv_out = apply_runtime_bn(encoder->ctx, mfa_conv_out, bn_mfa_scale, bn_mfa_offset); + } else { + WHISPER_LOG_WARN("Layer 4 (MFA): Missing BN tensors, skipping BN\n"); + } + + struct ggml_tensor * mfa_out = mfa_conv_out; + + // Final check + if (mfa_out->ne[1] != 3072) { + WHISPER_LOG_ERROR("Layer 4 (MFA): ERROR - output dimension is %lld, expected 3072! Aborting.\n", mfa_out->ne[1]); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + cur = mfa_out; // [T, 3072] + + // Layer 5: Attentive Statistical Pooling (ASP) + + struct ggml_tensor * x = cur; + int32_t n_features = x->ne[1]; + if (n_features != 3072) { + WHISPER_LOG_ERROR("Layer 5: input dimension is %d, expected 3072\n", n_features); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + int32_t T = x->ne[0]; + int32_t C = x->ne[1]; + + // Global statistics for attention input + struct ggml_tensor * global_mean = ggml_pool_1d(encoder->ctx, x, GGML_OP_POOL_AVG, T, T, 0); + + struct ggml_tensor * x_minus_mean = ggml_sub(encoder->ctx, x, global_mean); + struct ggml_tensor * sq_dev = ggml_mul(encoder->ctx, x_minus_mean, x_minus_mean); + struct ggml_tensor * var = ggml_pool_1d(encoder->ctx, sq_dev, GGML_OP_POOL_AVG, T, T, 0); + struct ggml_tensor * global_std = ggml_sqrt(encoder->ctx, var); + + // Repeat stats for concatenation + struct ggml_tensor * mean_repeated = ggml_repeat(encoder->ctx, global_mean, x); + struct ggml_tensor * std_repeated = ggml_repeat(encoder->ctx, global_std, x); + + struct ggml_tensor * att_input = ggml_concat(encoder->ctx, x, mean_repeated, 1); + att_input = ggml_concat(encoder->ctx, att_input, std_repeated, 1); + + // TDNN projection → attention weights + struct ggml_tensor * tdnn_w = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.asp.tdnn.conv.conv.weight"); + struct ggml_tensor * tdnn_b = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.asp.tdnn.conv.conv.bias"); + + if (!tdnn_w || !tdnn_b) { + WHISPER_LOG_ERROR("Layer 5 ASP: Missing TDNN tensors\n"); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + struct ggml_tensor * tdnn_w_ggml = ggml_conv_weight_f32_to_f16(encoder->ctx, tdnn_w); + if (!tdnn_w_ggml) { + WHISPER_LOG_ERROR("Layer 5 ASP: Failed to convert TDNN weight layout\n"); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + struct ggml_tensor * tdnn_conv = ggml_conv_1d(encoder->ctx, tdnn_w_ggml, att_input, 1, 0, 1); + tdnn_conv = ensure_4d_from_conv1d(encoder->ctx, tdnn_conv); + struct ggml_tensor * tdnn_b_reshaped = ggml_reshape_3d(encoder->ctx, tdnn_b, 1, 128, 1); + tdnn_conv = ggml_add(encoder->ctx, tdnn_conv, tdnn_b_reshaped); + tdnn_conv = ggml_relu(encoder->ctx, tdnn_conv); + + { + struct ggml_tensor * bn_mean = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.asp.tdnn.norm.norm.running_mean"); + struct ggml_tensor * bn_var = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.asp.tdnn.norm.norm.running_var"); + struct ggml_tensor * bn_g = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.asp.tdnn.norm.norm.weight"); + struct ggml_tensor * bn_b = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.asp.tdnn.norm.norm.bias"); + if (bn_mean && bn_var && bn_g && bn_b) { + struct ggml_tensor * s = ggml_new_tensor_1d(encoder->ctx, GGML_TYPE_F32, 128); + struct ggml_tensor * o = ggml_new_tensor_1d(encoder->ctx, GGML_TYPE_F32, 128); + precompute_bn_params(bn_mean, bn_var, bn_g, bn_b, s, o); + tdnn_conv = apply_runtime_bn(encoder->ctx, tdnn_conv, s, o); + } + } + + // Attention conv + struct ggml_tensor * att_conv_w = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.asp.conv.conv.weight"); + struct ggml_tensor * att_conv_b = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.asp.conv.conv.bias"); + if (!att_conv_w || !att_conv_b) { + WHISPER_LOG_ERROR("Layer 5 ASP: Missing attention conv tensors\n"); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + struct ggml_tensor * att_conv_w_f16 = ggml_conv_weight_f32_to_f16(encoder->ctx, att_conv_w); + struct ggml_tensor * att_e = ggml_conv_1d(encoder->ctx, att_conv_w_f16, tdnn_conv, 1, 0, 1); + att_e = ensure_4d_from_conv1d(encoder->ctx, att_e); + struct ggml_tensor * att_conv_b_r = ggml_reshape_3d(encoder->ctx, att_conv_b, 1, C, 1); + att_e = ggml_add(encoder->ctx, att_e, att_conv_b_r); + + // Softmax over time + struct ggml_tensor * att_e_2d = ggml_reshape_2d(encoder->ctx, att_e, T, C); + struct ggml_tensor * alpha = ggml_soft_max(encoder->ctx, att_e_2d); + + // Weighted aggregation + struct ggml_tensor * x_2d = ggml_reshape_2d(encoder->ctx, x, T, C); + + struct ggml_tensor * x_weighted = ggml_mul(encoder->ctx, x_2d, alpha); + struct ggml_tensor * weighted_mean_2d = ggml_sum_rows(encoder->ctx, x_weighted); + struct ggml_tensor * x_minus_wmean = ggml_sub(encoder->ctx, x_2d, weighted_mean_2d); + struct ggml_tensor * sq_diff = ggml_mul(encoder->ctx, x_minus_wmean, x_minus_wmean); + struct ggml_tensor * w_sq_diff = ggml_mul(encoder->ctx, sq_diff, alpha); + struct ggml_tensor * weighted_var_2d = ggml_sum_rows(encoder->ctx, w_sq_diff); + struct ggml_tensor * eps_tensor = ggml_new_tensor_1d(encoder->ctx, GGML_TYPE_F32, 1); + ((float *)eps_tensor->data)[0] = 1e-5f; + weighted_var_2d = ggml_add(encoder->ctx, weighted_var_2d, eps_tensor); + struct ggml_tensor * weighted_std_2d = ggml_sqrt(encoder->ctx, weighted_var_2d); + + // Output: [mean; std] → [6144] + weighted_mean_2d = ggml_reshape_4d(encoder->ctx, weighted_mean_2d, 1, C, 1, 1); + weighted_std_2d = ggml_reshape_4d(encoder->ctx, weighted_std_2d, 1, C, 1, 1); + + struct ggml_tensor * asp_output = ggml_concat(encoder->ctx, weighted_mean_2d, weighted_std_2d, 1); + asp_output = ggml_reshape_1d(encoder->ctx, asp_output, 2*C); // [6144] + + if (asp_output->ne[0] != 6144) { + WHISPER_LOG_ERROR("Layer 5 ASP: Output dimension mismatch: %lld != 6144\n", asp_output->ne[0]); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + cur = asp_output; + + // ASP BatchNorm + { + struct ggml_tensor * bn_mean = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.asp_bn.norm.running_mean"); + struct ggml_tensor * bn_var = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.asp_bn.norm.running_var"); + struct ggml_tensor * bn_g = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.asp_bn.norm.weight"); + struct ggml_tensor * bn_b = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.asp_bn.norm.bias"); + if (bn_mean && bn_var && bn_g && bn_b) { + struct ggml_tensor * s = ggml_new_tensor_1d(encoder->ctx, GGML_TYPE_F32, 6144); + struct ggml_tensor * o = ggml_new_tensor_1d(encoder->ctx, GGML_TYPE_F32, 6144); + precompute_bn_params(bn_mean, bn_var, bn_g, bn_b, s, o); + cur = ggml_mul(encoder->ctx, cur, s); + cur = ggml_add(encoder->ctx, cur, o); + } else { + WHISPER_LOG_WARN("ASP BN: Missing tensors, skipping\n"); + } + } + + // Layer 6: Final FC [6144] → [192] + + struct ggml_tensor * embedding_w = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.fc.conv.weight"); + struct ggml_tensor * embedding_b = whisper_speaker_find_tensor(encoder->model, "mods.embedding_model.fc.conv.bias"); + + if (!embedding_w || !embedding_b) { + WHISPER_LOG_ERROR("Layer 6 (Final FC): Failed to load weights\n"); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + + if (cur->ne[0] != 6144) { + WHISPER_LOG_ERROR("Layer 6 FC: ERROR - input dimension is %lld, expected 6144!\n", cur->ne[0]); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + struct ggml_tensor * embedding_w_ggml = ggml_conv_weight_f32_to_f16(encoder->ctx, embedding_w); + if (!embedding_w_ggml) { + WHISPER_LOG_ERROR("Layer 6: Failed to convert weight layout\n"); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + struct ggml_tensor * weight_2d = ggml_reshape_2d(encoder->ctx, embedding_w_ggml, 6144, 192); + + struct ggml_tensor * cur_fc_2d = ggml_reshape_2d(encoder->ctx, cur, 6144, 1); + struct ggml_tensor * embedding = ggml_mul_mat(encoder->ctx, weight_2d, cur_fc_2d); + embedding = ggml_reshape_1d(encoder->ctx, embedding, 192); + if (!embedding) { + WHISPER_LOG_ERROR("Layer 6: Failed to compute matrix multiplication\n"); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + embedding = ggml_add(encoder->ctx, embedding, embedding_b); + + if (embedding->ne[0] != 192) { + WHISPER_LOG_ERROR("Layer 6 FC: ERROR - output dimension is %lld, expected 192!\n", embedding->ne[0]); + ggml_free(encoder->ctx); + free(encoder); + return NULL; + } + + ggml_set_name(embedding, "output_embedding"); + ggml_set_output(embedding); + encoder->output_embedding = embedding; + ggml_build_forward_expand(encoder->graph, embedding); + + return encoder; +} + +// Free encoder context +void whisper_speaker_encoder_free(struct whisper_speaker_encoder * encoder) { + if (!encoder) { + return; + } + + if (encoder->ctx) { + ggml_free(encoder->ctx); + } + + free(encoder); +} + +// Run forward pass: mel [T, 80] → embedding [192] +bool whisper_speaker_encoder_compute( + struct whisper_speaker_encoder * encoder, + const float * mel, + float * embedding) { + + if (!encoder || !mel || !embedding) { + WHISPER_LOG_ERROR("encoder_compute: invalid arguments\n"); + return false; + } + + // Validate input tensor dimensions (ne[0]=n_mels, ne[1]=n_frames) + if (encoder->input_mel->ne[0] != encoder->n_mels || + encoder->input_mel->ne[1] != encoder->n_frames) { + WHISPER_LOG_ERROR("encoder_compute: input shape mismatch\n"); + return false; + } + + // Copy mel data into input tensor + float * input_data = (float *)encoder->input_mel->data; + int mel_size = encoder->n_frames * encoder->n_mels; + memcpy(input_data, mel, mel_size * sizeof(float)); + + // Mark input/output tensors + ggml_set_input(encoder->input_mel); + ggml_set_output(encoder->output_embedding); + + // Execute forward pass using CPU backend + // Allocate work buffer for graph computation + struct ggml_cplan plan = ggml_graph_plan(encoder->graph, 4, NULL); // 4 threads, no custom threadpool + if (plan.work_size > 0) { + plan.work_data = (uint8_t *)malloc(plan.work_size); + if (!plan.work_data) { + WHISPER_LOG_ERROR("encoder_compute: failed to allocate work buffer (%zu bytes)\n", plan.work_size); + return false; + } + } + + // Execute the graph (forward pass) + enum ggml_status ret = ggml_graph_compute(encoder->graph, &plan); + if (plan.work_data) { + free(plan.work_data); + } + + if (ret != GGML_STATUS_SUCCESS) { + WHISPER_LOG_ERROR("encoder_compute: graph compute failed with status %d\n", (int)ret); + return false; + } + + // Check for NaN/Inf in output (sanity check on computed values) + float * output_data = (float *)encoder->output_embedding->data; + for (int i = 0; i < 192; i++) { + if (isnan(output_data[i]) || isinf(output_data[i])) { + WHISPER_LOG_ERROR("encoder_compute: output contains NaN/Inf at index %d\n", i); + return false; + } + } + + // Copy output embedding to caller's buffer + memcpy(embedding, output_data, 192 * sizeof(float)); + + return true; +} + +// Agglomerative hierarchical clustering + +// Cosine distance (double precision) +static double cosine_distance_f64(const float* a, const float* b, int dim) { + double dot_product = 0.0; + double norm_a = 0.0; + double norm_b = 0.0; + for (int i = 0; i < dim; i++) { + double a_f64 = (double)a[i]; + double b_f64 = (double)b[i]; + dot_product += a_f64 * b_f64; + norm_a += a_f64 * a_f64; + norm_b += b_f64 * b_f64; + } + norm_a = std::sqrt(norm_a); + norm_b = std::sqrt(norm_b); + // Handle zero-norm case (protect against division by zero) + if (norm_a < 1e-10 || norm_b < 1e-10) return 1.0; + return 1.0 - (dot_product / (norm_a * norm_b)); +} + +double * compute_distance_matrix( + const float * embeddings, + int num_segments, + int embedding_dim +) { + // Allocate symmetric distance matrix + double * dist_matrix = (double*)malloc(num_segments * num_segments * sizeof(double)); + if (!dist_matrix) return NULL; + + for (int i = 0; i < num_segments; i++) { + dist_matrix[i * num_segments + i] = 0.0; // distance to self is 0 + const float * embedding_i = embeddings + i * embedding_dim; + for (int j = i + 1; j < num_segments; j++) { + const float * embedding_j = embeddings + j * embedding_dim; + double dist = cosine_distance_f64(embedding_i, embedding_j, embedding_dim); + dist_matrix[i * num_segments + j] = dist; + dist_matrix[j * num_segments + i] = dist; // symmetric + } + } + return dist_matrix; +} + +struct whisper_clustering_context * whisper_clustering_context_create(int num_segments) { + if (num_segments <= 0) return NULL; + + struct whisper_clustering_context * ctx = + (struct whisper_clustering_context*)malloc(sizeof(struct whisper_clustering_context)); + if (!ctx) return NULL; + + ctx->num_segments = num_segments; + ctx->embedding_dim = 192; // ECAPA-TDNN output dimension + ctx->distance_matrix = NULL; + ctx->speaker_ids = (int*)malloc(num_segments * sizeof(int)); + ctx->num_speakers = 0; + + if (!ctx->speaker_ids) { + free(ctx); + return NULL; + } + return ctx; +} + +void whisper_clustering_context_free(struct whisper_clustering_context * ctx) { + if (!ctx) return; + if (ctx->distance_matrix) free(ctx->distance_matrix); + if (ctx->speaker_ids) free(ctx->speaker_ids); + free(ctx); +} + +int whisper_clustering_cluster( + struct whisper_clustering_context * ctx, + const float * embeddings, + int target_speakers, + float threshold, + int linkage_type +) { + if (!ctx || !embeddings) return -1; + if (ctx->num_segments <= 0) return -1; + + // Compute distance matrix + ctx->distance_matrix = compute_distance_matrix(embeddings, ctx->num_segments, ctx->embedding_dim); + if (!ctx->distance_matrix) { + WHISPER_LOG_ERROR("clustering_cluster: failed to allocate distance matrix\n"); + return -1; + } + + int num_segments = ctx->num_segments; + double * dist = ctx->distance_matrix; + + int * cluster_assignment = (int*)malloc(num_segments * sizeof(int)); + if (!cluster_assignment) { + free(ctx->distance_matrix); + ctx->distance_matrix = NULL; + return -1; + } + + for (int i = 0; i < num_segments; i++) { + cluster_assignment[i] = i; + } + + bool * cluster_active = (bool*)malloc(num_segments * sizeof(bool)); + if (!cluster_active) { + free(cluster_assignment); + free(ctx->distance_matrix); + ctx->distance_matrix = NULL; + return -1; + } + + for (int i = 0; i < num_segments; i++) { + cluster_active[i] = true; + } + + int num_active_clusters = num_segments; + + while (num_active_clusters > 1) { + // Find two closest active clusters + double min_distance = 1e10; + int merge_cluster1 = -1; + int merge_cluster2 = -1; + + for (int i = 0; i < num_segments; i++) { + if (!cluster_active[i]) continue; + for (int j = i + 1; j < num_segments; j++) { + if (!cluster_active[j]) continue; + double d = dist[i * num_segments + j]; + if (d < min_distance) { + min_distance = d; + merge_cluster1 = i; + merge_cluster2 = j; + } + } + } + + if (merge_cluster1 == -1) break; // No clusters to merge + + if (target_speakers > 0) { + if (num_active_clusters <= target_speakers) break; + } else { + if (min_distance > (double)threshold) break; + } + + // Merge cluster2 into cluster1 + for (int i = 0; i < num_segments; i++) { + if (cluster_assignment[i] == merge_cluster2) { + cluster_assignment[i] = merge_cluster1; + } + } + cluster_active[merge_cluster2] = false; + num_active_clusters--; + + // Update distances + for (int k = 0; k < num_segments; k++) { + if (!cluster_active[k] || k == merge_cluster1) continue; + + double new_distance = 0.0; + + if (linkage_type == WHISPER_LINKAGE_AVERAGE) { + // Average linkage: mean distance between all pairs + int pairs_count = 0; + new_distance = 0.0; + for (int i = 0; i < num_segments; i++) { + if (cluster_assignment[i] != merge_cluster1) continue; + for (int j = 0; j < num_segments; j++) { + if (cluster_assignment[j] != k) continue; + new_distance += dist[i * num_segments + j]; + pairs_count++; + } + } + if (pairs_count > 0) { + new_distance /= (double)pairs_count; + } + } else { // WHISPER_LINKAGE_COMPLETE + // Complete linkage: max distance between any pair + new_distance = 0.0; + for (int i = 0; i < num_segments; i++) { + if (cluster_assignment[i] != merge_cluster1) continue; + for (int j = 0; j < num_segments; j++) { + if (cluster_assignment[j] != k) continue; + double d = dist[i * num_segments + j]; + if (d > new_distance) { + new_distance = d; + } + } + } + } + + dist[merge_cluster1 * num_segments + k] = new_distance; + dist[k * num_segments + merge_cluster1] = new_distance; + } + } + + // Map cluster IDs to 0-based speaker IDs + int * cluster_to_speaker = (int*)malloc(num_segments * sizeof(int)); + if (!cluster_to_speaker) { + free(cluster_assignment); + free(cluster_active); + if (ctx->distance_matrix) { + free(ctx->distance_matrix); + ctx->distance_matrix = NULL; + } + return -1; + } + + for (int i = 0; i < num_segments; i++) { + cluster_to_speaker[i] = -1; + } + + int num_speakers = 0; + for (int i = 0; i < num_segments; i++) { + int cluster_id = cluster_assignment[i]; + if (cluster_to_speaker[cluster_id] == -1) { + cluster_to_speaker[cluster_id] = num_speakers++; + } + ctx->speaker_ids[i] = cluster_to_speaker[cluster_id]; + } + + ctx->num_speakers = num_speakers; + + free(cluster_assignment); + free(cluster_active); + free(cluster_to_speaker); + + return 0; +} diff --git a/src/whisper-diarize.h b/src/whisper-diarize.h new file mode 100644 index 00000000000..afe61e8db58 --- /dev/null +++ b/src/whisper-diarize.h @@ -0,0 +1,79 @@ +#ifndef WHISPER_DIARIZE_H +#define WHISPER_DIARIZE_H + +#include "whisper-speaker.h" +#include "ggml.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// Mel-spectrogram computation (80-bin, n_fft=400, hop=160, fmin=0, fmax=8000) +float * whisper_compute_mel_80(const float * samples, int n_samples); + +// Get number of mel frames for given sample count +int whisper_get_mel_n_frames(int n_samples); + +// Free mel buffer +void whisper_mel_free(float * mel); + +// Speaker encoder +struct whisper_speaker_encoder; + +// Create encoder for given mel frame count +struct whisper_speaker_encoder * whisper_speaker_encoder_new( + struct whisper_speaker_model * model, + int n_frames, + int device +); + +// Free encoder context +void whisper_speaker_encoder_free(struct whisper_speaker_encoder * encoder); + +// Run forward pass: mel [T, 80] → embedding [192] +bool whisper_speaker_encoder_compute( + struct whisper_speaker_encoder * encoder, + const float * mel, + float * embedding +); + +// Clustering (agglomerative hierarchical) +enum whisper_linkage_type { + WHISPER_LINKAGE_AVERAGE = 0, // average linkage (default) + WHISPER_LINKAGE_COMPLETE = 1 // complete linkage (more conservative) +}; + +struct whisper_clustering_context { + int num_segments; + int embedding_dim; + double * distance_matrix; + int * speaker_ids; + int num_speakers; +}; + +struct whisper_clustering_context * whisper_clustering_context_create(int num_segments); + +void whisper_clustering_context_free(struct whisper_clustering_context * ctx); + +double * compute_distance_matrix( + const float * embeddings, + int num_segments, + int embedding_dim +); + +// Run agglomerative clustering on embeddings +// target_speakers > 0: force exact count; == 0: auto-detect using threshold +// Returns 0 on success, -1 on error +int whisper_clustering_cluster( + struct whisper_clustering_context * ctx, + const float * embeddings, + int target_speakers, + float threshold, + int linkage_type +); + +#ifdef __cplusplus +} +#endif + +#endif // WHISPER_DIARIZE_H diff --git a/src/whisper-speaker.cpp b/src/whisper-speaker.cpp new file mode 100644 index 00000000000..28c6c4c7d0d --- /dev/null +++ b/src/whisper-speaker.cpp @@ -0,0 +1,275 @@ +#include "whisper-speaker.h" +#include +#include +#include +#include +#include + +struct whisper_speaker_model { + struct ggml_context * ctx; + std::vector tensors; + std::vector tensor_names; + int embedding_dim; + int n_tensors; +}; + +// Load GGML speaker model from file +whisper_speaker_model * whisper_speaker_load_from_file(const char * path_model) { + FILE * fin = fopen(path_model, "rb"); + if (!fin) { + fprintf(stderr, "Failed to open model file: %s\n", path_model); + return nullptr; + } + + // Read magic number (must be 0x67676d6c = "ggml") + uint32_t magic; + if (fread(&magic, sizeof(magic), 1, fin) != 1) { + fprintf(stderr, "Failed to read magic number\n"); + fclose(fin); + return nullptr; + } + if (magic != 0x67676d6c) { // "ggml" + fprintf(stderr, "Invalid GGML magic: 0x%x (expected 0x67676d6c)\n", magic); + fclose(fin); + return nullptr; + } + printf("GGML magic valid: 0x%08x\n", magic); + + // Read model type string (length-prefixed UTF-8) + int str_len; + if (fread(&str_len, sizeof(str_len), 1, fin) != 1) { + fprintf(stderr, "Failed to read model type length\n"); + fclose(fin); + return nullptr; + } + + if (str_len < 0 || str_len > 256) { + fprintf(stderr, "Invalid model type length: %d\n", str_len); + fclose(fin); + return nullptr; + } + + char model_type[257]; + if (fread(model_type, str_len, 1, fin) != 1) { + fprintf(stderr, "Failed to read model type\n"); + fclose(fin); + return nullptr; + } + model_type[str_len] = '\0'; + printf("Model type: %s\n", model_type); + + // Read version (major, minor, patch) + int major, minor, patch; + if (fread(&major, sizeof(major), 1, fin) != 1 || + fread(&minor, sizeof(minor), 1, fin) != 1 || + fread(&patch, sizeof(patch), 1, fin) != 1) { + fprintf(stderr, "Failed to read version\n"); + fclose(fin); + return nullptr; + } + printf("Version: %d.%d.%d\n", major, minor, patch); + + // Read hyperparameters + int embedding_dim; + if (fread(&embedding_dim, sizeof(embedding_dim), 1, fin) != 1) { + fprintf(stderr, "Failed to read embedding_dim\n"); + fclose(fin); + return nullptr; + } + printf("Embedding dimension: %d\n", embedding_dim); + + int n_channels; + if (fread(&n_channels, sizeof(n_channels), 1, fin) != 1) { + fprintf(stderr, "Failed to read n_channels\n"); + fclose(fin); + return nullptr; + } + printf("Internal channels: %d\n", n_channels); + + // Read tensor count (for verification) + int n_tensors_expected; + if (fread(&n_tensors_expected, sizeof(n_tensors_expected), 1, fin) != 1) { + fprintf(stderr, "Failed to read tensor count\n"); + fclose(fin); + return nullptr; + } + printf("Expected tensors: %d\n", n_tensors_expected); + + // Create ggml context with sufficient memory for all tensors + // ~500 MB should accommodate ECAPA-TDNN model weights + size_t ctx_size = 500 * 1024 * 1024; // 500 MB + struct ggml_init_params ggml_params = { + .mem_size = ctx_size, + .mem_buffer = malloc(ctx_size), + .no_alloc = false, + }; + + if (!ggml_params.mem_buffer) { + fprintf(stderr, "Failed to allocate GGML context buffer\n"); + fclose(fin); + return nullptr; + } + + struct ggml_context * ctx = ggml_init(ggml_params); + if (!ctx) { + fprintf(stderr, "Failed to create ggml context\n"); + free(ggml_params.mem_buffer); + fclose(fin); + return nullptr; + } + printf("GGML context created (%zu bytes)\n", ctx_size); + + // Create speaker model structure + whisper_speaker_model * model = new whisper_speaker_model(); + model->ctx = ctx; + model->embedding_dim = embedding_dim; + model->n_tensors = 0; + + printf("\nLoading tensors:\n"); + + // Load tensors + for (int t = 0; t < n_tensors_expected; ++t) { + // Read tensor header: n_dims, name_len + int n_dims; + if (fread(&n_dims, sizeof(n_dims), 1, fin) != 1) { + fprintf(stderr, "Failed to read n_dims for tensor %d\n", t); + break; + } + + int name_len; + if (fread(&name_len, sizeof(name_len), 1, fin) != 1) { + fprintf(stderr, "Failed to read name_len for tensor %d\n", t); + break; + } + + // Sanity checks + if (n_dims < 0 || n_dims > 8) { + fprintf(stderr, "Invalid n_dims for tensor %d: %d\n", t, n_dims); + break; + } + if (name_len < 0 || name_len > 512) { + fprintf(stderr, "Invalid name_len for tensor %d: %d\n", t, name_len); + break; + } + + // Read dimensions (convert to int64_t) + int64_t dims[8] = {0}; + for (int i = 0; i < n_dims; ++i) { + int dim; + if (fread(&dim, sizeof(int), 1, fin) != 1) { + fprintf(stderr, "Failed to read dim %d for tensor %d\n", i, t); + break; + } + dims[i] = (int64_t)dim; + } + + // Read tensor name (not null-terminated in binary) + char name[513]; + if (fread(name, name_len, 1, fin) != 1) { + fprintf(stderr, "Failed to read tensor name for tensor %d\n", t); + break; + } + name[name_len] = '\0'; + + // Create tensor in ggml context + struct ggml_tensor * tensor = ggml_new_tensor(ctx, GGML_TYPE_F32, n_dims, dims); + if (!tensor) { + fprintf(stderr, "Failed to create tensor: %s\n", name); + break; + } + + // Read tensor data (float32) + size_t nelements = ggml_nelements(tensor); + size_t bytes_read = fread(tensor->data, sizeof(float), nelements, fin); + if (bytes_read != nelements) { + fprintf(stderr, "Failed to read tensor data for %s: got %zu, expected %zu\n", + name, bytes_read, nelements); + break; + } + + ggml_set_name(tensor, name); + model->tensors.push_back(tensor); + model->tensor_names.push_back(std::string(name)); + + printf(" [%d] %s: ", t + 1, name); + for (int i = 0; i < n_dims; ++i) { + printf("%lld", (long long)dims[i]); + if (i < n_dims - 1) printf("x"); + } + printf(" (%zu elements, %.2f MB)\n", nelements, (nelements * sizeof(float)) / 1024.0 / 1024.0); + + model->n_tensors++; + } + + printf("\nModel loaded: %d / %d tensors\n", model->n_tensors, n_tensors_expected); + + if (model->n_tensors != n_tensors_expected) { + fprintf(stderr, "Warning: Loaded %d tensors but expected %d\n", model->n_tensors, n_tensors_expected); + } + + fclose(fin); + return model; +} + +void whisper_speaker_validate(whisper_speaker_model * model) { + if (!model) { + fprintf(stderr, "Error: Model is nullptr\n"); + return; + } + + printf("\n=== Model Validation ===\n"); + printf("Embedding dimension: %d\n", model->embedding_dim); + printf("Total tensors loaded: %d\n", model->n_tensors); + + if (model->ctx) { + printf("Context allocated\n"); + } + + if (model->embedding_dim == 192) { + printf("Embedding dimension correct (192)\n"); + } else { + printf("WARNING: Embedding dimension unexpected: %d (expected 192)\n", model->embedding_dim); + } + + if (model->n_tensors > 0) { + printf("Model structure valid (%d tensors)\n", model->n_tensors); + } else { + printf("ERROR: No tensors loaded\n"); + } +} + +int whisper_speaker_get_embedding_dim(whisper_speaker_model * model) { + return model ? model->embedding_dim : -1; +} + +int whisper_speaker_get_tensor_count(whisper_speaker_model * model) { + return model ? model->n_tensors : -1; +} + +struct ggml_tensor * whisper_speaker_get_tensor(struct whisper_speaker_model * model, int idx) { + if (!model || idx < 0 || idx >= model->n_tensors) { + return nullptr; + } + return model->tensors[idx]; +} + +struct ggml_tensor * whisper_speaker_find_tensor(struct whisper_speaker_model * model, const char * name) { + if (!model || !name) return nullptr; + for (int i = 0; i < model->n_tensors; i++) { + if (model->tensor_names[i] == name) { + return model->tensors[i]; + } + } + return nullptr; +} + +void whisper_speaker_free(whisper_speaker_model * model) { + if (!model) return; + + if (model->ctx) { + // Free context (buffer is managed internally by ggml) + ggml_free(model->ctx); + } + + delete model; +} diff --git a/src/whisper.cpp b/src/whisper.cpp index 86bfafeaad8..aa77aad28d1 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -14,6 +14,11 @@ #include "openvino/whisper-openvino-encoder.h" #endif +#ifdef WHISPER_DIARIZE +#include "whisper-diarize.h" +#include "whisper-speaker.h" +#endif + #include #include #include @@ -467,6 +472,8 @@ struct whisper_segment { std::vector tokens; bool speaker_turn_next; + + int speaker_id = -1; // Speaker assignment from diarization; -1 if diarization disabled }; struct whisper_batch { @@ -932,6 +939,12 @@ struct whisper_state { bool has_vad_segments = false; std::vector vad_mapping_table; + + // Speaker diarization context + struct whisper_speaker_model * diarize_model = nullptr; + struct whisper_speaker_encoder * diarize_encoder = nullptr; + std::vector diarize_embeddings; + struct whisper_clustering_context * diarize_clustering = nullptr; }; struct whisper_context { @@ -3840,6 +3853,25 @@ void whisper_free_state(struct whisper_state * state) { state->vad_context = nullptr; } +#ifdef WHISPER_DIARIZE + // Free diarization context + if (state->diarize_model) { + whisper_speaker_free(state->diarize_model); + state->diarize_model = nullptr; + } + + if (state->diarize_encoder) { + whisper_speaker_encoder_free(state->diarize_encoder); + state->diarize_encoder = nullptr; + } + + if (state->diarize_clustering) { + whisper_clustering_context_free(state->diarize_clustering); + state->diarize_clustering = nullptr; + } + +#endif + delete state; } } @@ -5989,6 +6021,12 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.vad_model_path =*/ nullptr, /* vad_params =*/ whisper_vad_default_params(), + + // Speaker diarization defaults + /*.diarize =*/ false, + /*.diarize_model_path =*/ nullptr, + /*.diarize_threshold =*/ 0.5f, + /*.diarize_speakers =*/ 0, }; switch (strategy) { @@ -6968,6 +7006,32 @@ int whisper_full_with_state( prompt_init.push_back(whisper_token_not(ctx)); } +#ifdef WHISPER_DIARIZE + // Lazy load speaker embedding model if diarization enabled + if (params.diarize && params.diarize_model_path) { + if (!state->diarize_model) { + state->diarize_model = whisper_speaker_load_from_file( + params.diarize_model_path + ); + if (!state->diarize_model) { + WHISPER_LOG_ERROR("failed to load speaker embedding model from: %s\n", + params.diarize_model_path); + // Continue without diarization + } + } + } + + // Free leftover encoder from previous calls + if (state->diarize_encoder) { + whisper_speaker_encoder_free(state->diarize_encoder); + state->diarize_encoder = nullptr; + } + + if (params.diarize && state->diarize_model) { + state->diarize_embeddings.clear(); + } +#endif // WHISPER_DIARIZE + int seek = seek_start; std::vector prompt; @@ -7648,6 +7712,43 @@ int whisper_full_with_state( if (params.new_segment_callback && !ctx->params.dtw_token_timestamps) { params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data); } + +#ifdef WHISPER_DIARIZE + // Compute diarization embedding for this segment (if enabled) + if (params.diarize && state->diarize_model) { + whisper_segment & seg = state->result_all.back(); + + // Extract PCM for segment and compute mel + int sample_start = (int)(seg.t0 * WHISPER_SAMPLE_RATE / 100); + int sample_end = (int)(seg.t1 * WHISPER_SAMPLE_RATE / 100); + if (sample_start < 0) sample_start = 0; + if (sample_end > n_samples) sample_end = n_samples; + int seg_n_samples = sample_end - sample_start; + + if (seg_n_samples > 1600) { // at least 0.1s + float * mel_data = whisper_compute_mel_80(samples + sample_start, seg_n_samples); + if (mel_data) { + int mel_n_frames = whisper_get_mel_n_frames(seg_n_samples); + if (mel_n_frames > 160) mel_n_frames = 160; // cap for memory + + // Create per-segment encoder with correct n_frames + whisper_speaker_encoder * enc = whisper_speaker_encoder_new( + state->diarize_model, mel_n_frames, 0); + + if (enc) { + float embedding[192] = {0}; + if (whisper_speaker_encoder_compute(enc, mel_data, embedding)) { + state->diarize_embeddings.insert( + state->diarize_embeddings.end(), + embedding, embedding + 192); + } + whisper_speaker_encoder_free(enc); + } + free(mel_data); + } + } + } +#endif // WHISPER_DIARIZE } text = ""; while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) { @@ -7693,6 +7794,45 @@ int whisper_full_with_state( if (params.new_segment_callback && !ctx->params.dtw_token_timestamps) { params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data); } + +#ifdef WHISPER_DIARIZE + // Compute diarization embedding for this segment (if enabled) + if (params.diarize && state->diarize_encoder) { + whisper_segment & seg = state->result_all.back(); // Just-added segment + + // Extract mel frames for segment time range [seg.t0, seg.t1] + // Time is in centiseconds; frame index = time_centiseconds / 100 + int frame_start = seg.t0 / 100; + int frame_end = seg.t1 / 100; + int n_frames = frame_end - frame_start; + + if (n_frames > 0 && frame_start >= 0 && frame_end <= state->mel.n_len) { + // Get mel spectrogram for this segment + const float * mel_data = state->mel.data.data() + (frame_start * 80); + + // Allocate embedding buffer for this segment + float embedding[192]; // ECAPA-TDNN output size (192-dim) + memset(embedding, 0, sizeof(embedding)); + + // Compute speaker embedding + if (whisper_speaker_encoder_compute(state->diarize_encoder, mel_data, embedding)) { + // Store in buffer for later clustering + state->diarize_embeddings.insert( + state->diarize_embeddings.end(), + embedding, + embedding + 192 + ); + } else { + WHISPER_LOG_WARN("failed to compute speaker embedding for segment %lu\n", + state->result_all.size() - 1); + } + } else { + WHISPER_LOG_WARN("invalid mel frame range for segment %lu: [%d, %d), total frames: %d\n", + state->result_all.size() - 1, frame_start, frame_end, + state->mel.n_len); + } + } +#endif // WHISPER_DIARIZE } } @@ -7728,6 +7868,44 @@ int whisper_full_with_state( } } +#ifdef WHISPER_DIARIZE + // Perform speaker clustering if diarization enabled + if (params.diarize && !state->diarize_embeddings.empty()) { + int num_segments = state->result_all.size(); + int num_embeddings = state->diarize_embeddings.size() / 192; + + if (num_embeddings != num_segments) { + WHISPER_LOG_ERROR("embedding count (%d) != segment count (%d); skipping clustering\n", + num_embeddings, num_segments); + } else { + state->diarize_clustering = whisper_clustering_context_create(num_segments); + if (!state->diarize_clustering) { + WHISPER_LOG_ERROR("failed to create clustering context\n"); + } else { + int ret = whisper_clustering_cluster( + state->diarize_clustering, + state->diarize_embeddings.data(), + params.diarize_speakers, + params.diarize_threshold, + WHISPER_LINKAGE_AVERAGE + ); + + if (ret == 0) { + // Assign speaker IDs to segments + for (int i = 0; i < num_segments; ++i) { + state->result_all[i].speaker_id = + state->diarize_clustering->speaker_ids[i]; + } + WHISPER_LOG_INFO("diarization complete: %d speakers detected\n", + state->diarize_clustering->num_speakers); + } else { + WHISPER_LOG_ERROR("clustering failed with code %d\n", ret); + } + } + } + } +#endif // WHISPER_DIARIZE + return 0; } @@ -7997,6 +8175,17 @@ bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, in return ctx->state->result_all[i_segment].speaker_turn_next; } +int whisper_full_get_segment_speaker_id(struct whisper_context * ctx, int segment) { + return whisper_full_get_segment_speaker_id_from_state(ctx->state, segment); +} + +int whisper_full_get_segment_speaker_id_from_state(struct whisper_state * state, int segment) { + if (segment < 0 || segment >= (int) state->result_all.size()) { + return -1; // Invalid segment index + } + return state->result_all[segment].speaker_id; +} + const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment) { return state->result_all[i_segment].text.c_str(); } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 09e77ea89c2..69caa521912 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -110,3 +110,71 @@ target_compile_definitions(${VAD_TEST} PRIVATE SAMPLE_PATH="${PROJECT_SOURCE_DIR}/samples/jfk.wav") add_test(NAME ${VAD_TEST} COMMAND ${VAD_TEST}) set_tests_properties(${VAD_TEST} PROPERTIES LABELS "base;en") + +# Speaker embedding inference test +add_executable(test-speaker-embedding test-speaker-embedding.cpp ../src/whisper-diarize.cpp ../src/whisper-speaker.cpp) +target_include_directories(test-speaker-embedding PRIVATE ../include ../ggml/include) +target_link_libraries(test-speaker-embedding PRIVATE ggml m) + +# Embedding quality (speaker discrimination) +add_executable(test-embedding-quality test-embedding-quality.cpp ../src/whisper-diarize.cpp ../src/whisper-speaker.cpp) +target_include_directories(test-embedding-quality PRIVATE ../include ../ggml/include ../src) +target_link_libraries(test-embedding-quality PRIVATE ggml m) + +# Clustering algorithm test +add_executable(test-clustering test-clustering.cpp ../src/whisper-diarize.cpp ../src/whisper-speaker.cpp) +target_include_directories(test-clustering PRIVATE ../include ../ggml/include) +target_link_libraries(test-clustering PRIVATE ggml m) +target_compile_options(test-clustering PRIVATE -std=c++11 -Wall -Wextra) + +# Diarization integration test +add_executable(test-diarization-integration test-diarization-integration.cpp) +target_include_directories(test-diarization-integration PRIVATE ../include ../ggml/include) +target_link_libraries(test-diarization-integration PRIVATE whisper) + +# CLI diarization end-to-end test +add_executable(test-cli-diarization test-cli-diarization.cpp) +target_include_directories(test-cli-diarization PRIVATE ../include ../examples) +target_link_libraries(test-cli-diarization PRIVATE whisper common) +target_compile_options(test-cli-diarization PRIVATE -std=c++17) + +# DER benchmark +add_executable(benchmark-der benchmark-der.cpp) +target_include_directories(benchmark-der PRIVATE ../include ../ggml/include) +target_link_libraries(benchmark-der PRIVATE ggml m) +target_compile_options(benchmark-der PRIVATE -std=c++17) + +# AddressSanitizer support +if (WHISPER_SANITIZE_ADDRESS) + message(STATUS "Enabling AddressSanitizer for benchmark-der") + target_compile_options(benchmark-der PRIVATE -fsanitize=address -fno-omit-frame-pointer) + target_link_options(benchmark-der PRIVATE -fsanitize=address) +endif() + +# Register tests +enable_testing() +add_test(NAME speaker-embedding COMMAND test-speaker-embedding WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) +add_test(NAME clustering COMMAND test-clustering WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) +add_test(NAME diarization-integration COMMAND test-diarization-integration WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) +add_test(NAME cli-diarization COMMAND test-cli-diarization WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + +# Register benchmark-der test +if (WHISPER_DIARIZE) + add_test(NAME benchmark-der COMMAND benchmark-der WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + set_tests_properties(benchmark-der PROPERTIES LABELS "diarization;benchmark") + + # ASan leak detection (Linux only) + if (WHISPER_SANITIZE_ADDRESS AND CMAKE_SYSTEM_NAME STREQUAL "Linux") + set_tests_properties(benchmark-der PROPERTIES ENVIRONMENT "ASAN_OPTIONS=detect_leaks=1:strip_path_prefix=${CMAKE_SOURCE_DIR}/") + endif() +endif() + +# Test: Reference embedding comparison (feed PyTorch mel → compare embeddings) +add_executable(test-ref-compare test-ref-compare.cpp ../src/whisper-diarize.cpp ../src/whisper-speaker.cpp) +target_include_directories(test-ref-compare PRIVATE ../include ../ggml/include ../src) +target_link_libraries(test-ref-compare PRIVATE ggml m) + +# Test: Mel feature comparison +add_executable(test-mel-compare test-mel-compare.cpp ../src/whisper-diarize.cpp ../src/whisper-speaker.cpp) +target_include_directories(test-mel-compare PRIVATE ../include ../ggml/include ../src) +target_link_libraries(test-mel-compare PRIVATE ggml m) From 9c4f76243e66dff8a10fd413119514d51ec828f6 Mon Sep 17 00:00:00 2001 From: MKY508 Date: Fri, 3 Apr 2026 00:26:31 +0800 Subject: [PATCH 2/5] Fix diarize_clustering memory leak on repeated inference whisper_clustering_context_create() overwrites the old pointer without freeing it first. When the same whisper_state is reused across multiple inference runs, the previous clustering context leaks. Free it before creating a new one. --- src/whisper.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/whisper.cpp b/src/whisper.cpp index aa77aad28d1..ac04709d53a 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -7878,6 +7878,12 @@ int whisper_full_with_state( WHISPER_LOG_ERROR("embedding count (%d) != segment count (%d); skipping clustering\n", num_embeddings, num_segments); } else { + // Free previous clustering context if it exists + if (state->diarize_clustering) { + whisper_clustering_context_free(state->diarize_clustering); + state->diarize_clustering = nullptr; + } + state->diarize_clustering = whisper_clustering_context_create(num_segments); if (!state->diarize_clustering) { WHISPER_LOG_ERROR("failed to create clustering context\n"); From 7784465985014e759246ab3f63809e8531a4b5f9 Mon Sep 17 00:00:00 2001 From: MKY508 Date: Sat, 4 Apr 2026 00:14:26 +0800 Subject: [PATCH 3/5] whisper : wire diarization into cli/server and cleanup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - cli/server: add --diarize-model, --diarize-threshold, --diarize-speakers - unify speaker label logic across all output formats (txt/vtt/srt/csv/json/lrc/wts) - fall back to stereo diarization when no model is provided - fix memory leak in whisper_compute_mel_80, move allocs out of hot loop - thread-safe static init with std::call_once - rename hann → hamming (was actually hamming), remove dead code - dynamic ggml context sizing, WHISPER_LOG_* macros in speaker loader - fix n_channels 512 → 1024 in python converter - server: ARGV_NEXT bounds checking for all args --- CMakeLists.txt | 1 + examples/cli/README.md | 21 +++- examples/cli/cli.cpp | 157 ++++++++++++++--------- examples/server/README.md | 11 +- examples/server/server.cpp | 203 ++++++++++++++++++++---------- include/whisper-speaker.h | 1 - include/whisper.h | 2 +- models/convert-speaker-to-ggml.py | 40 +----- src/whisper-diarize.cpp | 81 ++++++------ src/whisper-speaker.cpp | 109 +++++++--------- src/whisper.cpp | 198 +++++++++++++++-------------- tests/CMakeLists.txt | 155 ++++++++++++++--------- 12 files changed, 545 insertions(+), 434 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a0f74041321..905169d2c2d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -69,6 +69,7 @@ option(WHISPER_ALL_WARNINGS_3RD_PARTY "whisper: enable all compiler warnings in # build option(WHISPER_FATAL_WARNINGS "whisper: enable -Werror flag" OFF) option(WHISPER_USE_SYSTEM_GGML "whisper: use system-installed GGML library" OFF) +option(WHISPER_DIARIZE "whisper: enable speaker diarization" OFF) # sanitizers option(WHISPER_SANITIZE_THREAD "whisper: enable thread sanitizer" OFF) diff --git a/examples/cli/README.md b/examples/cli/README.md index 65285c3cb66..d44610ee6e3 100644 --- a/examples/cli/README.md +++ b/examples/cli/README.md @@ -30,8 +30,11 @@ options: -tpi, --temperature-inc N [0.20 ] The increment of temperature, between 0 and 1 -debug, --debug-mode [false ] enable debug mode (eg. dump log_mel) -tr, --translate [false ] translate from source language to english - -di, --diarize [false ] stereo audio diarization - -tdrz, --tinydiarize [false ] enable tinydiarize (requires a tdrz model) + -di, --diarize [false ] enable speaker diarization + --diarize-model FNAME [ ] speaker embedding model path (GGML .bin) + --diarize-threshold N [0.50 ] clustering distance threshold + --diarize-speakers N [0 ] target speaker count (0 = auto) + -tdrz, --tinydiarize [false ] enable tinydiarize (requires a tdrz model) -nf, --no-fallback [false ] do not use temperature fallback while decoding -otxt, --output-txt [false ] output result in a text file -ovtt, --output-vtt [false ] output result in a vtt file @@ -62,5 +65,15 @@ options: --suppress-regex REGEX [ ] regular expression matching tokens to suppress --grammar GRAMMAR [ ] GBNF grammar to guide decoding --grammar-rule RULE [ ] top-level GBNF grammar rule name - --grammar-penalty N [100.0 ] scales down logits of nongrammar tokens -``` + --grammar-penalty N [100.0 ] scales down logits of nongrammar tokens +``` + +Model-based diarization uses the ECAPA-TDNN speaker embedding model produced by +`models/convert-speaker-to-ggml.py`: + +``` +python models/convert-speaker-to-ggml.py --output models/ggml-speaker-ecapa-tdnn.bin +./build/bin/whisper-cli -m models/ggml-base.en.bin \ + --diarize --diarize-model models/ggml-speaker-ecapa-tdnn.bin \ + -f input.wav +``` diff --git a/examples/cli/cli.cpp b/examples/cli/cli.cpp index 4e84c1b2750..3ff017708f8 100644 --- a/examples/cli/cli.cpp +++ b/examples/cli/cli.cpp @@ -85,9 +85,13 @@ struct whisper_params { std::string prompt; std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf"; std::string model = "models/ggml-base.en.bin"; + std::string diarize_model; std::string grammar; std::string grammar_rule; + float diarize_threshold = 0.5f; + int diarize_speakers = 0; + // [TDRZ] speaker turn string std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line @@ -129,6 +133,16 @@ static char * requires_value_error(const std::string & arg) { exit(0); } +static bool use_model_diarization(const whisper_params & params) { + return params.diarize && !params.diarize_model.empty(); +} + +static bool use_stereo_diarization( + const whisper_params & params, + const std::vector> & pcmf32s) { + return params.diarize && params.diarize_model.empty() && pcmf32s.size() == 2; +} + static bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { if (const char * env_device = std::getenv("WHISPER_ARG_DEVICE")) { params.gpu_device = std::stoi(env_device); @@ -171,6 +185,9 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } + else if ( arg == "--diarize-model") { params.diarize_model = ARGV_NEXT; params.diarize = true; } + else if ( arg == "--diarize-threshold") { params.diarize_threshold = std::stof(ARGV_NEXT); } + else if ( arg == "--diarize-speakers") { params.diarize_speakers = std::stoi(ARGV_NEXT); } else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; } else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; } else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } @@ -253,7 +270,10 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params fprintf(stderr, " -tpi, --temperature-inc N [%-7.2f] The increment of temperature, between 0 and 1\n",params.temperature_inc); fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); - fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); + fprintf(stderr, " -di, --diarize [%-7s] enable speaker diarization\n", params.diarize ? "true" : "false"); + fprintf(stderr, " --diarize-model FNAME [%-7s] speaker embedding model path (GGML .bin)\n", params.diarize_model.c_str()); + fprintf(stderr, " --diarize-threshold N [%-7.2f] clustering distance threshold\n", params.diarize_threshold); + fprintf(stderr, " --diarize-speakers N [%-7d] target speaker count (0 = auto)\n", params.diarize_speakers); fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false"); fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false"); fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); @@ -312,7 +332,7 @@ struct whisper_print_user_data { int progress_prev; }; -static std::string estimate_diarization_speaker(std::vector> pcmf32s, int64_t t0, int64_t t1, bool id_only = false) { +static std::string estimate_diarization_speaker(const std::vector> & pcmf32s, int64_t t0, int64_t t1, bool id_only = false) { std::string speaker = ""; const int64_t n_samples = pcmf32s[0].size(); @@ -345,6 +365,45 @@ static std::string estimate_diarization_speaker(std::vector> return speaker; } +static std::string get_segment_speaker_id( + struct whisper_context * ctx, + const whisper_params & params, + const std::vector> & pcmf32s, + int i_segment) { + if (!params.diarize) { + return ""; + } + + if (use_model_diarization(params)) { + const int speaker_id = whisper_full_get_segment_speaker_id(ctx, i_segment); + return speaker_id >= 0 ? std::to_string(speaker_id) : ""; + } + + if (use_stereo_diarization(params, pcmf32s)) { + const int64_t t0 = whisper_full_get_segment_t0(ctx, i_segment); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i_segment); + return estimate_diarization_speaker(pcmf32s, t0, t1, true); + } + + return ""; +} + +static std::string format_segment_speaker_label(const std::string & speaker_id) { + if (speaker_id.empty()) { + return ""; + } + + return "(speaker " + speaker_id + ")"; +} + +static std::string format_segment_speaker_vtt(const std::string & speaker_id) { + if (speaker_id.empty()) { + return ""; + } + + return ""; +} + static void whisper_print_progress_callback(struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, int progress, void * user_data) { int progress_step = ((whisper_print_user_data *) user_data)->params->progress_step; int * progress_prev = &(((whisper_print_user_data *) user_data)->progress_prev); @@ -382,9 +441,7 @@ static void whisper_print_segment_callback(struct whisper_context * ctx, struct printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str()); } - if (params.diarize && pcmf32s.size() == 2) { - speaker = estimate_diarization_speaker(pcmf32s, t0, t1); - } + speaker = format_segment_speaker_label(get_segment_speaker_id(ctx, params, pcmf32s, i)); if (params.print_colors) { for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { @@ -447,24 +504,17 @@ static void whisper_print_segment_callback(struct whisper_context * ctx, struct } } -static void output_txt(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, std::vector> pcmf32s) { +static void output_txt(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, const std::vector> & pcmf32s) { const int n_segments = whisper_full_n_segments(ctx); for (int i = 0; i < n_segments; ++i) { const char * text = whisper_full_get_segment_text(ctx, i); - std::string speaker = ""; - - if (params.diarize && pcmf32s.size() == 2) - { - const int64_t t0 = whisper_full_get_segment_t0(ctx, i); - const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - speaker = estimate_diarization_speaker(pcmf32s, t0, t1); - } + const std::string speaker = format_segment_speaker_label(get_segment_speaker_id(ctx, params, pcmf32s, i)); fout << speaker << text << "\n"; } } -static void output_vtt(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, std::vector> pcmf32s) { +static void output_vtt(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, const std::vector> & pcmf32s) { fout << "WEBVTT\n\n"; const int n_segments = whisper_full_n_segments(ctx); @@ -472,32 +522,20 @@ static void output_vtt(struct whisper_context * ctx, std::ofstream & fout, const const char * text = whisper_full_get_segment_text(ctx, i); const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - std::string speaker = ""; - - if (params.diarize && pcmf32s.size() == 2) - { - speaker = estimate_diarization_speaker(pcmf32s, t0, t1, true); - speaker.insert(0, ""); - } + const std::string speaker = format_segment_speaker_vtt(get_segment_speaker_id(ctx, params, pcmf32s, i)); fout << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n"; fout << speaker << text << "\n\n"; } } -static void output_srt(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, std::vector> pcmf32s) { +static void output_srt(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, const std::vector> & pcmf32s) { const int n_segments = whisper_full_n_segments(ctx); for (int i = 0; i < n_segments; ++i) { const char * text = whisper_full_get_segment_text(ctx, i); const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - std::string speaker = ""; - - if (params.diarize && pcmf32s.size() == 2) - { - speaker = estimate_diarization_speaker(pcmf32s, t0, t1); - } + const std::string speaker = format_segment_speaker_label(get_segment_speaker_id(ctx, params, pcmf32s, i)); fout << i + 1 + params.offset_n << "\n"; fout << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n"; @@ -568,11 +606,11 @@ static char * escape_double_quotes_in_csv(const char * str) { return escaped; } -static void output_csv(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, std::vector> pcmf32s) { +static void output_csv(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, const std::vector> & pcmf32s) { const int n_segments = whisper_full_n_segments(ctx); + const bool has_speaker_column = use_model_diarization(params) || use_stereo_diarization(params, pcmf32s); fout << "start,end,"; - if (params.diarize && pcmf32s.size() == 2) - { + if (has_speaker_column) { fout << "speaker,"; } fout << "text\n"; @@ -585,9 +623,8 @@ static void output_csv(struct whisper_context * ctx, std::ofstream & fout, const //need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds. fout << 10 * t0 << "," << 10 * t1 << ","; - if (params.diarize && pcmf32s.size() == 2) - { - fout << estimate_diarization_speaker(pcmf32s, t0, t1, true) << ","; + if (has_speaker_column) { + fout << get_segment_speaker_id(ctx, params, pcmf32s, i) << ","; } fout << "\"" << text_escaped << "\"\n"; } @@ -612,7 +649,7 @@ static void output_json( struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, - std::vector> pcmf32s) { + const std::vector> & pcmf32s) { const bool full = params.output_jsn_full; int indent = 0; @@ -727,13 +764,15 @@ static void output_json( const int n_segments = whisper_full_n_segments(ctx); for (int i = 0; i < n_segments; ++i) { const char * text = whisper_full_get_segment_text(ctx, i); + const std::string speaker = get_segment_speaker_id(ctx, params, pcmf32s, i); + const bool has_speaker = !speaker.empty(); const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i); start_obj(nullptr); times_o(t0, t1, false); - value_s("text", text, !params.diarize && !params.tinydiarize && !full); + value_s("text", text, !has_speaker && !params.tinydiarize && !full); if (full) { start_arr("tokens"); @@ -751,11 +790,11 @@ static void output_json( value_f("t_dtw", token.t_dtw, true); end_obj(j == (n - 1)); } - end_arr(!params.diarize && !params.tinydiarize); + end_arr(!has_speaker && !params.tinydiarize); } - if (params.diarize && pcmf32s.size() == 2) { - value_s("speaker", estimate_diarization_speaker(pcmf32s, t0, t1, true).c_str(), true); + if (has_speaker) { + value_s("speaker", speaker.c_str(), !params.tinydiarize); } if (params.tinydiarize) { @@ -771,7 +810,7 @@ static void output_json( // karaoke video generation // outputs a bash script that uses ffmpeg to generate a video with the subtitles // TODO: font parameter adjustments -static bool output_wts(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, std::vector> pcmf32s, const char * fname_inp, float t_sec, const char * fname_out) { +static bool output_wts(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, const std::vector> & pcmf32s, const char * fname_inp, float t_sec, const char * fname_out) { static const char * font = params.font_path.c_str(); std::ifstream fin(font); @@ -804,11 +843,7 @@ static bool output_wts(struct whisper_context * ctx, std::ofstream & fout, const fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'"; bool is_first = true; - std::string speaker = ""; - - if (params.diarize && pcmf32s.size() == 2) { - speaker = estimate_diarization_speaker(pcmf32s, t0, t1); - } + const std::string speaker = format_segment_speaker_label(get_segment_speaker_id(ctx, params, pcmf32s, i)); for (int j = 0; j < n; ++j) { const auto & token = tokens[j]; @@ -821,7 +856,7 @@ static bool output_wts(struct whisper_context * ctx, std::ofstream & fout, const std::string txt_fg = ""; // highlight token std::string txt_ul = ""; // underline - if (params.diarize && pcmf32s.size() == 2) { + if (!speaker.empty()) { txt_bg = speaker; txt_fg = speaker; txt_ul = "\\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ "; @@ -892,7 +927,7 @@ static bool output_wts(struct whisper_context * ctx, std::ofstream & fout, const return true; } -static void output_lrc(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, std::vector> pcmf32s) { +static void output_lrc(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, const std::vector> & pcmf32s) { fout << "[by:whisper.cpp]\n"; const int n_segments = whisper_full_n_segments(ctx); @@ -909,14 +944,7 @@ static void output_lrc(struct whisper_context * ctx, std::ofstream & fout, const char buf[16]; snprintf(buf, sizeof(buf), "%02d:%02d.%02d", (int) min, (int) sec, (int) ( msec / 10)); std::string timestamp_lrc = std::string(buf); - std::string speaker = ""; - - if (params.diarize && pcmf32s.size() == 2) - { - const int64_t t0 = whisper_full_get_segment_t0(ctx, i); - const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - speaker = estimate_diarization_speaker(pcmf32s, t0, t1); - } + const std::string speaker = format_segment_speaker_label(get_segment_speaker_id(ctx, params, pcmf32s, i)); fout << '[' << timestamp_lrc << ']' << speaker << text << "\n"; } @@ -1002,6 +1030,12 @@ int main(int argc, char ** argv) { exit(0); } + if (params.diarize_speakers < 0) { + fprintf(stderr, "error: --diarize-speakers must be >= 0\n"); + whisper_print_usage(argc, argv, params); + exit(0); + } + if (params.no_prints) { whisper_log_set(cb_log_disable, NULL); } @@ -1122,8 +1156,9 @@ int main(int argc, char ** argv) { std::vector pcmf32; // mono-channel F32 PCM std::vector> pcmf32s; // stereo-channel F32 PCM + const bool needs_stereo_diarization = params.diarize && params.diarize_model.empty(); - if (!::read_audio_data(fname_inp, pcmf32, pcmf32s, params.diarize)) { + if (!::read_audio_data(fname_inp, pcmf32, pcmf32s, needs_stereo_diarization)) { fprintf(stderr, "error: failed to read audio file '%s'\n", fname_inp.c_str()); continue; } @@ -1152,7 +1187,7 @@ int main(int argc, char ** argv) { params.n_threads, params.n_processors, params.beam_size, params.best_of, params.language.c_str(), params.translate ? "translate" : "transcribe", - params.tinydiarize ? "tdrz = 1, " : "", + params.tinydiarize ? "tdrz = 1, " : (params.diarize ? "diarize = 1, " : ""), params.no_timestamps ? 0 : 1); if (params.print_colors) { @@ -1191,6 +1226,10 @@ int main(int argc, char ** argv) { wparams.debug_mode = params.debug_mode; wparams.tdrz_enable = params.tinydiarize; // [TDRZ] + wparams.diarize = use_model_diarization(params); + wparams.diarize_model_path = params.diarize_model.empty() ? nullptr : params.diarize_model.c_str(); + wparams.diarize_threshold = params.diarize_threshold; + wparams.diarize_speakers = params.diarize_speakers; wparams.suppress_regex = params.suppress_regex.empty() ? nullptr : params.suppress_regex.c_str(); diff --git a/examples/server/README.md b/examples/server/README.md index ffba5f4edf5..13622ef126e 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -29,7 +29,10 @@ options: -lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail -debug, --debug-mode [false ] enable debug mode (eg. dump log_mel) -tr, --translate [false ] translate from source language to english - -di, --diarize [false ] stereo audio diarization + -di, --diarize [false ] enable speaker diarization + --diarize-model FNAME [ ] speaker embedding model path (GGML .bin) + --diarize-threshold N [0.50 ] clustering distance threshold + --diarize-speakers N [0 ] target speaker count (0 = auto) -tdrz, --tinydiarize [false ] enable tinydiarize (requires a tdrz model) -nf, --no-fallback [false ] do not use temperature fallback while decoding -ps, --print-special [false ] print special tokens @@ -69,6 +72,10 @@ Voice Activity Detection (VAD) options: > [!WARNING] > **Do not run the server example with administrative privileges and ensure it's operated in a sandbox environment, especially since it involves risky operations like accepting user file uploads and using ffmpeg for format conversions. Always validate and sanitize inputs to guard against potential security threats.** +When using diarization over HTTP, `diarize_model` is a form field whose value is a +path on the server host, not an uploaded model file. The speaker embedding model +must already exist on the machine running `whisper-server`. + ## request examples **/inference** @@ -78,6 +85,8 @@ curl 127.0.0.1:8080/inference \ -F file="@" \ -F temperature="0.0" \ -F temperature_inc="0.2" \ +-F diarize="true" \ +-F diarize_model="/absolute/path/on/server/models/ggml-speaker-ecapa-tdnn.bin" \ -F response_format="json" ``` diff --git a/examples/server/server.cpp b/examples/server/server.cpp index f6a7a83181a..9e50b523cc5 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -112,9 +112,13 @@ struct whisper_params { std::string prompt = ""; std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf"; std::string model = "models/ggml-base.en.bin"; + std::string diarize_model = ""; std::string response_format = json_format; + float diarize_threshold = 0.5f; + int diarize_speakers = 0; + // [TDRZ] speaker turn string std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line @@ -155,7 +159,10 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); - fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); + fprintf(stderr, " -di, --diarize [%-7s] enable speaker diarization\n", params.diarize ? "true" : "false"); + fprintf(stderr, " --diarize-model FNAME [%-7s] speaker embedding model path (GGML .bin)\n", params.diarize_model.c_str()); + fprintf(stderr, " --diarize-threshold N [%-7.2f] clustering distance threshold\n", params.diarize_threshold); + fprintf(stderr, " --diarize-speakers N [%-7d] target speaker count (0 = auto)\n", params.diarize_speakers); fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false"); fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); @@ -199,6 +206,11 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, "\n"); } +static char * requires_value_error(const std::string & arg) { + fprintf(stderr, "error: argument %s requires value\n", arg.c_str()); + exit(0); +} + bool whisper_params_parse(int argc, char ** argv, whisper_params & params, server_params & sparams) { if (const char * env_device = std::getenv("WHISPER_ARG_DEVICE")) { params.gpu_device = std::stoi(env_device); @@ -211,63 +223,67 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve whisper_print_usage(argc, argv, params, sparams); exit(0); } - else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } - else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); } - else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); } - else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); } - else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); } - else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); } - else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); } - else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); } - else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); } - else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } - else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } - else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); } - else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } + #define ARGV_NEXT (((i + 1) < argc) ? argv[++i] : requires_value_error(arg)) + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(ARGV_NEXT); } + else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(ARGV_NEXT); } + else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(ARGV_NEXT); } + else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(ARGV_NEXT); } + else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(ARGV_NEXT); } + else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(ARGV_NEXT); } + else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(ARGV_NEXT); } + else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(ARGV_NEXT); } + else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(ARGV_NEXT); } + else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(ARGV_NEXT); } + else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(ARGV_NEXT); } + else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(ARGV_NEXT); } + else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(ARGV_NEXT); } else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } + else if ( arg == "--diarize-model") { params.diarize_model = ARGV_NEXT; params.diarize = true; } + else if ( arg == "--diarize-threshold"){ params.diarize_threshold = std::stof(ARGV_NEXT); } + else if ( arg == "--diarize-speakers") { params.diarize_speakers = std::stoi(ARGV_NEXT); } else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; } else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; } else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } - else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; } + else if (arg == "-fp" || arg == "--font-path") { params.font_path = ARGV_NEXT; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } else if (arg == "-pr" || arg == "--print-realtime") { params.print_realtime = true; } else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } - else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } + else if (arg == "-l" || arg == "--language") { params.language = ARGV_NEXT; } else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; } - else if ( arg == "--prompt") { params.prompt = argv[++i]; } - else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } - else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; } - else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; } + else if ( arg == "--prompt") { params.prompt = ARGV_NEXT; } + else if (arg == "-m" || arg == "--model") { params.model = ARGV_NEXT; } + else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = ARGV_NEXT; } + else if (arg == "-dtw" || arg == "--dtw") { params.dtw = ARGV_NEXT; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } - else if (arg == "-dev" || arg == "--device") { params.gpu_device = std::stoi(argv[++i]); } + else if (arg == "-dev" || arg == "--device") { params.gpu_device = std::stoi(ARGV_NEXT); } else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; } else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; } - else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(argv[++i]); } + else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(ARGV_NEXT); } else if (arg == "-nlp" || arg == "--no-language-probabilities") { params.no_language_probabilities = true; } // server params - else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); } - else if ( arg == "--host") { sparams.hostname = argv[++i]; } - else if ( arg == "--public") { sparams.public_path = argv[++i]; } - else if ( arg == "--request-path") { sparams.request_path = argv[++i]; } - else if ( arg == "--inference-path") { sparams.inference_path = argv[++i]; } + else if ( arg == "--port") { sparams.port = std::stoi(ARGV_NEXT); } + else if ( arg == "--host") { sparams.hostname = ARGV_NEXT; } + else if ( arg == "--public") { sparams.public_path = ARGV_NEXT; } + else if ( arg == "--request-path") { sparams.request_path = ARGV_NEXT; } + else if ( arg == "--inference-path") { sparams.inference_path = ARGV_NEXT; } else if ( arg == "--convert") { sparams.ffmpeg_converter = true; } - else if ( arg == "--tmp-dir") { sparams.tmp_dir = argv[++i]; } + else if ( arg == "--tmp-dir") { sparams.tmp_dir = ARGV_NEXT; } // Voice Activity Detection (VAD) else if ( arg == "--vad") { params.vad = true; } - else if (arg == "-vm" || arg == "--vad-model") { params.vad_model = argv[++i]; } - else if (arg == "-vt" || arg == "--vad-threshold") { params.vad_threshold = std::stof(argv[++i]); } - else if (arg == "-vspd" || arg == "--vad-min-speech-duration-ms") { params.vad_min_speech_duration_ms = std::stoi(argv[++i]); } - else if (arg == "-vsd" || arg == "--vad-min-silence-duration-ms") { params.vad_min_silence_duration_ms = std::stoi(argv[++i]); } - else if (arg == "-vmsd" || arg == "--vad-max-speech-duration-s") { params.vad_max_speech_duration_s = std::stof(argv[++i]); } - else if (arg == "-vp" || arg == "--vad-speech-pad-ms") { params.vad_speech_pad_ms = std::stoi(argv[++i]); } - else if (arg == "-vo" || arg == "--vad-samples-overlap") { params.vad_samples_overlap = std::stof(argv[++i]); } + else if (arg == "-vm" || arg == "--vad-model") { params.vad_model = ARGV_NEXT; } + else if (arg == "-vt" || arg == "--vad-threshold") { params.vad_threshold = std::stof(ARGV_NEXT); } + else if (arg == "-vspd" || arg == "--vad-min-speech-duration-ms") { params.vad_min_speech_duration_ms = std::stoi(ARGV_NEXT); } + else if (arg == "-vsd" || arg == "--vad-min-silence-duration-ms") { params.vad_min_silence_duration_ms = std::stoi(ARGV_NEXT); } + else if (arg == "-vmsd" || arg == "--vad-max-speech-duration-s") { params.vad_max_speech_duration_s = std::stof(ARGV_NEXT); } + else if (arg == "-vp" || arg == "--vad-speech-pad-ms") { params.vad_speech_pad_ms = std::stoi(ARGV_NEXT); } + else if (arg == "-vo" || arg == "--vad-samples-overlap") { params.vad_samples_overlap = std::stof(ARGV_NEXT); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); whisper_print_usage(argc, argv, params, sparams); @@ -344,7 +360,17 @@ bool convert_to_wav(const std::string & temp_filename, std::string & error_resp) return true; } -std::string estimate_diarization_speaker(std::vector> pcmf32s, int64_t t0, int64_t t1, bool id_only = false) { +static bool use_model_diarization(const whisper_params & params) { + return params.diarize && !params.diarize_model.empty(); +} + +static bool use_stereo_diarization( + const whisper_params & params, + const std::vector> & pcmf32s) { + return params.diarize && params.diarize_model.empty() && pcmf32s.size() == 2; +} + +std::string estimate_diarization_speaker(const std::vector> & pcmf32s, int64_t t0, int64_t t1, bool id_only = false) { std::string speaker = ""; const int64_t n_samples = pcmf32s[0].size(); @@ -377,6 +403,45 @@ std::string estimate_diarization_speaker(std::vector> pcmf32s return speaker; } +static std::string get_segment_speaker_id( + struct whisper_context * ctx, + const whisper_params & params, + const std::vector> & pcmf32s, + int i_segment) { + if (!params.diarize) { + return ""; + } + + if (use_model_diarization(params)) { + const int speaker_id = whisper_full_get_segment_speaker_id(ctx, i_segment); + return speaker_id >= 0 ? std::to_string(speaker_id) : ""; + } + + if (use_stereo_diarization(params, pcmf32s)) { + const int64_t t0 = whisper_full_get_segment_t0(ctx, i_segment); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i_segment); + return estimate_diarization_speaker(pcmf32s, t0, t1, true); + } + + return ""; +} + +static std::string format_segment_speaker_label(const std::string & speaker_id) { + if (speaker_id.empty()) { + return ""; + } + + return "(speaker " + speaker_id + ")"; +} + +static std::string format_segment_speaker_vtt(const std::string & speaker_id) { + if (speaker_id.empty()) { + return ""; + } + + return ""; +} + void whisper_print_progress_callback(struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, int progress, void * user_data) { int progress_step = ((whisper_print_user_data *) user_data)->params->progress_step; int * progress_prev = &(((whisper_print_user_data *) user_data)->progress_prev); @@ -414,9 +479,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str()); } - if (params.diarize && pcmf32s.size() == 2) { - speaker = estimate_diarization_speaker(pcmf32s, t0, t1); - } + speaker = format_segment_speaker_label(get_segment_speaker_id(ctx, params, pcmf32s, i)); if (params.print_colors) { for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { @@ -454,19 +517,12 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper } } -std::string output_str(struct whisper_context * ctx, const whisper_params & params, std::vector> pcmf32s) { +std::string output_str(struct whisper_context * ctx, const whisper_params & params, const std::vector> & pcmf32s) { std::stringstream result; const int n_segments = whisper_full_n_segments(ctx); for (int i = 0; i < n_segments; ++i) { const char * text = whisper_full_get_segment_text(ctx, i); - std::string speaker = ""; - - if (params.diarize && pcmf32s.size() == 2) - { - const int64_t t0 = whisper_full_get_segment_t0(ctx, i); - const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - speaker = estimate_diarization_speaker(pcmf32s, t0, t1); - } + const std::string speaker = format_segment_speaker_label(get_segment_speaker_id(ctx, params, pcmf32s, i)); result << speaker << text << "\n"; } @@ -538,6 +594,19 @@ void get_req_parameters(const Request & req, whisper_params & params) { params.diarize = parse_str_to_bool(req.get_file_value("diarize").content); } + if (req.has_file("diarize_model")) + { + params.diarize_model = req.get_file_value("diarize_model").content; + params.diarize = true; + } + if (req.has_file("diarize_threshold")) + { + params.diarize_threshold = std::stof(req.get_file_value("diarize_threshold").content); + } + if (req.has_file("diarize_speakers")) + { + params.diarize_speakers = std::stoi(req.get_file_value("diarize_speakers").content); + } if (req.has_file("tinydiarize")) { params.tinydiarize = parse_str_to_bool(req.get_file_value("tinydiarize").content); @@ -643,6 +712,12 @@ int main(int argc, char ** argv) { exit(0); } + if (params.diarize_speakers < 0) { + fprintf(stderr, "error: --diarize-speakers must be >= 0\n"); + whisper_print_usage(argc, argv, params, sparams); + exit(0); + } + if (sparams.ffmpeg_converter) { check_ffmpeg_availibility(); } @@ -826,6 +901,7 @@ int main(int argc, char ** argv) { // audio arrays std::vector pcmf32; // mono-channel F32 PCM std::vector> pcmf32s; // stereo-channel F32 PCM + const bool needs_stereo_diarization = params.diarize && params.diarize_model.empty(); if (sparams.ffmpeg_converter) { // if file is not wav, convert to wav @@ -844,7 +920,7 @@ int main(int argc, char ** argv) { } // read audio content into pcmf32 - if (!::read_audio_data(temp_filename, pcmf32, pcmf32s, params.diarize)) + if (!::read_audio_data(temp_filename, pcmf32, pcmf32s, needs_stereo_diarization)) { fprintf(stderr, "error: failed to read WAV file '%s'\n", temp_filename.c_str()); const std::string error_resp = "{\"error\":\"failed to read WAV file\"}"; @@ -856,7 +932,7 @@ int main(int argc, char ** argv) { // remove temp file std::remove(temp_filename.c_str()); } else { - if (!::read_audio_data(audio_file.content, pcmf32, pcmf32s, params.diarize)) + if (!::read_audio_data(audio_file.content, pcmf32, pcmf32s, needs_stereo_diarization)) { fprintf(stderr, "error: failed to read audio data\n"); const std::string error_resp = "{\"error\":\"failed to read audio data\"}"; @@ -893,7 +969,7 @@ int main(int argc, char ** argv) { params.n_threads, params.n_processors, params.language.c_str(), params.translate ? "translate" : "transcribe", - params.tinydiarize ? "tdrz = 1, " : "", + params.tinydiarize ? "tdrz = 1, " : (params.diarize ? "diarize = 1, " : ""), params.no_timestamps ? 0 : 1); fprintf(stderr, "\n"); @@ -926,6 +1002,10 @@ int main(int argc, char ** argv) { wparams.debug_mode = params.debug_mode; wparams.tdrz_enable = params.tinydiarize; // [TDRZ] + wparams.diarize = use_model_diarization(params); + wparams.diarize_model_path = params.diarize_model.empty() ? nullptr : params.diarize_model.c_str(); + wparams.diarize_threshold = params.diarize_threshold; + wparams.diarize_speakers = params.diarize_speakers; wparams.initial_prompt = params.prompt.c_str(); @@ -1006,12 +1086,7 @@ int main(int argc, char ** argv) { const char * text = whisper_full_get_segment_text(ctx, i); const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - std::string speaker = ""; - - if (params.diarize && pcmf32s.size() == 2) - { - speaker = estimate_diarization_speaker(pcmf32s, t0, t1); - } + const std::string speaker = format_segment_speaker_label(get_segment_speaker_id(ctx, params, pcmf32s, i)); ss << i + 1 + params.offset_n << "\n"; ss << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n"; @@ -1028,14 +1103,7 @@ int main(int argc, char ** argv) { const char * text = whisper_full_get_segment_text(ctx, i); const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - std::string speaker = ""; - - if (params.diarize && pcmf32s.size() == 2) - { - speaker = estimate_diarization_speaker(pcmf32s, t0, t1, true); - speaker.insert(0, ""); - } + const std::string speaker = format_segment_speaker_vtt(get_segment_speaker_id(ctx, params, pcmf32s, i)); ss << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n"; ss << speaker << text << "\n\n"; @@ -1072,12 +1140,17 @@ int main(int argc, char ** argv) { {"id", i}, {"text", whisper_full_get_segment_text(ctx, i)}, }; + const std::string speaker = get_segment_speaker_id(ctx, params, pcmf32s, i); if (!params.no_timestamps) { segment["start"] = whisper_full_get_segment_t0(ctx, i) * 0.01; segment["end"] = whisper_full_get_segment_t1(ctx, i) * 0.01; } + if (!speaker.empty()) { + segment["speaker"] = speaker; + } + float total_logprob = 0; const int n_tokens = whisper_full_n_tokens(ctx, i); for (int j = 0; j < n_tokens; ++j) { diff --git a/include/whisper-speaker.h b/include/whisper-speaker.h index 42212a557b8..35d6a50ec17 100644 --- a/include/whisper-speaker.h +++ b/include/whisper-speaker.h @@ -2,7 +2,6 @@ #define WHISPER_SPEAKER_H #include "ggml.h" -#include #ifdef __cplusplus extern "C" { diff --git a/include/whisper.h b/include/whisper.h index 530484ae242..a9fded8ddb4 100644 --- a/include/whisper.h +++ b/include/whisper.h @@ -591,7 +591,7 @@ extern "C" { // Speaker diarization params bool diarize; // Enable speaker diarization (default: false) - const char * diarize_model_path; // Path to speaker embedding model file (GGUF format) + const char * diarize_model_path; // Path to speaker embedding model file (GGML .bin format) float diarize_threshold; // Distance threshold for clustering (default: 0.5f) int diarize_speakers; // Target speaker count; 0 = auto-detect (default: 0) }; diff --git a/models/convert-speaker-to-ggml.py b/models/convert-speaker-to-ggml.py index f6cac2713f9..16d13966040 100644 --- a/models/convert-speaker-to-ggml.py +++ b/models/convert-speaker-to-ggml.py @@ -33,42 +33,6 @@ import numpy as np from pathlib import Path -def fuse_batch_norm_weights(conv_weight, conv_bias, bn_mean, bn_var, bn_weight, bn_bias, eps=1e-5): - """ - Fuse BatchNorm into conv weights for inference. - - Args: - conv_weight: [out_c, in_c, kernel_size] - conv_bias: [out_c] - bn_mean, bn_var, bn_weight, bn_bias: [out_c] each - - Returns: - fused_weight: [out_c, in_c, kernel_size] - fused_bias: [out_c] - """ - # Convert to numpy if needed - if isinstance(conv_weight, torch.Tensor): - conv_weight = conv_weight.cpu().numpy() - if isinstance(conv_bias, torch.Tensor): - conv_bias = conv_bias.cpu().numpy() - if isinstance(bn_mean, torch.Tensor): - bn_mean = bn_mean.cpu().numpy() - if isinstance(bn_var, torch.Tensor): - bn_var = bn_var.cpu().numpy() - if isinstance(bn_weight, torch.Tensor): - bn_weight = bn_weight.cpu().numpy() - if isinstance(bn_bias, torch.Tensor): - bn_bias = bn_bias.cpu().numpy() - - # Fusion formula: W_fused = W * γ / sqrt(σ² + ε), b_fused = β + γ * (b - μ) / sqrt(σ² + ε) - scale = bn_weight / np.sqrt(bn_var + eps) # [out_c] - - # Broadcast scale across weight dimensions [out_c, in_c, kernel_size] - fused_weight = conv_weight.astype(np.float32) * scale[:, np.newaxis, np.newaxis] - fused_bias = bn_bias + scale * (conv_bias.astype(np.float32) - bn_mean) - - return fused_weight, fused_bias - def load_speaker_model(model_name: str, tmp_dir: str = None): """ Load SpeakerRecognition model from SpeechBrain. @@ -167,8 +131,8 @@ def convert_speaker_model(model_name: str, output_path: str, test_mode: bool = F print(f"Version: {version_major}.{version_minor}.{version_patch}") # Write hyperparameters - embedding_dim = 192 # ECAPA-TDNN output dimension - n_channels = 512 # Internal architecture parameter + embedding_dim = 192 # ECAPA-TDNN output dimension + n_channels = 1024 # Internal channel width fout.write(struct.pack('i', embedding_dim)) fout.write(struct.pack('i', n_channels)) print(f"Embedding dimension: {embedding_dim}") diff --git a/src/whisper-diarize.cpp b/src/whisper-diarize.cpp index 02f795c87ee..2bf34e1180d 100644 --- a/src/whisper-diarize.cpp +++ b/src/whisper-diarize.cpp @@ -6,6 +6,7 @@ #include #include #include +#include // Define logging macros for consistency with whisper.cpp #define WHISPER_LOG_ERROR(...) fprintf(stderr, "[ERROR] " __VA_ARGS__) @@ -25,17 +26,13 @@ // FFT constants #define FFT_SIZE 512 // Next power of 2 for efficiency -static float g_hann_window[WHISPER_N_FFT] = {0}; -static int g_hann_computed = 0; +static float g_hamming_window[WHISPER_N_FFT] = {0}; +static std::once_flag g_hamming_once; -static void compute_hann_window() { - if (g_hann_computed) return; - - // Hamming window (matching SpeechBrain pretrained model) +static void compute_hamming_window() { for (int i = 0; i < WHISPER_N_FFT; i++) { - g_hann_window[i] = 0.54f - 0.46f * cosf(2.0f * M_PI * i / (WHISPER_N_FFT - 1)); + g_hamming_window[i] = 0.54f - 0.46f * cosf(2.0f * M_PI * i / (WHISPER_N_FFT - 1)); } - g_hann_computed = 1; } // Cooley-Tukey radix-2 FFT (complex-to-complex, in-place) @@ -128,12 +125,25 @@ static float * create_mel_filters(int n_mels, int n_fft, int sample_rate) { } // Compute 80-bin mel-spectrogram from PCM samples +static float * g_mel_filters = NULL; +static std::once_flag g_mel_filters_once; +static const int g_n_fft_bins = 1 + WHISPER_N_FFT / 2; // 201 bins for 400-point DFT + +static void init_mel_filters() { + g_mel_filters = create_mel_filters(MEL_N_BINS, g_n_fft_bins, WHISPER_SAMPLE_RATE); +} + float * whisper_compute_mel_80(const float * samples, int n_samples) { if (!samples || n_samples <= 0) { return NULL; } - compute_hann_window(); + std::call_once(g_hamming_once, compute_hamming_window); + std::call_once(g_mel_filters_once, init_mel_filters); + + if (!g_mel_filters) { + return NULL; + } // Center padding: add n_fft/2 samples on both sides int pad = WHISPER_N_FFT / 2; // 200 samples @@ -157,20 +167,18 @@ float * whisper_compute_mel_80(const float * samples, int n_samples) { // Allocate output mel array [n_frames, 80] float * mel = (float *)calloc(n_frames * MEL_N_BINS, sizeof(float)); if (!mel) { + free(padded_samples); return NULL; } - // Create mel filterbank once (n_fft=400 → 201 bins) - static float * mel_filters = NULL; - static int mel_filters_initialized = 0; - int n_fft_bins = 1 + WHISPER_N_FFT / 2; // 201 bins for 400-point DFT - if (!mel_filters_initialized) { - mel_filters = create_mel_filters(MEL_N_BINS, n_fft_bins, WHISPER_SAMPLE_RATE); - mel_filters_initialized = 1; - } - - if (!mel_filters) { + // Allocate FFT and magnitude buffers once for the loop + float * fft_buf = (float *)malloc(FFT_SIZE * 2 * sizeof(float)); + float * mag = (float *)malloc(g_n_fft_bins * sizeof(float)); + if (!fft_buf || !mag) { + free(fft_buf); + free(mag); free(mel); + free(padded_samples); return NULL; } @@ -179,12 +187,10 @@ float * whisper_compute_mel_80(const float * samples, int n_samples) { int offset = t * WHISPER_HOP_LENGTH; // Extract frame, apply Hamming window, zero-pad to 512 as complex interleaved - float * fft_buf = (float *)calloc(FFT_SIZE * 2, sizeof(float)); + memset(fft_buf, 0, FFT_SIZE * 2 * sizeof(float)); for (int i = 0; i < WHISPER_N_FFT; i++) { - fft_buf[2*i] = padded_samples[offset + i] * g_hann_window[i]; - // imaginary = 0 (calloc) + fft_buf[2*i] = padded_samples[offset + i] * g_hamming_window[i]; } - // positions WHISPER_N_FFT..FFT_SIZE-1 are zero-padded (calloc) // In-place 512-point FFT fft_radix2(fft_buf, FFT_SIZE); @@ -195,8 +201,7 @@ float * whisper_compute_mel_80(const float * samples, int n_samples) { // Map: for target bin j (400-point), find 512-point bin at same frequency // freq_j = j * sr / 400 → k_512 = j * 512 / 400 = j * 1.28 // Use linear interpolation between adjacent 512-point bins - float * mag = (float *)malloc(n_fft_bins * sizeof(float)); - for (int j = 0; j < n_fft_bins; j++) { + for (int j = 0; j < g_n_fft_bins; j++) { float k_f = j * (float)FFT_SIZE / WHISPER_N_FFT; // fractional 512-bin index int k0 = (int)k_f; float frac = k_f - k0; @@ -214,18 +219,18 @@ float * whisper_compute_mel_80(const float * samples, int n_samples) { // Apply mel filterbank for (int m = 0; m < MEL_N_BINS; m++) { float mel_val = 0.0f; - for (int k = 0; k < n_fft_bins; k++) { - mel_val += mag[k] * mel_filters[m * n_fft_bins + k]; + for (int k = 0; k < g_n_fft_bins; k++) { + mel_val += mag[k] * g_mel_filters[m * g_n_fft_bins + k]; } // dB scale: 10 * log10(max(x, 1e-10)) mel[t * MEL_N_BINS + m] = 10.0f * log10f(fmaxf(mel_val, 1e-10f)); } - - free(fft_buf); - free(mag); } + free(fft_buf); + free(mag); + // top_db clipping: clamp to (max_db - 80) float max_db = -1e30f; for (int i = 0; i < n_frames * MEL_N_BINS; i++) { @@ -259,16 +264,6 @@ void whisper_mel_free(float * mel) { // Speaker encoder forward pass -static struct ggml_tensor * apply_simple_norm( - struct ggml_context * ctx, - struct ggml_tensor * x) { - if (!x) { - WHISPER_LOG_ERROR("apply_simple_norm: NULL input tensor\n"); - return x; - } - return ggml_norm(ctx, x, 1e-5f); -} - // Reshape conv1d output to 4D for broadcasting static struct ggml_tensor * ensure_4d_from_conv1d(struct ggml_context * ctx, struct ggml_tensor * t) { // Conv1d outputs 3D: [OW, OC, batch] @@ -391,9 +386,9 @@ struct whisper_speaker_encoder * whisper_speaker_encoder_new( // Dynamic context size: base 200MB + ~0.5MB per frame for intermediate tensors size_t ctx_bytes = (size_t)200 * 1024 * 1024 + (size_t)n_frames * 512 * 1024; struct ggml_init_params params = { - .mem_size = ctx_bytes, - .mem_buffer = NULL, - .no_alloc = false, + ctx_bytes, + NULL, + false, }; encoder->ctx = ggml_init(params); diff --git a/src/whisper-speaker.cpp b/src/whisper-speaker.cpp index 28c6c4c7d0d..623fa202f17 100644 --- a/src/whisper-speaker.cpp +++ b/src/whisper-speaker.cpp @@ -5,6 +5,10 @@ #include #include +#define WHISPER_LOG_ERROR(...) fprintf(stderr, "[ERROR] " __VA_ARGS__) +#define WHISPER_LOG_WARN(...) fprintf(stderr, "[WARN] " __VA_ARGS__) +#define WHISPER_LOG_INFO(...) fprintf(stderr, "[INFO] " __VA_ARGS__) + struct whisper_speaker_model { struct ggml_context * ctx; std::vector tensors; @@ -29,95 +33,89 @@ whisper_speaker_model * whisper_speaker_load_from_file(const char * path_model) return nullptr; } if (magic != 0x67676d6c) { // "ggml" - fprintf(stderr, "Invalid GGML magic: 0x%x (expected 0x67676d6c)\n", magic); + WHISPER_LOG_ERROR("invalid GGML magic: 0x%x (expected 0x67676d6c)\n", magic); fclose(fin); return nullptr; } - printf("GGML magic valid: 0x%08x\n", magic); // Read model type string (length-prefixed UTF-8) int str_len; if (fread(&str_len, sizeof(str_len), 1, fin) != 1) { - fprintf(stderr, "Failed to read model type length\n"); + WHISPER_LOG_ERROR("failed to read model type length\n"); fclose(fin); return nullptr; } if (str_len < 0 || str_len > 256) { - fprintf(stderr, "Invalid model type length: %d\n", str_len); + WHISPER_LOG_ERROR("invalid model type length: %d\n", str_len); fclose(fin); return nullptr; } char model_type[257]; if (fread(model_type, str_len, 1, fin) != 1) { - fprintf(stderr, "Failed to read model type\n"); + WHISPER_LOG_ERROR("failed to read model type\n"); fclose(fin); return nullptr; } model_type[str_len] = '\0'; - printf("Model type: %s\n", model_type); + WHISPER_LOG_INFO("speaker model type: %s\n", model_type); // Read version (major, minor, patch) int major, minor, patch; if (fread(&major, sizeof(major), 1, fin) != 1 || fread(&minor, sizeof(minor), 1, fin) != 1 || fread(&patch, sizeof(patch), 1, fin) != 1) { - fprintf(stderr, "Failed to read version\n"); + WHISPER_LOG_ERROR("failed to read version\n"); fclose(fin); return nullptr; } - printf("Version: %d.%d.%d\n", major, minor, patch); + WHISPER_LOG_INFO("speaker model version: %d.%d.%d\n", major, minor, patch); // Read hyperparameters int embedding_dim; if (fread(&embedding_dim, sizeof(embedding_dim), 1, fin) != 1) { - fprintf(stderr, "Failed to read embedding_dim\n"); + WHISPER_LOG_ERROR("failed to read embedding_dim\n"); fclose(fin); return nullptr; } - printf("Embedding dimension: %d\n", embedding_dim); int n_channels; if (fread(&n_channels, sizeof(n_channels), 1, fin) != 1) { - fprintf(stderr, "Failed to read n_channels\n"); + WHISPER_LOG_ERROR("failed to read n_channels\n"); fclose(fin); return nullptr; } - printf("Internal channels: %d\n", n_channels); // Read tensor count (for verification) int n_tensors_expected; if (fread(&n_tensors_expected, sizeof(n_tensors_expected), 1, fin) != 1) { - fprintf(stderr, "Failed to read tensor count\n"); + WHISPER_LOG_ERROR("failed to read tensor count\n"); fclose(fin); return nullptr; } - printf("Expected tensors: %d\n", n_tensors_expected); - // Create ggml context with sufficient memory for all tensors - // ~500 MB should accommodate ECAPA-TDNN model weights - size_t ctx_size = 500 * 1024 * 1024; // 500 MB + WHISPER_LOG_INFO("speaker model: embedding_dim=%d, n_tensors=%d\n", embedding_dim, n_tensors_expected); + + // Estimate context size from file: file_size + overhead for tensor metadata + long cur_pos = ftell(fin); + fseek(fin, 0, SEEK_END); + long file_size = ftell(fin); + fseek(fin, cur_pos, SEEK_SET); + size_t ctx_size = (size_t)(file_size - cur_pos) + (size_t)n_tensors_expected * 1024 + 16 * 1024 * 1024; + struct ggml_init_params ggml_params = { - .mem_size = ctx_size, - .mem_buffer = malloc(ctx_size), - .no_alloc = false, + ctx_size, + nullptr, + false, }; - if (!ggml_params.mem_buffer) { - fprintf(stderr, "Failed to allocate GGML context buffer\n"); - fclose(fin); - return nullptr; - } - struct ggml_context * ctx = ggml_init(ggml_params); if (!ctx) { - fprintf(stderr, "Failed to create ggml context\n"); - free(ggml_params.mem_buffer); + WHISPER_LOG_ERROR("failed to create ggml context (%zu bytes)\n", ctx_size); fclose(fin); return nullptr; } - printf("GGML context created (%zu bytes)\n", ctx_size); // Create speaker model structure whisper_speaker_model * model = new whisper_speaker_model(); @@ -125,30 +123,28 @@ whisper_speaker_model * whisper_speaker_load_from_file(const char * path_model) model->embedding_dim = embedding_dim; model->n_tensors = 0; - printf("\nLoading tensors:\n"); - // Load tensors for (int t = 0; t < n_tensors_expected; ++t) { // Read tensor header: n_dims, name_len int n_dims; if (fread(&n_dims, sizeof(n_dims), 1, fin) != 1) { - fprintf(stderr, "Failed to read n_dims for tensor %d\n", t); + WHISPER_LOG_ERROR("failed to read n_dims for tensor %d\n", t); break; } int name_len; if (fread(&name_len, sizeof(name_len), 1, fin) != 1) { - fprintf(stderr, "Failed to read name_len for tensor %d\n", t); + WHISPER_LOG_ERROR("failed to read name_len for tensor %d\n", t); break; } // Sanity checks if (n_dims < 0 || n_dims > 8) { - fprintf(stderr, "Invalid n_dims for tensor %d: %d\n", t, n_dims); + WHISPER_LOG_ERROR("invalid n_dims for tensor %d: %d\n", t, n_dims); break; } if (name_len < 0 || name_len > 512) { - fprintf(stderr, "Invalid name_len for tensor %d: %d\n", t, name_len); + WHISPER_LOG_ERROR("invalid name_len for tensor %d: %d\n", t, name_len); break; } @@ -157,7 +153,7 @@ whisper_speaker_model * whisper_speaker_load_from_file(const char * path_model) for (int i = 0; i < n_dims; ++i) { int dim; if (fread(&dim, sizeof(int), 1, fin) != 1) { - fprintf(stderr, "Failed to read dim %d for tensor %d\n", i, t); + WHISPER_LOG_ERROR("failed to read dim %d for tensor %d\n", i, t); break; } dims[i] = (int64_t)dim; @@ -166,7 +162,7 @@ whisper_speaker_model * whisper_speaker_load_from_file(const char * path_model) // Read tensor name (not null-terminated in binary) char name[513]; if (fread(name, name_len, 1, fin) != 1) { - fprintf(stderr, "Failed to read tensor name for tensor %d\n", t); + WHISPER_LOG_ERROR("failed to read tensor name for tensor %d\n", t); break; } name[name_len] = '\0'; @@ -174,7 +170,7 @@ whisper_speaker_model * whisper_speaker_load_from_file(const char * path_model) // Create tensor in ggml context struct ggml_tensor * tensor = ggml_new_tensor(ctx, GGML_TYPE_F32, n_dims, dims); if (!tensor) { - fprintf(stderr, "Failed to create tensor: %s\n", name); + WHISPER_LOG_ERROR("failed to create tensor: %s\n", name); break; } @@ -182,7 +178,7 @@ whisper_speaker_model * whisper_speaker_load_from_file(const char * path_model) size_t nelements = ggml_nelements(tensor); size_t bytes_read = fread(tensor->data, sizeof(float), nelements, fin); if (bytes_read != nelements) { - fprintf(stderr, "Failed to read tensor data for %s: got %zu, expected %zu\n", + WHISPER_LOG_ERROR("failed to read tensor data for %s: got %zu, expected %zu\n", name, bytes_read, nelements); break; } @@ -191,20 +187,14 @@ whisper_speaker_model * whisper_speaker_load_from_file(const char * path_model) model->tensors.push_back(tensor); model->tensor_names.push_back(std::string(name)); - printf(" [%d] %s: ", t + 1, name); - for (int i = 0; i < n_dims; ++i) { - printf("%lld", (long long)dims[i]); - if (i < n_dims - 1) printf("x"); - } - printf(" (%zu elements, %.2f MB)\n", nelements, (nelements * sizeof(float)) / 1024.0 / 1024.0); - model->n_tensors++; } - printf("\nModel loaded: %d / %d tensors\n", model->n_tensors, n_tensors_expected); + WHISPER_LOG_INFO("speaker model loaded: %d / %d tensors (%.1f MB)\n", + model->n_tensors, n_tensors_expected, ctx_size / 1024.0 / 1024.0); if (model->n_tensors != n_tensors_expected) { - fprintf(stderr, "Warning: Loaded %d tensors but expected %d\n", model->n_tensors, n_tensors_expected); + WHISPER_LOG_WARN("loaded %d tensors but expected %d\n", model->n_tensors, n_tensors_expected); } fclose(fin); @@ -213,28 +203,19 @@ whisper_speaker_model * whisper_speaker_load_from_file(const char * path_model) void whisper_speaker_validate(whisper_speaker_model * model) { if (!model) { - fprintf(stderr, "Error: Model is nullptr\n"); + WHISPER_LOG_ERROR("speaker model is nullptr\n"); return; } - printf("\n=== Model Validation ===\n"); - printf("Embedding dimension: %d\n", model->embedding_dim); - printf("Total tensors loaded: %d\n", model->n_tensors); - - if (model->ctx) { - printf("Context allocated\n"); - } + WHISPER_LOG_INFO("speaker model: embedding_dim=%d, n_tensors=%d\n", + model->embedding_dim, model->n_tensors); - if (model->embedding_dim == 192) { - printf("Embedding dimension correct (192)\n"); - } else { - printf("WARNING: Embedding dimension unexpected: %d (expected 192)\n", model->embedding_dim); + if (model->embedding_dim != 192) { + WHISPER_LOG_WARN("unexpected embedding dimension: %d (expected 192)\n", model->embedding_dim); } - if (model->n_tensors > 0) { - printf("Model structure valid (%d tensors)\n", model->n_tensors); - } else { - printf("ERROR: No tensors loaded\n"); + if (model->n_tensors <= 0) { + WHISPER_LOG_ERROR("no tensors loaded in speaker model\n"); } } diff --git a/src/whisper.cpp b/src/whisper.cpp index ac04709d53a..842326d82fa 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -942,8 +942,8 @@ struct whisper_state { // Speaker diarization context struct whisper_speaker_model * diarize_model = nullptr; - struct whisper_speaker_encoder * diarize_encoder = nullptr; std::vector diarize_embeddings; + std::vector diarize_segment_indices; struct whisper_clustering_context * diarize_clustering = nullptr; }; @@ -964,6 +964,85 @@ struct whisper_context { std::string path_model; // populated by whisper_init_from_file_with_params() }; +#ifdef WHISPER_DIARIZE +static bool whisper_compute_diarization_embedding( + struct whisper_state * state, + const float * samples, + int n_samples, + int segment_index) { + if (state == nullptr || state->diarize_model == nullptr || samples == nullptr) { + return false; + } + + if (segment_index < 0 || segment_index >= (int) state->result_all.size()) { + return false; + } + + const whisper_segment & seg = state->result_all[segment_index]; + + int sample_start = (int) (seg.t0 * WHISPER_SAMPLE_RATE / 100); + int sample_end = (int) (seg.t1 * WHISPER_SAMPLE_RATE / 100); + + sample_start = std::max(sample_start, 0); + sample_end = std::min(sample_end, n_samples); + + const int seg_n_samples = sample_end - sample_start; + if (seg_n_samples <= 1600) { + return false; + } + + float * mel_data = whisper_compute_mel_80(samples + sample_start, seg_n_samples); + if (mel_data == nullptr) { + WHISPER_LOG_WARN("%s: failed to compute mel for segment %d\n", __func__, segment_index); + return false; + } + + const int mel_n_frames = whisper_get_mel_n_frames(seg_n_samples); + + struct whisper_speaker_encoder * enc = + whisper_speaker_encoder_new(state->diarize_model, mel_n_frames, 0); + if (enc == nullptr) { + whisper_mel_free(mel_data); + WHISPER_LOG_WARN("%s: failed to create speaker encoder for segment %d\n", __func__, segment_index); + return false; + } + + float embedding[192] = {0}; + const bool ok = whisper_speaker_encoder_compute(enc, mel_data, embedding); + + whisper_speaker_encoder_free(enc); + whisper_mel_free(mel_data); + + if (!ok) { + WHISPER_LOG_WARN("%s: failed to compute speaker embedding for segment %d\n", __func__, segment_index); + return false; + } + + state->diarize_embeddings.insert( + state->diarize_embeddings.end(), + embedding, + embedding + 192); + state->diarize_segment_indices.push_back(segment_index); + + return true; +} + +static void whisper_collect_diarization_embeddings( + struct whisper_state * state, + const float * samples, + int n_samples, + int n_new_segments) { + if (state == nullptr || state->diarize_model == nullptr || n_new_segments <= 0) { + return; + } + + const int start = std::max(0, (int) state->result_all.size() - n_new_segments); + for (int segment_index = start; segment_index < (int) state->result_all.size(); ++segment_index) { + whisper_compute_diarization_embedding(state, samples, n_samples, segment_index); + } +} +#endif + struct whisper_global { // We save the log callback globally ggml_log_callback log_callback = whisper_log_callback_default; @@ -3860,11 +3939,6 @@ void whisper_free_state(struct whisper_state * state) { state->diarize_model = nullptr; } - if (state->diarize_encoder) { - whisper_speaker_encoder_free(state->diarize_encoder); - state->diarize_encoder = nullptr; - } - if (state->diarize_clustering) { whisper_clustering_context_free(state->diarize_clustering); state->diarize_clustering = nullptr; @@ -7021,15 +7095,13 @@ int whisper_full_with_state( } } - // Free leftover encoder from previous calls - if (state->diarize_encoder) { - whisper_speaker_encoder_free(state->diarize_encoder); - state->diarize_encoder = nullptr; + if (state->diarize_clustering) { + whisper_clustering_context_free(state->diarize_clustering); + state->diarize_clustering = nullptr; } - if (params.diarize && state->diarize_model) { - state->diarize_embeddings.clear(); - } + state->diarize_embeddings.clear(); + state->diarize_segment_indices.clear(); #endif // WHISPER_DIARIZE int seek = seek_start; @@ -7714,39 +7786,9 @@ int whisper_full_with_state( } #ifdef WHISPER_DIARIZE - // Compute diarization embedding for this segment (if enabled) + // Compute diarization embeddings for all newly finalized segments. if (params.diarize && state->diarize_model) { - whisper_segment & seg = state->result_all.back(); - - // Extract PCM for segment and compute mel - int sample_start = (int)(seg.t0 * WHISPER_SAMPLE_RATE / 100); - int sample_end = (int)(seg.t1 * WHISPER_SAMPLE_RATE / 100); - if (sample_start < 0) sample_start = 0; - if (sample_end > n_samples) sample_end = n_samples; - int seg_n_samples = sample_end - sample_start; - - if (seg_n_samples > 1600) { // at least 0.1s - float * mel_data = whisper_compute_mel_80(samples + sample_start, seg_n_samples); - if (mel_data) { - int mel_n_frames = whisper_get_mel_n_frames(seg_n_samples); - if (mel_n_frames > 160) mel_n_frames = 160; // cap for memory - - // Create per-segment encoder with correct n_frames - whisper_speaker_encoder * enc = whisper_speaker_encoder_new( - state->diarize_model, mel_n_frames, 0); - - if (enc) { - float embedding[192] = {0}; - if (whisper_speaker_encoder_compute(enc, mel_data, embedding)) { - state->diarize_embeddings.insert( - state->diarize_embeddings.end(), - embedding, embedding + 192); - } - whisper_speaker_encoder_free(enc); - } - free(mel_data); - } - } + whisper_collect_diarization_embeddings(state, samples, n_samples, n_new); } #endif // WHISPER_DIARIZE } @@ -7796,41 +7838,9 @@ int whisper_full_with_state( } #ifdef WHISPER_DIARIZE - // Compute diarization embedding for this segment (if enabled) - if (params.diarize && state->diarize_encoder) { - whisper_segment & seg = state->result_all.back(); // Just-added segment - - // Extract mel frames for segment time range [seg.t0, seg.t1] - // Time is in centiseconds; frame index = time_centiseconds / 100 - int frame_start = seg.t0 / 100; - int frame_end = seg.t1 / 100; - int n_frames = frame_end - frame_start; - - if (n_frames > 0 && frame_start >= 0 && frame_end <= state->mel.n_len) { - // Get mel spectrogram for this segment - const float * mel_data = state->mel.data.data() + (frame_start * 80); - - // Allocate embedding buffer for this segment - float embedding[192]; // ECAPA-TDNN output size (192-dim) - memset(embedding, 0, sizeof(embedding)); - - // Compute speaker embedding - if (whisper_speaker_encoder_compute(state->diarize_encoder, mel_data, embedding)) { - // Store in buffer for later clustering - state->diarize_embeddings.insert( - state->diarize_embeddings.end(), - embedding, - embedding + 192 - ); - } else { - WHISPER_LOG_WARN("failed to compute speaker embedding for segment %lu\n", - state->result_all.size() - 1); - } - } else { - WHISPER_LOG_WARN("invalid mel frame range for segment %lu: [%d, %d), total frames: %d\n", - state->result_all.size() - 1, frame_start, frame_end, - state->mel.n_len); - } + // Compute diarization embeddings for all newly finalized segments. + if (params.diarize && state->diarize_model) { + whisper_collect_diarization_embeddings(state, samples, n_samples, n_new); } #endif // WHISPER_DIARIZE } @@ -7871,20 +7881,14 @@ int whisper_full_with_state( #ifdef WHISPER_DIARIZE // Perform speaker clustering if diarization enabled if (params.diarize && !state->diarize_embeddings.empty()) { - int num_segments = state->result_all.size(); - int num_embeddings = state->diarize_embeddings.size() / 192; + const int num_embeddings = state->diarize_embeddings.size() / 192; + const int num_indexed_segments = state->diarize_segment_indices.size(); - if (num_embeddings != num_segments) { - WHISPER_LOG_ERROR("embedding count (%d) != segment count (%d); skipping clustering\n", - num_embeddings, num_segments); + if (num_embeddings != num_indexed_segments) { + WHISPER_LOG_ERROR("embedding count (%d) != diarized segment count (%d); skipping clustering\n", + num_embeddings, num_indexed_segments); } else { - // Free previous clustering context if it exists - if (state->diarize_clustering) { - whisper_clustering_context_free(state->diarize_clustering); - state->diarize_clustering = nullptr; - } - - state->diarize_clustering = whisper_clustering_context_create(num_segments); + state->diarize_clustering = whisper_clustering_context_create(num_embeddings); if (!state->diarize_clustering) { WHISPER_LOG_ERROR("failed to create clustering context\n"); } else { @@ -7897,13 +7901,13 @@ int whisper_full_with_state( ); if (ret == 0) { - // Assign speaker IDs to segments - for (int i = 0; i < num_segments; ++i) { - state->result_all[i].speaker_id = + // Assign speaker IDs to the segments that produced embeddings. + for (int i = 0; i < num_embeddings; ++i) { + state->result_all[state->diarize_segment_indices[i]].speaker_id = state->diarize_clustering->speaker_ids[i]; } - WHISPER_LOG_INFO("diarization complete: %d speakers detected\n", - state->diarize_clustering->num_speakers); + WHISPER_LOG_INFO("diarization complete: %d speakers detected across %d segments\n", + state->diarize_clustering->num_speakers, num_embeddings); } else { WHISPER_LOG_ERROR("clustering failed with code %d\n", ret); } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 69caa521912..4f22937cbfd 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -111,70 +111,103 @@ target_compile_definitions(${VAD_TEST} PRIVATE add_test(NAME ${VAD_TEST} COMMAND ${VAD_TEST}) set_tests_properties(${VAD_TEST} PROPERTIES LABELS "base;en") -# Speaker embedding inference test -add_executable(test-speaker-embedding test-speaker-embedding.cpp ../src/whisper-diarize.cpp ../src/whisper-speaker.cpp) -target_include_directories(test-speaker-embedding PRIVATE ../include ../ggml/include) -target_link_libraries(test-speaker-embedding PRIVATE ggml m) - -# Embedding quality (speaker discrimination) -add_executable(test-embedding-quality test-embedding-quality.cpp ../src/whisper-diarize.cpp ../src/whisper-speaker.cpp) -target_include_directories(test-embedding-quality PRIVATE ../include ../ggml/include ../src) -target_link_libraries(test-embedding-quality PRIVATE ggml m) - -# Clustering algorithm test -add_executable(test-clustering test-clustering.cpp ../src/whisper-diarize.cpp ../src/whisper-speaker.cpp) -target_include_directories(test-clustering PRIVATE ../include ../ggml/include) -target_link_libraries(test-clustering PRIVATE ggml m) -target_compile_options(test-clustering PRIVATE -std=c++11 -Wall -Wextra) - -# Diarization integration test -add_executable(test-diarization-integration test-diarization-integration.cpp) -target_include_directories(test-diarization-integration PRIVATE ../include ../ggml/include) -target_link_libraries(test-diarization-integration PRIVATE whisper) - -# CLI diarization end-to-end test -add_executable(test-cli-diarization test-cli-diarization.cpp) -target_include_directories(test-cli-diarization PRIVATE ../include ../examples) -target_link_libraries(test-cli-diarization PRIVATE whisper common) -target_compile_options(test-cli-diarization PRIVATE -std=c++17) - -# DER benchmark -add_executable(benchmark-der benchmark-der.cpp) -target_include_directories(benchmark-der PRIVATE ../include ../ggml/include) -target_link_libraries(benchmark-der PRIVATE ggml m) -target_compile_options(benchmark-der PRIVATE -std=c++17) - -# AddressSanitizer support -if (WHISPER_SANITIZE_ADDRESS) - message(STATUS "Enabling AddressSanitizer for benchmark-der") - target_compile_options(benchmark-der PRIVATE -fsanitize=address -fno-omit-frame-pointer) - target_link_options(benchmark-der PRIVATE -fsanitize=address) -endif() - -# Register tests -enable_testing() -add_test(NAME speaker-embedding COMMAND test-speaker-embedding WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) -add_test(NAME clustering COMMAND test-clustering WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) -add_test(NAME diarization-integration COMMAND test-diarization-integration WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) -add_test(NAME cli-diarization COMMAND test-cli-diarization WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) +function(whisper_add_optional_test_executable TARGET_NAME SOURCE_FILE) + if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/${SOURCE_FILE}") + add_executable(${TARGET_NAME} ${SOURCE_FILE} ${ARGN}) + endif() +endfunction() -# Register benchmark-der test if (WHISPER_DIARIZE) - add_test(NAME benchmark-der COMMAND benchmark-der WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) - set_tests_properties(benchmark-der PROPERTIES LABELS "diarization;benchmark") + whisper_add_optional_test_executable( + test-speaker-embedding + test-speaker-embedding.cpp + ../src/whisper-diarize.cpp + ../src/whisper-speaker.cpp) + if (TARGET test-speaker-embedding) + target_include_directories(test-speaker-embedding PRIVATE ../include ../ggml/include) + target_link_libraries(test-speaker-embedding PRIVATE ggml m) + add_test(NAME speaker-embedding COMMAND test-speaker-embedding WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + endif() - # ASan leak detection (Linux only) - if (WHISPER_SANITIZE_ADDRESS AND CMAKE_SYSTEM_NAME STREQUAL "Linux") - set_tests_properties(benchmark-der PROPERTIES ENVIRONMENT "ASAN_OPTIONS=detect_leaks=1:strip_path_prefix=${CMAKE_SOURCE_DIR}/") + whisper_add_optional_test_executable( + test-embedding-quality + test-embedding-quality.cpp + ../src/whisper-diarize.cpp + ../src/whisper-speaker.cpp) + if (TARGET test-embedding-quality) + target_include_directories(test-embedding-quality PRIVATE ../include ../ggml/include ../src) + target_link_libraries(test-embedding-quality PRIVATE ggml m) + endif() + + whisper_add_optional_test_executable( + test-clustering + test-clustering.cpp + ../src/whisper-diarize.cpp + ../src/whisper-speaker.cpp) + if (TARGET test-clustering) + target_include_directories(test-clustering PRIVATE ../include ../ggml/include) + target_link_libraries(test-clustering PRIVATE ggml m) + target_compile_options(test-clustering PRIVATE -std=c++11 -Wall -Wextra) + add_test(NAME clustering COMMAND test-clustering WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + endif() + + whisper_add_optional_test_executable( + test-diarization-integration + test-diarization-integration.cpp) + if (TARGET test-diarization-integration) + target_include_directories(test-diarization-integration PRIVATE ../include ../ggml/include) + target_link_libraries(test-diarization-integration PRIVATE whisper) + add_test(NAME diarization-integration COMMAND test-diarization-integration WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + endif() + + whisper_add_optional_test_executable( + test-cli-diarization + test-cli-diarization.cpp) + if (TARGET test-cli-diarization) + target_include_directories(test-cli-diarization PRIVATE ../include ../examples) + target_link_libraries(test-cli-diarization PRIVATE whisper common) + target_compile_options(test-cli-diarization PRIVATE -std=c++17) + add_test(NAME cli-diarization COMMAND test-cli-diarization WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + endif() + + whisper_add_optional_test_executable( + benchmark-der + benchmark-der.cpp) + if (TARGET benchmark-der) + target_include_directories(benchmark-der PRIVATE ../include ../ggml/include) + target_link_libraries(benchmark-der PRIVATE ggml m) + target_compile_options(benchmark-der PRIVATE -std=c++17) + add_test(NAME benchmark-der COMMAND benchmark-der WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + set_tests_properties(benchmark-der PROPERTIES LABELS "diarization;benchmark") + + if (WHISPER_SANITIZE_ADDRESS) + message(STATUS "Enabling AddressSanitizer for benchmark-der") + target_compile_options(benchmark-der PRIVATE -fsanitize=address -fno-omit-frame-pointer) + target_link_options(benchmark-der PRIVATE -fsanitize=address) + endif() + + if (WHISPER_SANITIZE_ADDRESS AND CMAKE_SYSTEM_NAME STREQUAL "Linux") + set_tests_properties(benchmark-der PROPERTIES ENVIRONMENT "ASAN_OPTIONS=detect_leaks=1:strip_path_prefix=${CMAKE_SOURCE_DIR}/") + endif() endif() -endif() -# Test: Reference embedding comparison (feed PyTorch mel → compare embeddings) -add_executable(test-ref-compare test-ref-compare.cpp ../src/whisper-diarize.cpp ../src/whisper-speaker.cpp) -target_include_directories(test-ref-compare PRIVATE ../include ../ggml/include ../src) -target_link_libraries(test-ref-compare PRIVATE ggml m) + whisper_add_optional_test_executable( + test-ref-compare + test-ref-compare.cpp + ../src/whisper-diarize.cpp + ../src/whisper-speaker.cpp) + if (TARGET test-ref-compare) + target_include_directories(test-ref-compare PRIVATE ../include ../ggml/include ../src) + target_link_libraries(test-ref-compare PRIVATE ggml m) + endif() -# Test: Mel feature comparison -add_executable(test-mel-compare test-mel-compare.cpp ../src/whisper-diarize.cpp ../src/whisper-speaker.cpp) -target_include_directories(test-mel-compare PRIVATE ../include ../ggml/include ../src) -target_link_libraries(test-mel-compare PRIVATE ggml m) + whisper_add_optional_test_executable( + test-mel-compare + test-mel-compare.cpp + ../src/whisper-diarize.cpp + ../src/whisper-speaker.cpp) + if (TARGET test-mel-compare) + target_include_directories(test-mel-compare PRIVATE ../include ../ggml/include ../src) + target_link_libraries(test-mel-compare PRIVATE ggml m) + endif() +endif() From 7f92aa025575e69fc51fb677487c4ba45a2b3ef1 Mon Sep 17 00:00:00 2001 From: MKY508 Date: Sat, 4 Apr 2026 02:13:45 +0800 Subject: [PATCH 4/5] whisper : fix speaker encoder OOM on certain audio lengths The ggml context pool could run out of space for some segment lengths where the estimate was a few MB short. Add 10% margin to the allocation. --- src/whisper-diarize.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/whisper-diarize.cpp b/src/whisper-diarize.cpp index 2bf34e1180d..44457cf8ea1 100644 --- a/src/whisper-diarize.cpp +++ b/src/whisper-diarize.cpp @@ -383,8 +383,9 @@ struct whisper_speaker_encoder * whisper_speaker_encoder_new( encoder->n_frames = n_frames; encoder->n_mels = 80; // Fixed for ECAPA-TDNN - // Dynamic context size: base 200MB + ~0.5MB per frame for intermediate tensors + // Dynamic context size: base 200MB + ~0.5MB per frame for intermediate tensors + 10% margin size_t ctx_bytes = (size_t)200 * 1024 * 1024 + (size_t)n_frames * 512 * 1024; + ctx_bytes = ctx_bytes + ctx_bytes / 10; struct ggml_init_params params = { ctx_bytes, NULL, From c34ed63976887c704bf0b2296847e6e781a6c07d Mon Sep 17 00:00:00 2001 From: MKY508 Date: Sat, 4 Apr 2026 05:05:17 +0800 Subject: [PATCH 5/5] whisper : improve diarization window assignment --- examples/cli/cli.cpp | 2 +- examples/server/server.cpp | 2 +- src/whisper.cpp | 596 +++++++++++++++++++++++++++++++++---- 3 files changed, 540 insertions(+), 60 deletions(-) diff --git a/examples/cli/cli.cpp b/examples/cli/cli.cpp index 3ff017708f8..8d2a089a53f 100644 --- a/examples/cli/cli.cpp +++ b/examples/cli/cli.cpp @@ -89,7 +89,7 @@ struct whisper_params { std::string grammar; std::string grammar_rule; - float diarize_threshold = 0.5f; + float diarize_threshold = 0.70f; int diarize_speakers = 0; // [TDRZ] speaker turn string diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 9e50b523cc5..1f2fb2f1a41 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -116,7 +116,7 @@ struct whisper_params { std::string response_format = json_format; - float diarize_threshold = 0.5f; + float diarize_threshold = 0.70f; int diarize_speakers = 0; // [TDRZ] speaker turn string diff --git a/src/whisper.cpp b/src/whisper.cpp index 842326d82fa..3194c0c4523 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -944,6 +944,8 @@ struct whisper_state { struct whisper_speaker_model * diarize_model = nullptr; std::vector diarize_embeddings; std::vector diarize_segment_indices; + std::vector diarize_window_starts; + std::vector diarize_window_ends; struct whisper_clustering_context * diarize_clustering = nullptr; }; @@ -965,45 +967,24 @@ struct whisper_context { }; #ifdef WHISPER_DIARIZE -static bool whisper_compute_diarization_embedding( +// Compute a single embedding from a sample range and push it into state. +static bool whisper_compute_embedding_for_range( struct whisper_state * state, const float * samples, - int n_samples, + int range_start, + int range_n_samples, int segment_index) { - if (state == nullptr || state->diarize_model == nullptr || samples == nullptr) { - return false; - } - - if (segment_index < 0 || segment_index >= (int) state->result_all.size()) { - return false; - } - - const whisper_segment & seg = state->result_all[segment_index]; - - int sample_start = (int) (seg.t0 * WHISPER_SAMPLE_RATE / 100); - int sample_end = (int) (seg.t1 * WHISPER_SAMPLE_RATE / 100); - - sample_start = std::max(sample_start, 0); - sample_end = std::min(sample_end, n_samples); - - const int seg_n_samples = sample_end - sample_start; - if (seg_n_samples <= 1600) { - return false; - } - - float * mel_data = whisper_compute_mel_80(samples + sample_start, seg_n_samples); + float * mel_data = whisper_compute_mel_80(samples + range_start, range_n_samples); if (mel_data == nullptr) { - WHISPER_LOG_WARN("%s: failed to compute mel for segment %d\n", __func__, segment_index); return false; } - const int mel_n_frames = whisper_get_mel_n_frames(seg_n_samples); + const int mel_n_frames = whisper_get_mel_n_frames(range_n_samples); struct whisper_speaker_encoder * enc = whisper_speaker_encoder_new(state->diarize_model, mel_n_frames, 0); if (enc == nullptr) { whisper_mel_free(mel_data); - WHISPER_LOG_WARN("%s: failed to create speaker encoder for segment %d\n", __func__, segment_index); return false; } @@ -1014,7 +995,6 @@ static bool whisper_compute_diarization_embedding( whisper_mel_free(mel_data); if (!ok) { - WHISPER_LOG_WARN("%s: failed to compute speaker embedding for segment %d\n", __func__, segment_index); return false; } @@ -1023,22 +1003,525 @@ static bool whisper_compute_diarization_embedding( embedding, embedding + 192); state->diarize_segment_indices.push_back(segment_index); + state->diarize_window_starts.push_back(range_start); + state->diarize_window_ends.push_back(range_start + range_n_samples); return true; } -static void whisper_collect_diarization_embeddings( +// Sub-window size for speaker embedding extraction (2 seconds) +static const int DIARIZE_WINDOW_SAMPLES = 2 * WHISPER_SAMPLE_RATE; +// Hop size (1 second) — 50% overlap +static const int DIARIZE_HOP_SAMPLES = 1 * WHISPER_SAMPLE_RATE; +// Minimum duration for a speaker turn before we preserve a split (1.5 seconds in 10 ms units) +static const int64_t DIARIZE_MIN_TURN_TIME = 150; + +static int whisper_diarize_majority_speaker( + const struct whisper_state * state, + const std::vector & window_indices) { + if (state == nullptr || state->diarize_clustering == nullptr || window_indices.empty()) { + return -1; + } + + std::map speaker_votes; + for (int window_index : window_indices) { + if (window_index < 0 || window_index >= state->diarize_clustering->num_segments) { + continue; + } + + speaker_votes[state->diarize_clustering->speaker_ids[window_index]]++; + } + + int best_speaker = -1; + int best_votes = 0; + for (std::map::const_iterator it = speaker_votes.begin(); it != speaker_votes.end(); ++it) { + if (it->second > best_votes) { + best_votes = it->second; + best_speaker = it->first; + } + } + + return best_speaker; +} + +static int whisper_diarize_assign_token_speaker( + const struct whisper_state * state, + const std::vector & window_indices, + const struct whisper_token_data & token, + int fallback_speaker) { + if (state == nullptr || state->diarize_clustering == nullptr || window_indices.empty()) { + return fallback_speaker; + } + + if (token.t0 < 0 || token.t1 < 0 || token.t1 <= token.t0) { + return fallback_speaker; + } + + const int token_start = (int) (token.t0 * WHISPER_SAMPLE_RATE / 100); + const int token_end = (int) (token.t1 * WHISPER_SAMPLE_RATE / 100); + + std::map speaker_overlap; + for (int window_index : window_indices) { + if (window_index < 0 || window_index >= state->diarize_clustering->num_segments) { + continue; + } + + const int overlap_start = std::max(token_start, state->diarize_window_starts[window_index]); + const int overlap_end = std::min(token_end, state->diarize_window_ends[window_index]); + if (overlap_end <= overlap_start) { + continue; + } + + const int speaker_id = state->diarize_clustering->speaker_ids[window_index]; + speaker_overlap[speaker_id] += overlap_end - overlap_start; + } + + int best_speaker = fallback_speaker; + int best_overlap = 0; + for (std::map::const_iterator it = speaker_overlap.begin(); it != speaker_overlap.end(); ++it) { + if (it->second > best_overlap) { + best_overlap = it->second; + best_speaker = it->first; + } + } + + return best_speaker; +} + +static float whisper_diarize_cosine_distance( + const float * lhs, + const float * rhs, + int dim) { + double dot = 0.0; + double lhs_norm = 0.0; + double rhs_norm = 0.0; + + for (int i = 0; i < dim; ++i) { + dot += (double) lhs[i] * rhs[i]; + lhs_norm += (double) lhs[i] * lhs[i]; + rhs_norm += (double) rhs[i] * rhs[i]; + } + + if (lhs_norm <= 0.0 || rhs_norm <= 0.0) { + return 1.0f; + } + + return 1.0f - (float) (dot / (sqrt(lhs_norm) * sqrt(rhs_norm))); +} + +static void whisper_diarize_merge_small_clusters(struct whisper_state * state) { + if (state == nullptr || state->diarize_clustering == nullptr) { + return; + } + + const int num_embeddings = state->diarize_clustering->num_segments; + const int embedding_dim = 192; + if (num_embeddings <= 0) { + return; + } + + while (true) { + std::map > clusters; + for (int i = 0; i < num_embeddings; ++i) { + clusters[state->diarize_clustering->speaker_ids[i]].push_back(i); + } + + if ((int) clusters.size() <= 3) { + break; + } + + std::map > centroids; + for (std::map >::const_iterator it = clusters.begin(); it != clusters.end(); ++it) { + std::vector centroid(embedding_dim, 0.0f); + for (int window_index : it->second) { + const float * embedding = state->diarize_embeddings.data() + window_index * embedding_dim; + for (int j = 0; j < embedding_dim; ++j) { + centroid[j] += embedding[j]; + } + } + + const float inv = 1.0f / it->second.size(); + for (int j = 0; j < embedding_dim; ++j) { + centroid[j] *= inv; + } + + centroids[it->first] = std::move(centroid); + } + + int smallest_speaker = -1; + size_t smallest_size = std::numeric_limits::max(); + for (std::map >::const_iterator it = clusters.begin(); it != clusters.end(); ++it) { + if (it->second.size() < smallest_size) { + smallest_size = it->second.size(); + smallest_speaker = it->first; + } + } + + if (smallest_speaker < 0 || smallest_size >= 3) { + break; + } + + int best_target = -1; + float best_distance = std::numeric_limits::max(); + for (std::map >::const_iterator it = clusters.begin(); it != clusters.end(); ++it) { + if (it->first == smallest_speaker) { + continue; + } + + const float distance = whisper_diarize_cosine_distance( + centroids[smallest_speaker].data(), + centroids[it->first].data(), + embedding_dim); + if (distance < best_distance) { + best_distance = distance; + best_target = it->first; + } + } + + if (best_target < 0) { + break; + } + + for (int window_index : clusters[smallest_speaker]) { + state->diarize_clustering->speaker_ids[window_index] = best_target; + } + } + + std::map speaker_remap; + int next_speaker = 0; + for (int i = 0; i < num_embeddings; ++i) { + const int speaker_id = state->diarize_clustering->speaker_ids[i]; + if (speaker_remap.find(speaker_id) == speaker_remap.end()) { + speaker_remap[speaker_id] = next_speaker++; + } + state->diarize_clustering->speaker_ids[i] = speaker_remap[speaker_id]; + } + state->diarize_clustering->num_speakers = next_speaker; +} + +static void whisper_diarize_append_token_range( + struct whisper_context & ctx, + const whisper_segment & segment, + int token_start, + int token_end, + int speaker_id, + bool speaker_turn_next, + std::vector & out_segments) { + if (token_start < 0 || token_end < token_start || token_end >= (int) segment.tokens.size()) { + return; + } + + whisper_segment split = segment; + split.tokens.assign(segment.tokens.begin() + token_start, segment.tokens.begin() + token_end + 1); + split.text.clear(); + split.speaker_id = speaker_id; + split.speaker_turn_next = speaker_turn_next; + + int64_t split_t0 = segment.t0; + int64_t split_t1 = segment.t1; + + for (int i = 0; i < (int) split.tokens.size(); ++i) { + const whisper_token_data & token = split.tokens[i]; + if (token.id >= whisper_token_eot(&ctx)) { + continue; + } + + if (token.t0 >= 0) { + split_t0 = token.t0; + break; + } + } + + for (int i = (int) split.tokens.size() - 1; i >= 0; --i) { + const whisper_token_data & token = split.tokens[i]; + if (token.id >= whisper_token_eot(&ctx)) { + continue; + } + + if (token.t1 >= split_t0) { + split_t1 = token.t1; + break; + } + } + + split.t0 = std::max(segment.t0, split_t0); + split.t1 = std::max(split.t0, std::min(segment.t1, split_t1)); + + for (int i = 0; i < (int) split.tokens.size(); ++i) { + const whisper_token_data & token = split.tokens[i]; + if (token.id >= whisper_token_eot(&ctx)) { + continue; + } + + split.text += whisper_token_to_str(&ctx, token.id); + } + + out_segments.push_back(std::move(split)); +} + +static void whisper_diarize_assign_segments( + struct whisper_context & ctx, + struct whisper_state & state) { + if (state.diarize_clustering == nullptr) { + return; + } + + std::vector diarized_segments; + diarized_segments.reserve(state.result_all.size()); + + for (int segment_index = 0; segment_index < (int) state.result_all.size(); ++segment_index) { + whisper_segment & segment = state.result_all[segment_index]; + + const int segment_start = std::max(0, (int) (segment.t0 * WHISPER_SAMPLE_RATE / 100)); + const int segment_end = std::max(segment_start, (int) (segment.t1 * WHISPER_SAMPLE_RATE / 100)); + + std::vector window_indices; + for (int i = 0; i < (int) state.diarize_window_starts.size(); ++i) { + if (state.diarize_window_ends[i] <= segment_start || + state.diarize_window_starts[i] >= segment_end) { + continue; + } + window_indices.push_back(i); + } + + if (window_indices.empty()) { + diarized_segments.push_back(segment); + continue; + } + + const int fallback_speaker = whisper_diarize_majority_speaker(&state, window_indices); + + std::vector spoken_token_indices; + std::vector token_speakers; + spoken_token_indices.reserve(segment.tokens.size()); + token_speakers.reserve(segment.tokens.size()); + + for (int token_index = 0; token_index < (int) segment.tokens.size(); ++token_index) { + const whisper_token_data & token = segment.tokens[token_index]; + if (token.id >= whisper_token_eot(&ctx)) { + continue; + } + + spoken_token_indices.push_back(token_index); + token_speakers.push_back( + whisper_diarize_assign_token_speaker(&state, window_indices, token, fallback_speaker)); + } + + if (token_speakers.empty()) { + segment.speaker_id = fallback_speaker; + diarized_segments.push_back(segment); + continue; + } + + for (int i = 1; i + 1 < (int) token_speakers.size(); ++i) { + if (token_speakers[i - 1] == token_speakers[i + 1] && + token_speakers[i] != token_speakers[i - 1]) { + token_speakers[i] = token_speakers[i - 1]; + } + } + + bool merged_short_run = true; + while (merged_short_run && token_speakers.size() > 1) { + merged_short_run = false; + + int run_start = 0; + while (run_start < (int) token_speakers.size()) { + int run_end = run_start + 1; + while (run_end < (int) token_speakers.size() && + token_speakers[run_end] == token_speakers[run_start]) { + ++run_end; + } + + const whisper_token_data & run_first = segment.tokens[spoken_token_indices[run_start]]; + const whisper_token_data & run_last = segment.tokens[spoken_token_indices[run_end - 1]]; + const int64_t run_t0 = run_first.t0 >= 0 ? run_first.t0 : segment.t0; + const int64_t run_t1 = run_last.t1 >= run_t0 ? run_last.t1 : run_t0; + + if (run_t1 - run_t0 < DIARIZE_MIN_TURN_TIME && + (run_start > 0 || run_end < (int) token_speakers.size())) { + int replacement_speaker = fallback_speaker; + + if (run_start == 0) { + replacement_speaker = token_speakers[run_end]; + } else if (run_end == (int) token_speakers.size()) { + replacement_speaker = token_speakers[run_start - 1]; + } else if (token_speakers[run_start - 1] == token_speakers[run_end]) { + replacement_speaker = token_speakers[run_start - 1]; + } else { + int prev_start = run_start - 1; + while (prev_start > 0 && token_speakers[prev_start - 1] == token_speakers[run_start - 1]) { + --prev_start; + } + + int next_end = run_end; + while (next_end < (int) token_speakers.size() && + token_speakers[next_end] == token_speakers[run_end]) { + ++next_end; + } + + const whisper_token_data & prev_first = segment.tokens[spoken_token_indices[prev_start]]; + const whisper_token_data & prev_last = segment.tokens[spoken_token_indices[run_start - 1]]; + const whisper_token_data & next_first = segment.tokens[spoken_token_indices[run_end]]; + const whisper_token_data & next_last = segment.tokens[spoken_token_indices[next_end - 1]]; + + const int64_t prev_t0 = prev_first.t0 >= 0 ? prev_first.t0 : segment.t0; + const int64_t prev_t1 = prev_last.t1 >= prev_t0 ? prev_last.t1 : prev_t0; + const int64_t next_t0 = next_first.t0 >= 0 ? next_first.t0 : segment.t0; + const int64_t next_t1 = next_last.t1 >= next_t0 ? next_last.t1 : next_t0; + + const int64_t prev_dur = prev_t1 - prev_t0; + const int64_t next_dur = next_t1 - next_t0; + + replacement_speaker = + prev_dur >= next_dur ? token_speakers[run_start - 1] : token_speakers[run_end]; + } + + for (int i = run_start; i < run_end; ++i) { + token_speakers[i] = replacement_speaker; + } + + merged_short_run = true; + break; + } + + run_start = run_end; + } + } + + bool uniform_speaker = true; + for (int i = 1; i < (int) token_speakers.size(); ++i) { + if (token_speakers[i] != token_speakers[0]) { + uniform_speaker = false; + break; + } + } + + if (uniform_speaker) { + segment.speaker_id = token_speakers[0]; + diarized_segments.push_back(segment); + continue; + } + + int run_start = 0; + while (run_start < (int) token_speakers.size()) { + int run_end = run_start + 1; + while (run_end < (int) token_speakers.size() && + token_speakers[run_end] == token_speakers[run_start]) { + ++run_end; + } + + whisper_diarize_append_token_range( + ctx, + segment, + spoken_token_indices[run_start], + spoken_token_indices[run_end - 1], + token_speakers[run_start], + run_end == (int) token_speakers.size() ? segment.speaker_turn_next : false, + diarized_segments); + + run_start = run_end; + } + } + + state.result_all.swap(diarized_segments); +} + +static bool whisper_compute_diarization_embedding( struct whisper_state * state, const float * samples, int n_samples, - int n_new_segments) { - if (state == nullptr || state->diarize_model == nullptr || n_new_segments <= 0) { + int range_start, + int range_end) { + if (state == nullptr || state->diarize_model == nullptr || samples == nullptr) { + return false; + } + + const int sample_start = std::max(0, range_start); + const int sample_end = std::min(range_end, n_samples); + const int range_n_samples = sample_end - sample_start; + if (range_n_samples <= 1600) { + return false; + } + + if (range_n_samples <= DIARIZE_WINDOW_SAMPLES) { + return whisper_compute_embedding_for_range(state, samples, sample_start, range_n_samples, -1); + } + + std::vector offsets; + for (int offset = 0; offset + DIARIZE_WINDOW_SAMPLES <= range_n_samples; offset += DIARIZE_HOP_SAMPLES) { + offsets.push_back(offset); + } + + const int tail_offset = range_n_samples - DIARIZE_WINDOW_SAMPLES; + if (tail_offset >= 0 && (offsets.empty() || tail_offset > offsets.back())) { + offsets.push_back(tail_offset); + } + + bool any_ok = false; + for (int i = 0; i < (int) offsets.size(); ++i) { + const int offset = offsets[i]; + const float * win = samples + sample_start + offset; + float energy = 0.0f; + for (int j = 0; j < DIARIZE_WINDOW_SAMPLES; j++) { + energy += win[j] * win[j]; + } + energy = sqrtf(energy / DIARIZE_WINDOW_SAMPLES); + if (energy < 0.01f) { + continue; + } + if (whisper_compute_embedding_for_range(state, samples, sample_start + offset, DIARIZE_WINDOW_SAMPLES, -1)) { + any_ok = true; + } + } + + if (!any_ok) { + return whisper_compute_embedding_for_range(state, samples, sample_start, range_n_samples, -1); + } + + return any_ok; +} + +static void whisper_collect_diarization_embeddings( + struct whisper_state * state, + const float * samples, + int n_samples) { + if (state == nullptr || state->diarize_model == nullptr || state->result_all.empty()) { return; } - const int start = std::max(0, (int) state->result_all.size() - n_new_segments); - for (int segment_index = start; segment_index < (int) state->result_all.size(); ++segment_index) { - whisper_compute_diarization_embedding(state, samples, n_samples, segment_index); + static const int DIARIZE_REGION_GAP_SAMPLES = 0; + + state->diarize_embeddings.clear(); + state->diarize_segment_indices.clear(); + state->diarize_window_starts.clear(); + state->diarize_window_ends.clear(); + + std::vector > regions; + regions.reserve(state->result_all.size()); + + for (int i = 0; i < (int) state->result_all.size(); ++i) { + const whisper_segment & segment = state->result_all[i]; + const int sample_start = std::max(0, (int) (segment.t0 * WHISPER_SAMPLE_RATE / 100)); + const int sample_end = std::min(n_samples, (int) (segment.t1 * WHISPER_SAMPLE_RATE / 100)); + + if (sample_end - sample_start <= 1600) { + continue; + } + + if (!regions.empty() && sample_start <= regions.back().second + DIARIZE_REGION_GAP_SAMPLES) { + regions.back().second = std::max(regions.back().second, sample_end); + } else { + regions.push_back(std::make_pair(sample_start, sample_end)); + } + } + + for (int i = 0; i < (int) regions.size(); ++i) { + whisper_compute_diarization_embedding( + state, + samples, + n_samples, + regions[i].first, + regions[i].second); } } #endif @@ -6099,7 +6582,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str // Speaker diarization defaults /*.diarize =*/ false, /*.diarize_model_path =*/ nullptr, - /*.diarize_threshold =*/ 0.5f, + /*.diarize_threshold =*/ 0.70f, /*.diarize_speakers =*/ 0, }; @@ -7102,6 +7585,8 @@ int whisper_full_with_state( state->diarize_embeddings.clear(); state->diarize_segment_indices.clear(); + state->diarize_window_starts.clear(); + state->diarize_window_ends.clear(); #endif // WHISPER_DIARIZE int seek = seek_start; @@ -7773,7 +8258,7 @@ int whisper_full_with_state( int n_new = 1; - if (params.token_timestamps) { + if (params.token_timestamps || params.diarize) { whisper_exp_compute_token_level_timestamps( *ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum); @@ -7785,12 +8270,6 @@ int whisper_full_with_state( params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data); } -#ifdef WHISPER_DIARIZE - // Compute diarization embeddings for all newly finalized segments. - if (params.diarize && state->diarize_model) { - whisper_collect_diarization_embeddings(state, samples, n_samples, n_new); - } -#endif // WHISPER_DIARIZE } text = ""; while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) { @@ -7825,7 +8304,7 @@ int whisper_full_with_state( int n_new = 1; - if (params.token_timestamps) { + if (params.token_timestamps || params.diarize) { whisper_exp_compute_token_level_timestamps( *ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum); @@ -7837,12 +8316,6 @@ int whisper_full_with_state( params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data); } -#ifdef WHISPER_DIARIZE - // Compute diarization embeddings for all newly finalized segments. - if (params.diarize && state->diarize_model) { - whisper_collect_diarization_embeddings(state, samples, n_samples, n_new); - } -#endif // WHISPER_DIARIZE } } @@ -7880,13 +8353,21 @@ int whisper_full_with_state( #ifdef WHISPER_DIARIZE // Perform speaker clustering if diarization enabled + if (params.diarize && state->diarize_model) { + whisper_collect_diarization_embeddings(state, samples, n_samples); + } + if (params.diarize && !state->diarize_embeddings.empty()) { const int num_embeddings = state->diarize_embeddings.size() / 192; const int num_indexed_segments = state->diarize_segment_indices.size(); - - if (num_embeddings != num_indexed_segments) { - WHISPER_LOG_ERROR("embedding count (%d) != diarized segment count (%d); skipping clustering\n", - num_embeddings, num_indexed_segments); + const int num_window_starts = state->diarize_window_starts.size(); + const int num_window_ends = state->diarize_window_ends.size(); + + if (num_embeddings != num_indexed_segments || + num_embeddings != num_window_starts || + num_embeddings != num_window_ends) { + WHISPER_LOG_ERROR("embedding metadata mismatch (embeddings=%d, segment_indices=%d, starts=%d, ends=%d); skipping clustering\n", + num_embeddings, num_indexed_segments, num_window_starts, num_window_ends); } else { state->diarize_clustering = whisper_clustering_context_create(num_embeddings); if (!state->diarize_clustering) { @@ -7901,13 +8382,12 @@ int whisper_full_with_state( ); if (ret == 0) { - // Assign speaker IDs to the segments that produced embeddings. - for (int i = 0; i < num_embeddings; ++i) { - state->result_all[state->diarize_segment_indices[i]].speaker_id = - state->diarize_clustering->speaker_ids[i]; + if (params.diarize_speakers <= 0) { + whisper_diarize_merge_small_clusters(state); } + whisper_diarize_assign_segments(*ctx, *state); WHISPER_LOG_INFO("diarization complete: %d speakers detected across %d segments\n", - state->diarize_clustering->num_speakers, num_embeddings); + state->diarize_clustering->num_speakers, (int) state->result_all.size()); } else { WHISPER_LOG_ERROR("clustering failed with code %d\n", ret); }