diff --git a/.gitignore b/.gitignore index 02553cfa..e3dc9717 100644 --- a/.gitignore +++ b/.gitignore @@ -60,6 +60,7 @@ manifest_and_icons manifest_and_icons.zip tests/database_test.db-wal internal/web/dist +web/dist # Development files diff --git a/cmd/server/main.go b/cmd/server/main.go index acdd3beb..06375602 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -239,6 +239,8 @@ func registerAdapters(cfg *config.Config) { adapters.NewVoxtralAdapter(voxtralEnvPath)) registry.RegisterTranscriptionAdapter("openai_whisper", adapters.NewOpenAIAdapter(cfg.OpenAIAPIKey)) + registry.RegisterTranscriptionAdapter("whisper_api", + adapters.NewWhisperAPIAdapter(cfg.WhisperAPIURL, cfg.WhisperAPIKey)) // Register diarization adapters registry.RegisterDiarizationAdapter("pyannote", diff --git a/internal/config/config.go b/internal/config/config.go index 72b99c3d..cc4250f5 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -41,6 +41,10 @@ type Config struct { // Hugging Face configuration HFToken string + + // Whisper API configuration + WhisperAPIURL string + WhisperAPIKey string } // Load loads configuration from environment variables and .env file @@ -70,6 +74,8 @@ func Load() *Config { SecureCookies: getEnv("SECURE_COOKIES", defaultSecure) == "true", OpenAIAPIKey: getEnv("OPENAI_API_KEY", ""), HFToken: getEnv("HF_TOKEN", ""), + WhisperAPIURL: getEnv("WHISPER_API_URL", ""), + WhisperAPIKey: getEnv("WHISPER_API_KEY", ""), } } diff --git a/internal/models/transcription.go b/internal/models/transcription.go index 632c6c3b..e7171373 100644 --- a/internal/models/transcription.go +++ b/internal/models/transcription.go @@ -129,6 +129,9 @@ type WhisperXParams struct { // OpenAI settings APIKey *string `json:"api_key,omitempty" gorm:"type:text"` + // External API Settings + APIURL *string `json:"api_url,omitempty" gorm:"type:text"` + // Voxtral settings MaxNewTokens *int `json:"max_new_tokens,omitempty" gorm:"type:int"` } diff --git a/internal/transcription/adapters/whisper_api_adapter.go b/internal/transcription/adapters/whisper_api_adapter.go new file mode 100644 index 00000000..15bc5b0f --- /dev/null +++ b/internal/transcription/adapters/whisper_api_adapter.go @@ -0,0 +1,299 @@ +package adapters + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/http" + "os" + "path/filepath" + "time" + + "scriberr/internal/transcription/interfaces" + "scriberr/pkg/logger" +) + +type WhisperAPIAdapter struct { + *BaseAdapter + apiURL string // global default from server config + apiKey string // global default from server config +} + +// NewWhisperAPIAdapter creates a new Whisper API adapter; globalURL/globalKey are server +// defaults overridden by per-job params when provided. +func NewWhisperAPIAdapter(globalURL, globalKey string) *WhisperAPIAdapter { + capabilities := interfaces.ModelCapabilities{ + ModelID: "whisper_api", + ModelFamily: "whisper_api", + DisplayName: "External Whisper API", + Description: "External Whisper API compatible with OpenAI's /v1/audio/transcriptions format", + Version: "1.0", + SupportedLanguages: []string{ + "en", "zh", "de", "es", "ru", "ko", "fr", "ja", "pt", "tr", "pl", "ca", "nl", + "ar", "sv", "it", "id", "hi", "fi", "vi", "he", "uk", "el", "ms", "cs", "ro", + "da", "hu", "ta", "no", "th", "ur", "hr", "bg", "lt", "la", "mi", "ml", "cy", + "sk", "te", "fa", "lv", "bn", "sr", "az", "sl", "kn", "et", "mk", "br", "eu", + "is", "hy", "ne", "mn", "bs", "kk", "sq", "sw", "gl", "mr", "pa", "si", "km", + "sn", "yo", "so", "af", "oc", "ka", "be", "tg", "sd", "gu", "am", "yi", "lo", + "uz", "fo", "ht", "ps", "tk", "nn", "mt", "sa", "lb", "my", "bo", "tl", "mg", + "as", "tt", "haw", "ln", "ha", "ba", "jw", "su", "auto", + }, + SupportedFormats: []string{"flac", "mp3", "mp4", "mpeg", "mpga", "m4a", "ogg", "wav", "webm"}, + RequiresGPU: false, + MemoryRequirement: 0, + Features: map[string]bool{ + "timestamps": true, + "word_level": true, + "diarization": true, + "translation": true, + "language_detection": true, + "vad": true, + }, + Metadata: map[string]string{ + "provider": "external_whisper_api", + }, + } + + schema := []interfaces.ParameterSchema{ + { + Name: "api_url", + Type: "string", + Required: true, + Description: "External Whisper API URL (e.g. http://localhost:8000/v1/audio/transcriptions)", + Group: "authentication", + }, + { + Name: "api_key", + Type: "string", + Required: false, + Description: "API Key if required by the external API", + Group: "authentication", + }, + { + Name: "model", + Type: "string", + Required: false, + Default: "whisper-1", + Options: []string{"whisper-1", "large-v3", "large-v2", "large-v1", "medium", "small", "base", "tiny"}, + Description: "Model name/ID to use", + Group: "basic", + }, + { + Name: "language", + Type: "string", + Required: false, + Description: "Language of the input audio (ISO-639-1)", + Group: "basic", + }, + } + + baseAdapter := NewBaseAdapter("whisper_api", "", capabilities, schema) + + return &WhisperAPIAdapter{ + BaseAdapter: baseAdapter, + apiURL: globalURL, + apiKey: globalKey, + } +} + +func (a *WhisperAPIAdapter) GetSupportedModels() []string { + return []string{"whisper-1"} +} + +func (a *WhisperAPIAdapter) PrepareEnvironment(ctx context.Context) error { + a.initialized = true + return nil +} + +func (a *WhisperAPIAdapter) GetEstimatedProcessingTime(input interfaces.AudioInput) time.Duration { + audioDuration := input.Duration + if audioDuration == 0 { + return 30 * time.Second + } + return time.Duration(float64(audioDuration) * 0.15) +} + +// Transcribe processes audio using the external Whisper API +// +//nolint:gocyclo // API interaction involves many steps +func (a *WhisperAPIAdapter) Transcribe(ctx context.Context, input interfaces.AudioInput, params map[string]interface{}, procCtx interfaces.ProcessingContext) (*interfaces.TranscriptResult, error) { + startTime := time.Now() + a.LogProcessingStart(input, procCtx) + defer func() { + a.LogProcessingEnd(procCtx, time.Since(startTime), nil) + }() + + writeLog := func(format string, args ...interface{}) { + logPath := filepath.Join(procCtx.OutputDirectory, "transcription.log") + f, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + logger.Error("Failed to open log file", "path", logPath, "error", err) + return + } + defer f.Close() + + msg := fmt.Sprintf(format, args...) + timestamp := time.Now().Format("2006-01-02 15:04:05") + fmt.Fprintf(f, "[%s] %s\n", timestamp, msg) + } + + writeLog("Starting external Whisper API transcription for job %s", procCtx.JobID) + writeLog("Input file: %s", input.FilePath) + + if err := a.ValidateAudioInput(input); err != nil { + writeLog("Error: Invalid audio input: %v", err) + return nil, fmt.Errorf("invalid audio input: %w", err) + } + + // Apply fallback: global config → per-job param → error + apiUrl := a.apiURL + if jobURL := a.GetStringParameter(params, "api_url"); jobURL != "" { + apiUrl = jobURL + } + if apiUrl == "" { + writeLog("Error: api_url is required but not provided (set WHISPER_API_URL or provide api_url in job params)") + return nil, fmt.Errorf("api_url is required but not provided") + } + + apiKey := a.apiKey + if jobKey := a.GetStringParameter(params, "api_key"); jobKey != "" { + apiKey = jobKey + } + + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + + file, err := os.Open(input.FilePath) + if err != nil { + writeLog("Error: Failed to open audio file: %v", err) + return nil, fmt.Errorf("failed to open audio file: %w", err) + } + defer file.Close() + + part, err := writer.CreateFormFile("file", filepath.Base(input.FilePath)) + if err != nil { + writeLog("Error: Failed to create form file: %v", err) + return nil, fmt.Errorf("failed to create form file: %w", err) + } + if _, err := io.Copy(part, file); err != nil { + writeLog("Error: Failed to copy file content: %v", err) + return nil, fmt.Errorf("failed to copy file content: %w", err) + } + + model := a.GetStringParameter(params, "model") + if model == "" { + model = "whisper-1" + } + writeLog("Model: %s", model) + _ = writer.WriteField("model", model) + + // Request timestamps + _ = writer.WriteField("response_format", "verbose_json") + _ = writer.WriteField("timestamp_granularities[]", "word") + _ = writer.WriteField("timestamp_granularities[]", "segment") + + if lang := a.GetStringParameter(params, "language"); lang != "" { + writeLog("Language: %s", lang) + _ = writer.WriteField("language", lang) + } + + if err := writer.Close(); err != nil { + writeLog("Error: Failed to close multipart writer: %v", err) + return nil, fmt.Errorf("failed to close multipart writer: %w", err) + } + + writeLog("Sending request to %s...", apiUrl) + req, err := http.NewRequestWithContext(ctx, "POST", apiUrl, body) + if err != nil { + writeLog("Error: Failed to create request: %v", err) + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", writer.FormDataContentType()) + if apiKey != "" { + req.Header.Set("Authorization", "Bearer "+apiKey) + } + + client := &http.Client{ + Timeout: 30 * time.Minute, // Generous timeout for large files + } + resp, err := client.Do(req) + if err != nil { + writeLog("Error: Request failed: %v", err) + return nil, fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + writeLog("Error: API error (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(respBody)) + } + + writeLog("Response received. Parsing...") + + var apiResponse struct { + Language string `json:"language"` + Duration float64 `json:"duration"` + Text string `json:"text"` + Segments []struct { + Start float64 `json:"start"` + End float64 `json:"end"` + Text string `json:"text"` + } `json:"segments"` + Words []struct { + Word string `json:"word"` + Start float64 `json:"start"` + End float64 `json:"end"` + } `json:"words"` + } + + if err := json.NewDecoder(resp.Body).Decode(&apiResponse); err != nil { + writeLog("Error: Failed to decode response: %v", err) + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + writeLog("Transcription completed successfully.") + + result := &interfaces.TranscriptResult{ + Language: apiResponse.Language, + Text: apiResponse.Text, + Segments: make([]interfaces.TranscriptSegment, len(apiResponse.Segments)), + WordSegments: make([]interfaces.TranscriptWord, len(apiResponse.Words)), + ProcessingTime: time.Since(startTime), + ModelUsed: model, + Metadata: a.CreateDefaultMetadata(params), + } + + if len(apiResponse.Segments) > 0 { + for i, seg := range apiResponse.Segments { + result.Segments[i] = interfaces.TranscriptSegment{ + Start: seg.Start, + End: seg.End, + Text: seg.Text, + } + } + } else if apiResponse.Text != "" { + // Fallback if segments aren't present + result.Segments = []interfaces.TranscriptSegment{ + { + Start: 0, + End: apiResponse.Duration, + Text: apiResponse.Text, + }, + } + } + + for i, word := range apiResponse.Words { + result.WordSegments[i] = interfaces.TranscriptWord{ + Word: word.Word, + Start: word.Start, + End: word.End, + } + } + + return result, nil +} diff --git a/internal/transcription/adapters_test.go b/internal/transcription/adapters_test.go index ca09ca87..946d6617 100644 --- a/internal/transcription/adapters_test.go +++ b/internal/transcription/adapters_test.go @@ -531,6 +531,94 @@ func TestParameterConversion(t *testing.T) { } } +func TestWhisperAPIAdapter(t *testing.T) { + reg := registry.GetRegistry() + registry.RegisterTranscriptionAdapter("whisper_api", adapters.NewWhisperAPIAdapter("http://localhost:9000/v1/audio/transcriptions", "test-key")) + + adapter, err := reg.GetTranscriptionAdapter("whisper_api") + if err != nil { + t.Fatalf("Failed to get WhisperAPI adapter: %v", err) + } + + capabilities := adapter.GetCapabilities() + if capabilities.ModelFamily != "whisper_api" { + t.Errorf("Expected model family 'whisper_api', got '%s'", capabilities.ModelFamily) + } + if !capabilities.Features["timestamps"] { + t.Error("Expected timestamps feature") + } + if capabilities.RequiresGPU { + t.Error("Expected RequiresGPU to be false") + } + + schema := adapter.GetParameterSchema() + if len(schema) == 0 { + t.Error("Expected non-empty parameter schema") + } + + // api_url must be present in schema + var hasAPIURL bool + for _, p := range schema { + if p.Name == "api_url" { + hasAPIURL = true + break + } + } + if !hasAPIURL { + t.Error("Expected api_url in parameter schema") + } + + validParams := map[string]interface{}{ + "api_url": "http://localhost:9000/v1/audio/transcriptions", + "model": "large-v3", + } + if err := adapter.ValidateParameters(validParams); err != nil { + t.Errorf("Valid parameters failed validation: %v", err) + } +} + +func TestWhisperAPIParamConversion(t *testing.T) { + mockRepo := new(MockJobRepository) + service := NewUnifiedTranscriptionService(mockRepo, "data/temp", "data/transcripts") + + apiURL := "http://localhost:9000/v1/audio/transcriptions" + apiKey := "secret" + lang := "fr" + + params := models.WhisperXParams{ + Model: "large-v3", + APIURL: &apiURL, + APIKey: &apiKey, + Language: &lang, + } + + m := service.convertToWhisperAPIParams(params) + + if m["model"] != "large-v3" { + t.Errorf("Expected model 'large-v3', got '%v'", m["model"]) + } + if m["api_url"] != apiURL { + t.Errorf("Expected api_url '%s', got '%v'", apiURL, m["api_url"]) + } + if m["api_key"] != apiKey { + t.Errorf("Expected api_key '%s', got '%v'", apiKey, m["api_key"]) + } + if m["language"] != lang { + t.Errorf("Expected language '%s', got '%v'", lang, m["language"]) + } + + // Empty optional fields should not be present + emptyURL := "" + paramsNoURL := models.WhisperXParams{Model: "whisper-1", APIURL: &emptyURL} + m2 := service.convertToWhisperAPIParams(paramsNoURL) + if _, ok := m2["api_url"]; ok { + t.Error("Empty api_url should not appear in param map") + } + if _, ok := m2["language"]; ok { + t.Error("Nil language should not appear in param map") + } +} + // Helper functions func stringPtr(s string) *string { return &s diff --git a/internal/transcription/unified_service.go b/internal/transcription/unified_service.go index e17ef8b4..01e91199 100644 --- a/internal/transcription/unified_service.go +++ b/internal/transcription/unified_service.go @@ -29,12 +29,14 @@ const ( ModelSortformer = "sortformer" ModelOpenAI = "openai_whisper" ModelVoxtral = "voxtral" + ModelWhisperAPI = "whisper_api" ModelDiarization31 = "pyannote/speaker-diarization-3.1" FamilyNvidiaCanary = "nvidia_canary" FamilyNvidiaParakeet = "nvidia_parakeet" FamilyWhisper = "whisper" FamilyOpenAI = "openai" FamilyMistralVoxtral = "mistral_voxtral" + FamilyWhisperAPI = "whisper_api" DiarizeSortformer = "nvidia_sortformer" OutputFormatJSON = "json" ) @@ -387,6 +389,8 @@ func (u *UnifiedTranscriptionService) selectModels(params models.WhisperXParams) transcriptionModelID = ModelOpenAI case FamilyMistralVoxtral: transcriptionModelID = ModelVoxtral + case FamilyWhisperAPI: + transcriptionModelID = ModelWhisperAPI default: transcriptionModelID = ModelWhisperX // Default fallback } @@ -565,12 +569,35 @@ func (u *UnifiedTranscriptionService) convertParametersForModel(params models.Wh return u.convertToOpenAIParams(params) case ModelVoxtral: return u.convertToVoxtralParams(params) + case ModelWhisperAPI: + return u.convertToWhisperAPIParams(params) default: // Fallback to legacy conversion return u.parametersToMap(params) } } +// convertToWhisperAPIParams converts to external Whisper API parameters +func (u *UnifiedTranscriptionService) convertToWhisperAPIParams(params models.WhisperXParams) map[string]interface{} { + paramMap := map[string]interface{}{ + "model": params.Model, + } + + if params.Language != nil { + paramMap["language"] = *params.Language + } + + if params.APIURL != nil && *params.APIURL != "" { + paramMap["api_url"] = *params.APIURL + } + + if params.APIKey != nil && *params.APIKey != "" { + paramMap["api_key"] = *params.APIKey + } + + return paramMap +} + // convertToOpenAIParams converts to OpenAI-specific parameters func (u *UnifiedTranscriptionService) convertToOpenAIParams(params models.WhisperXParams) map[string]interface{} { paramMap := map[string]interface{}{ diff --git a/tests/adapter_registration_test.go b/tests/adapter_registration_test.go index 3b9203c7..d29f0368 100644 --- a/tests/adapter_registration_test.go +++ b/tests/adapter_registration_test.go @@ -45,6 +45,12 @@ func TestAdapterEnvPathInjection(t *testing.T) { if pyannote == nil { t.Fatal("NewPyAnnoteAdapter returned nil") } + + // Test WhisperAPI adapter + whisperAPI := adapters.NewWhisperAPIAdapter("http://localhost:9000/v1/audio/transcriptions", "") + if whisperAPI == nil { + t.Fatal("NewWhisperAPIAdapter returned nil") + } } // TestRegisterAdapters tests that registerAdapters correctly registers all adapters @@ -67,6 +73,8 @@ func TestRegisterAdapters(t *testing.T) { adapters.NewParakeetAdapter(nvidiaEnvPath)) registry.RegisterTranscriptionAdapter("canary", adapters.NewCanaryAdapter(nvidiaEnvPath)) + registry.RegisterTranscriptionAdapter("whisper_api", + adapters.NewWhisperAPIAdapter("", "")) registry.RegisterDiarizationAdapter("pyannote", adapters.NewPyAnnoteAdapter(nvidiaEnvPath)) @@ -75,8 +83,8 @@ func TestRegisterAdapters(t *testing.T) { // Verify registrations transcriptionAdapters := registry.GetTranscriptionAdapters() - if len(transcriptionAdapters) != 3 { - t.Errorf("Expected 3 transcription adapters, got %d", len(transcriptionAdapters)) + if len(transcriptionAdapters) != 4 { + t.Errorf("Expected 4 transcription adapters, got %d", len(transcriptionAdapters)) } // Check specific adapters are registered @@ -89,6 +97,9 @@ func TestRegisterAdapters(t *testing.T) { if _, exists := transcriptionAdapters["canary"]; !exists { t.Error("canary adapter not registered") } + if _, exists := transcriptionAdapters["whisper_api"]; !exists { + t.Error("whisper_api adapter not registered") + } diarizationAdapters := registry.GetDiarizationAdapters() if len(diarizationAdapters) != 2 { @@ -132,6 +143,11 @@ func TestAdaptersUseConfigPaths(t *testing.T) { if pyannote == nil { t.Error("PyAnnote adapter should accept custom path") } + + whisperAPI := adapters.NewWhisperAPIAdapter("http://custom.host/v1/audio/transcriptions", "key") + if whisperAPI == nil { + t.Error("WhisperAPI adapter should accept custom URL and key") + } } // TestClearRegistry tests the registry clear function diff --git a/web/frontend/src/components/transcription/TranscriptionConfigDialog.tsx b/web/frontend/src/components/transcription/TranscriptionConfigDialog.tsx index 01f5da56..9a80b235 100644 --- a/web/frontend/src/components/transcription/TranscriptionConfigDialog.tsx +++ b/web/frontend/src/components/transcription/TranscriptionConfigDialog.tsx @@ -84,6 +84,7 @@ export interface WhisperXParams { attention_context_right: number; is_multi_track_enabled: boolean; api_key?: string; + api_url?: string; max_new_tokens?: number; } @@ -427,6 +428,9 @@ export const TranscriptionConfigDialog = memo(function TranscriptionConfigDialog OpenAI + + External Whisper API + @@ -481,6 +485,14 @@ export const TranscriptionConfigDialog = memo(function TranscriptionConfigDialog updateParam={updateParam} /> )} + + {params.model_family === "whisper_api" && ( + + )} {/* Footer */} @@ -1130,6 +1142,159 @@ function OpenAIConfig({ ); } +function WhisperAPIConfig({ params, updateParam, isMultiTrack }: ConfigProps) { + return ( +
+
+
+ + updateParam('api_url', e.target.value)} + className={inputClassName} + /> + + + + updateParam('api_key', e.target.value)} + className={inputClassName} + /> + + + + updateParam('model', e.target.value)} + className={inputClassName} + /> + + + + + +
+
+ + {!isMultiTrack && ( +
+
+
+ updateParam('diarize', v)} + /> + +
+ + {params.diarize && ( +
+ + + + +
+ + updateParam('min_speakers', e.target.value ? parseInt(e.target.value) : undefined)} + className={inputClassName} + /> + + + updateParam('max_speakers', e.target.value ? parseInt(e.target.value) : undefined)} + className={inputClassName} + /> + +
+ + {params.diarize_model === "pyannote" && ( + <> + + updateParam('hf_token', e.target.value || undefined)} + className={inputClassName} + /> + + +
+

Voice Detection Tuning (for noisy/distant audio)

+
+ + updateParam('vad_onset', parseFloat(e.target.value) || 0.5)} + className={inputClassName} + /> + + + updateParam('vad_offset', parseFloat(e.target.value) || 0.363)} + className={inputClassName} + /> + +
+
+ + )} +
+ )} +
+
+ )} +
+ ); +} + function VoxtralConfig({ params, updateParam }: ConfigProps) { return (