Skip to content

Commit 81da609

Browse files
committed
feat(proxy): adapt VoxCPM clone through TTS API
1 parent 8492cba commit 81da609

2 files changed

Lines changed: 264 additions & 1 deletion

File tree

internal/inferencehttp/routes.go

Lines changed: 124 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"io"
99
"log/slog"
1010
"mime"
11+
"mime/multipart"
1112
"net/http"
1213
"net/http/httputil"
1314
"net/url"
@@ -19,6 +20,7 @@ const maxTTSRequestBody = 16 << 20
1920
const (
2021
adapterLiteTTSHTTP = "litetts_http"
2122
adapterMooERGRPC = "mooer_grpc"
23+
adapterVoxCPMClone = "voxcpm_clone"
2224
)
2325

2426
// RegisterRoutes returns a function that registers AIMA inference proxy routes.
@@ -74,10 +76,15 @@ func (d *Deps) handleTTS(w http.ResponseWriter, r *http.Request) {
7476
return
7577
}
7678

77-
if d.adapterFor(model, r.URL.Path) == adapterLiteTTSHTTP {
79+
adapter := d.adapterFor(model, r.URL.Path)
80+
if adapter == adapterLiteTTSHTTP {
7881
d.handleLiteTTS(w, r, backend, raw)
7982
return
8083
}
84+
if adapter == adapterVoxCPMClone && hasTTSReferenceAudio(raw) {
85+
d.handleVoxCPMClone(w, r, backend, raw, body)
86+
return
87+
}
8188

8289
switch r.URL.Path {
8390
case "/v1/tts":
@@ -504,6 +511,122 @@ func (d *Deps) forwardTTSJSON(w http.ResponseWriter, r *http.Request, backend *B
504511
writeBackendResponse(w, resp, respBody)
505512
}
506513

514+
func (d *Deps) handleVoxCPMClone(w http.ResponseWriter, r *http.Request, backend *Backend, raw map[string]any, requestBody []byte) {
515+
body, contentType, err := buildVoxCPMCloneRequest(raw)
516+
if err != nil {
517+
http.Error(w, err.Error(), http.StatusBadRequest)
518+
return
519+
}
520+
521+
resp, respBody, err := d.callBackend(r, backend.Address, "/v1/clone", contentType, body)
522+
if err != nil {
523+
slog.Warn("aima proxy: VoxCPM clone backend request failed", "backend", backend.Address, "err", err)
524+
http.Error(w, "backend unreachable", http.StatusBadGateway)
525+
return
526+
}
527+
528+
if r.URL.Path == "/v1/tts" && resp.StatusCode >= 200 && resp.StatusCode < 300 && isAudioContent(resp.Header.Get("Content-Type")) {
529+
writeAudioJSON(w, respBody, requestBody, resp.Header.Get("Content-Type"), resp.StatusCode)
530+
return
531+
}
532+
if r.URL.Path == "/v1/audio/speech" && resp.StatusCode >= 200 && resp.StatusCode < 300 && writeAudioFromJSON(w, respBody, requestBody, resp.StatusCode) {
533+
return
534+
}
535+
writeBackendResponse(w, resp, respBody)
536+
}
537+
538+
func buildVoxCPMCloneRequest(raw map[string]any) ([]byte, string, error) {
539+
text := extractTTSText(raw)
540+
if text == "" {
541+
return nil, "", fmt.Errorf(`{"error":"missing or invalid input field"}`)
542+
}
543+
refAudio := firstTTSString(raw, "reference_audio", "ref_audio")
544+
if refAudio == "" {
545+
return nil, "", fmt.Errorf(`{"error":"missing or invalid reference_audio field"}`)
546+
}
547+
audio, filename, err := decodeReferenceAudio(refAudio)
548+
if err != nil {
549+
return nil, "", err
550+
}
551+
552+
var body bytes.Buffer
553+
writer := multipart.NewWriter(&body)
554+
if err := writer.WriteField("text", text); err != nil {
555+
return nil, "", err
556+
}
557+
if refText := firstTTSString(raw, "reference_text", "ref_text"); refText != "" {
558+
if err := writer.WriteField("ref_text", refText); err != nil {
559+
return nil, "", err
560+
}
561+
}
562+
for _, key := range []string{"response_format", "temperature", "cfg", "max_length"} {
563+
if value, ok := raw[key]; ok {
564+
if err := writer.WriteField(key, fmt.Sprint(value)); err != nil {
565+
return nil, "", err
566+
}
567+
}
568+
}
569+
part, err := writer.CreateFormFile("ref_audio", filename)
570+
if err != nil {
571+
return nil, "", err
572+
}
573+
if _, err := part.Write(audio); err != nil {
574+
return nil, "", err
575+
}
576+
if err := writer.Close(); err != nil {
577+
return nil, "", err
578+
}
579+
return body.Bytes(), writer.FormDataContentType(), nil
580+
}
581+
582+
func hasTTSReferenceAudio(raw map[string]any) bool {
583+
return firstTTSString(raw, "reference_audio", "ref_audio") != ""
584+
}
585+
586+
func firstTTSString(raw map[string]any, keys ...string) string {
587+
for _, key := range keys {
588+
if value, _ := raw[key].(string); strings.TrimSpace(value) != "" {
589+
return strings.TrimSpace(value)
590+
}
591+
}
592+
return ""
593+
}
594+
595+
func decodeReferenceAudio(value string) ([]byte, string, error) {
596+
value = strings.TrimSpace(value)
597+
if strings.HasPrefix(strings.ToLower(value), "data:") {
598+
return decodeReferenceAudioDataURL(value)
599+
}
600+
audio, err := base64.StdEncoding.DecodeString(value)
601+
if err != nil {
602+
return nil, "", fmt.Errorf(`{"error":"reference_audio must be a data URL or base64 audio"}`)
603+
}
604+
return audio, "reference.wav", nil
605+
}
606+
607+
func decodeReferenceAudioDataURL(value string) ([]byte, string, error) {
608+
comma := strings.IndexByte(value, ',')
609+
if comma < 0 {
610+
return nil, "", fmt.Errorf(`{"error":"invalid reference_audio data URL"}`)
611+
}
612+
meta := value[len("data:"):comma]
613+
payload := value[comma+1:]
614+
if !strings.Contains(strings.ToLower(meta), ";base64") {
615+
return nil, "", fmt.Errorf(`{"error":"reference_audio data URL must be base64 encoded"}`)
616+
}
617+
audio, err := base64.StdEncoding.DecodeString(payload)
618+
if err != nil {
619+
return nil, "", fmt.Errorf(`{"error":"invalid reference_audio base64 data"}`)
620+
}
621+
622+
contentType := strings.TrimSpace(strings.Split(meta, ";")[0])
623+
format := audioFormatFromContentType(contentType)
624+
if format == "" {
625+
format = "wav"
626+
}
627+
return audio, "reference." + format, nil
628+
}
629+
507630
func (d *Deps) callBackend(r *http.Request, targetAddr, targetPath, contentType string, body []byte) (*http.Response, []byte, error) {
508631
if !strings.HasPrefix(targetAddr, "http://") && !strings.HasPrefix(targetAddr, "https://") {
509632
targetAddr = "http://" + targetAddr

internal/inferencehttp/routes_test.go

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package inferencehttp
33
import (
44
"bytes"
55
"context"
6+
"encoding/base64"
67
"encoding/json"
78
"io"
89
"mime/multipart"
@@ -386,6 +387,145 @@ func TestHandleTTSJSONFallsBackToSpeechAndWrapsAudio(t *testing.T) {
386387
}
387388
}
388389

390+
func TestHandleTTSVoxCPMCloneJSONRoutesToCloneAndWrapsAudio(t *testing.T) {
391+
var (
392+
gotPath string
393+
gotContentType string
394+
gotFields map[string][]string
395+
gotRefAudio []byte
396+
gotRefName string
397+
)
398+
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
399+
gotPath = r.URL.Path
400+
gotContentType = r.Header.Get("Content-Type")
401+
if err := r.ParseMultipartForm(1 << 20); err != nil {
402+
t.Fatalf("ParseMultipartForm: %v", err)
403+
}
404+
gotFields = r.MultipartForm.Value
405+
file, header, err := r.FormFile("ref_audio")
406+
if err != nil {
407+
t.Fatalf("FormFile ref_audio: %v", err)
408+
}
409+
defer file.Close()
410+
gotRefName = header.Filename
411+
gotRefAudio, err = io.ReadAll(file)
412+
if err != nil {
413+
t.Fatalf("ReadAll ref_audio: %v", err)
414+
}
415+
w.Header().Set("Content-Type", "audio/wav")
416+
_, _ = w.Write([]byte("RIFFclone"))
417+
}))
418+
defer backend.Close()
419+
420+
deps := &Deps{
421+
Backends: staticBackendLister{backends: map[string]*Backend{
422+
"voxcpm2": {
423+
ModelName: "voxcpm2",
424+
Address: strings.TrimPrefix(backend.URL, "http://"),
425+
Ready: true,
426+
},
427+
}},
428+
Catalog: staticCatalog{adapters: map[string][]Adapter{
429+
"voxcpm2": {{Path: "/v1/tts", Kind: adapterVoxCPMClone}},
430+
}},
431+
}
432+
433+
mux := http.NewServeMux()
434+
RegisterRoutes(deps)(mux)
435+
436+
reqBody := `{"model":"voxcpm2","text":"hello","response_format":"wav","reference_audio":"data:audio/wav;base64,UklGRg==","reference_text":"sample voice"}`
437+
req := httptest.NewRequest(http.MethodPost, "/v1/tts", strings.NewReader(reqBody))
438+
req.Header.Set("Content-Type", "application/json")
439+
w := httptest.NewRecorder()
440+
mux.ServeHTTP(w, req)
441+
442+
if w.Code != http.StatusOK {
443+
t.Fatalf("status = %d, want %d; body=%s", w.Code, http.StatusOK, w.Body.String())
444+
}
445+
if gotPath != "/v1/clone" {
446+
t.Fatalf("backend path = %q, want /v1/clone", gotPath)
447+
}
448+
if !strings.HasPrefix(gotContentType, "multipart/form-data; boundary=") {
449+
t.Fatalf("content-type = %q, want multipart boundary", gotContentType)
450+
}
451+
field := func(name string) string {
452+
if values := gotFields[name]; len(values) > 0 {
453+
return values[0]
454+
}
455+
return ""
456+
}
457+
if field("text") != "hello" {
458+
t.Fatalf("text field = %q, want hello", field("text"))
459+
}
460+
if field("ref_text") != "sample voice" {
461+
t.Fatalf("ref_text field = %q, want sample voice", field("ref_text"))
462+
}
463+
if field("response_format") != "wav" {
464+
t.Fatalf("response_format field = %q, want wav", field("response_format"))
465+
}
466+
if gotRefName != "reference.wav" {
467+
t.Fatalf("ref_audio filename = %q, want reference.wav", gotRefName)
468+
}
469+
if string(gotRefAudio) != "RIFF" {
470+
t.Fatalf("ref_audio = %q, want RIFF", string(gotRefAudio))
471+
}
472+
var resp map[string]any
473+
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
474+
t.Fatalf("Unmarshal response: %v", err)
475+
}
476+
if resp["audio_base64"] != base64.StdEncoding.EncodeToString([]byte("RIFFclone")) {
477+
t.Fatalf("audio_base64 = %#v, want encoded RIFFclone", resp["audio_base64"])
478+
}
479+
if resp["format"] != "wav" {
480+
t.Fatalf("format = %#v, want wav", resp["format"])
481+
}
482+
}
483+
484+
func TestHandleTTSSpeechVoxCPMCloneReturnsAudio(t *testing.T) {
485+
var gotPath string
486+
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
487+
gotPath = r.URL.Path
488+
w.Header().Set("Content-Type", "audio/wav")
489+
_, _ = w.Write([]byte("RIFFclone"))
490+
}))
491+
defer backend.Close()
492+
493+
deps := &Deps{
494+
Backends: staticBackendLister{backends: map[string]*Backend{
495+
"voxcpm2": {
496+
ModelName: "voxcpm2",
497+
Address: strings.TrimPrefix(backend.URL, "http://"),
498+
Ready: true,
499+
},
500+
}},
501+
Catalog: staticCatalog{adapters: map[string][]Adapter{
502+
"voxcpm2": {{Path: "/v1/audio/speech", Kind: adapterVoxCPMClone}},
503+
}},
504+
}
505+
506+
mux := http.NewServeMux()
507+
RegisterRoutes(deps)(mux)
508+
509+
reqBody := `{"model":"voxcpm2","input":"hello","reference_audio":"UklGRg=="}`
510+
req := httptest.NewRequest(http.MethodPost, "/v1/audio/speech", strings.NewReader(reqBody))
511+
req.Header.Set("Content-Type", "application/json")
512+
w := httptest.NewRecorder()
513+
mux.ServeHTTP(w, req)
514+
515+
if w.Code != http.StatusOK {
516+
t.Fatalf("status = %d, want %d; body=%s", w.Code, http.StatusOK, w.Body.String())
517+
}
518+
if gotPath != "/v1/clone" {
519+
t.Fatalf("backend path = %q, want /v1/clone", gotPath)
520+
}
521+
if got := w.Body.String(); got != "RIFFclone" {
522+
t.Fatalf("body = %q, want RIFFclone", got)
523+
}
524+
if ct := w.Header().Get("Content-Type"); ct != "audio/wav" {
525+
t.Fatalf("content-type = %q, want audio/wav", ct)
526+
}
527+
}
528+
389529
func TestHandleTTSSpeechFallsBackToJSONAndDecodesAudio(t *testing.T) {
390530
var paths []string
391531
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

0 commit comments

Comments
 (0)