Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 57 additions & 2 deletions packages/transformers/src/models/whisper/modeling_whisper.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { cat, mean, Tensor, stack, std_mean } from '../../utils/tensor.js';
import { PreTrainedModel } from '../modeling_utils.js';
import { PreTrainedModel, encoder_forward, decoder_forward } from '../modeling_utils.js';
import { WhisperGenerationConfig } from './generation_whisper.js';
import { whisper_language_to_code } from './common_whisper.js';
import {
Expand Down Expand Up @@ -39,6 +39,50 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
);
}

/**
* Detects the language of the input audio by running a single forward pass.
* Feeds `<|startoftranscript|>` as the only decoder input and examines
* the logits at the next token position, masked to language tokens only.
* @param {Tensor} input_features The log-mel spectrogram input.
* @param {WhisperGenerationConfig} generation_config The generation config containing `lang_to_id` and `decoder_start_token_id`.
* @returns {Promise<number>} The detected language token ID.
*/
async detect_language(input_features, generation_config) {
// 1. Encode audio
const encoder_outputs = (await encoder_forward(this, { input_features })).last_hidden_state;

// 2. Prepare decoder input: just the <|startoftranscript|> token
const sot_token = generation_config.decoder_start_token_id;
const decoder_input_ids = new Tensor('int64', BigInt64Array.from([BigInt(sot_token)]), [1, 1]);

// 3. Run decoder forward pass
const decoder_outputs = await decoder_forward(
this,
{
input_ids: decoder_input_ids,
encoder_hidden_states: encoder_outputs,
},
true,
);

// 4. Get logits at the last (only) position and convert to float32
const logits_data = decoder_outputs.logits[0][0].to('float32').data;

// 5. Mask non-language tokens to -Infinity, then argmax
const lang_token_ids = new Set(Object.values(generation_config.lang_to_id));
let max_score = -Infinity;
let detected_token_id = -1;
for (let i = 0; i < logits_data.length; i++) {
if (!lang_token_ids.has(i)) continue;
if (logits_data[i] > max_score) {
max_score = logits_data[i];
detected_token_id = i;
}
}

return detected_token_id;
}

/**
*
* @param {WhisperGenerationConfig} generation_config
Expand All @@ -56,7 +100,6 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
const task = generation_config.task;
if (generation_config.is_multilingual) {
if (!language) {
// TODO: Implement language detection
logger.warn('No language specified - defaulting to English (en).');
language = 'en';
}
Expand Down Expand Up @@ -116,6 +159,18 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
}) {
generation_config = this._prepare_generation_config(generation_config, kwargs);

// Auto-detect language if not specified on multilingual models
if (generation_config.is_multilingual && !generation_config.language && inputs) {
const detected_token_id = await this.detect_language(inputs, generation_config);
// Reverse lookup: token_id -> language token string -> language code
const id_to_lang = Object.fromEntries(Object.entries(generation_config.lang_to_id).map(([k, v]) => [v, k]));
const language_token = id_to_lang[detected_token_id]; // e.g., "<|ko|>"
if (language_token) {
generation_config.language = language_token.slice(2, -2); // e.g., "ko"
logger.info(`Detected language: ${generation_config.language}`);
}
}

const init_tokens = kwargs.decoder_input_ids ?? this._retrieve_init_tokens(generation_config);

if (generation_config.return_timestamps) {
Expand Down