Skip to content

Commit c78a1d9

Browse files
authored
feat(ai): add audio generation support (goravel#1467)
1 parent d1161fd commit c78a1d9

26 files changed

Lines changed: 2330 additions & 97 deletions

ai/application.go

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,12 @@ func (r *Application) Agent(agent contractsai.Agent, options ...contractsai.Opti
3636
return NewConversation(r.ctx, agent, provider, model, middlewares), nil
3737
}
3838

39-
func (r *Application) Image(prompt string, options ...contractsai.Option) contractsai.ImageRequest {
40-
return NewImageRequest(r.ctx, r, prompt, options...)
39+
func (r *Application) Audio(prompt string) contractsai.AudioRequest {
40+
return NewAudioRequest(r.ctx, r, prompt)
41+
}
42+
43+
func (r *Application) Image(prompt string) contractsai.ImageRequest {
44+
return NewImageRequest(r.ctx, r, prompt)
4145
}
4246

4347
func (r *Application) putFile(ctx context.Context, file contractsai.StorableFile, options ...contractsai.Option) (contractsai.FileResponse, error) {
@@ -90,6 +94,23 @@ func (r *Application) deleteFile(ctx context.Context, id string, options ...cont
9094
return fileProvider.DeleteFile(ctx, id)
9195
}
9296

97+
func (r *Application) audio(ctx context.Context, prompt contractsai.AudioPrompt, options ...contractsai.Option) (contractsai.AudioResponse, error) {
98+
opts, providerName, provider, err := r.resolveProvider(options)
99+
if err != nil {
100+
return nil, err
101+
}
102+
if prompt.Model == "" {
103+
prompt.Model = opts.Model
104+
}
105+
106+
audioProvider, ok := provider.(contractsai.AudioProvider)
107+
if !ok {
108+
return nil, errors.AIProviderDoesNotSupportAudio.Args(providerName)
109+
}
110+
111+
return audioProvider.Audio(ctx, prompt)
112+
}
113+
93114
func (r *Application) image(ctx context.Context, prompt contractsai.ImagePrompt, options ...contractsai.Option) (contractsai.ImageResponse, error) {
94115
opts, providerName, provider, err := r.resolveProvider(options)
95116
if err != nil {

ai/application_test.go

Lines changed: 233 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ func TestApplication_Image(t *testing.T) {
347347
}
348348

349349
app := NewApplication(ctx, config)
350-
request := app.Image("draw a cat", WithProvider("default"), WithModel("gpt-image-1"))
350+
request := app.Image("draw a cat").Provider("default").Model("gpt-image-1")
351351

352352
req, ok := request.(*imageRequest)
353353
assert.True(t, ok)
@@ -459,6 +459,154 @@ func TestImageRequest_StoreUsesResponseStore(t *testing.T) {
459459
assert.Empty(t, response.storePath)
460460
}
461461

462+
func TestApplication_Audio(t *testing.T) {
463+
ctx := context.Background()
464+
config := contractsai.Config{
465+
Default: "default",
466+
Providers: map[string]contractsai.ProviderConfig{
467+
"default": {Via: mocksai.NewProvider(t)},
468+
},
469+
}
470+
471+
app := NewApplication(ctx, config)
472+
request := app.Audio("welcome to goravel").Provider("default").Model("gpt-4o-mini-tts").Male().Instructions("Speak slowly").Timeout(2 * time.Second)
473+
474+
req, ok := request.(*audioRequest)
475+
assert.True(t, ok)
476+
assert.Equal(t, ctx, req.ctx)
477+
assert.Equal(t, app, req.app)
478+
assert.Equal(t, "welcome to goravel", req.prompt)
479+
assert.Equal(t, "default", req.provider)
480+
assert.Equal(t, "gpt-4o-mini-tts", req.model)
481+
assert.Equal(t, defaultMaleVoice, req.voice)
482+
assert.Equal(t, "Speak slowly", req.instructions)
483+
assert.Equal(t, 2*time.Second, req.timeout)
484+
485+
assert.Same(t, req, request.Female())
486+
assert.Equal(t, defaultFemaleVoice, req.voice)
487+
}
488+
489+
func TestAudioRequest_Generate(t *testing.T) {
490+
ctx := context.Background()
491+
provider := &applicationAudioProviderStub{}
492+
config := contractsai.Config{
493+
Default: "default",
494+
Providers: map[string]contractsai.ProviderConfig{
495+
"default": {Via: provider},
496+
},
497+
}
498+
499+
app := NewApplication(context.Background(), config)
500+
response := &applicationAudioResponseStub{}
501+
provider.response = response
502+
503+
result, err := app.Audio("welcome to goravel").
504+
Provider("default").
505+
Model("gpt-4o-mini-tts").
506+
Male().
507+
Instructions("Speak slowly").
508+
Timeout(3 * time.Second).
509+
Generate()
510+
511+
require.NoError(t, err)
512+
assert.Equal(t, response, result)
513+
assert.Equal(t, ctx, provider.ctx)
514+
assert.Equal(t, contractsai.AudioPrompt{
515+
Prompt: "welcome to goravel",
516+
Model: "gpt-4o-mini-tts",
517+
Voice: defaultMaleVoice,
518+
Instructions: "Speak slowly",
519+
Timeout: 3 * time.Second,
520+
}, provider.prompt)
521+
}
522+
523+
func TestAudioRequest_StoreUsesResponseStore(t *testing.T) {
524+
provider := &applicationAudioProviderStub{}
525+
app := NewApplication(context.Background(), contractsai.Config{
526+
Default: "default",
527+
Providers: map[string]contractsai.ProviderConfig{
528+
"default": {Via: provider},
529+
},
530+
})
531+
response := &applicationAudioResponseStub{storePathResult: "audio/generated.mp3"}
532+
provider.response = response
533+
534+
path, err := app.Audio("welcome to goravel").Store()
535+
536+
require.NoError(t, err)
537+
assert.Equal(t, "audio/generated.mp3", path)
538+
assert.Equal(t, 1, response.storeCalls)
539+
assert.Equal(t, 0, response.storeAsCalls)
540+
assert.Empty(t, response.storePath)
541+
}
542+
543+
func TestApplication_audio(t *testing.T) {
544+
tests := []struct {
545+
name string
546+
options []contractsai.Option
547+
prompt contractsai.AudioPrompt
548+
setup func() contractsai.Config
549+
expectError error
550+
expectPrompt contractsai.AudioPrompt
551+
}{
552+
{
553+
name: "success with default model",
554+
options: []contractsai.Option{WithProvider("openai")},
555+
prompt: contractsai.AudioPrompt{
556+
Prompt: "welcome to goravel",
557+
Voice: defaultFemaleVoice,
558+
},
559+
setup: func() contractsai.Config {
560+
provider := &applicationAudioProviderStub{}
561+
provider.response = &applicationAudioResponseStub{}
562+
return contractsai.Config{
563+
Default: "default",
564+
Providers: map[string]contractsai.ProviderConfig{
565+
"default": {Via: mocksai.NewProvider(t)},
566+
"openai": {Via: provider},
567+
},
568+
}
569+
},
570+
expectPrompt: contractsai.AudioPrompt{
571+
Prompt: "welcome to goravel",
572+
Voice: defaultFemaleVoice,
573+
},
574+
},
575+
{
576+
name: "provider does not support audio",
577+
prompt: contractsai.AudioPrompt{
578+
Prompt: "welcome to goravel",
579+
},
580+
setup: func() contractsai.Config {
581+
return contractsai.Config{
582+
Default: "default",
583+
Providers: map[string]contractsai.ProviderConfig{
584+
"default": {Via: mocksai.NewProvider(t)},
585+
},
586+
}
587+
},
588+
expectError: errors.AIProviderDoesNotSupportAudio.Args("default"),
589+
},
590+
}
591+
592+
for _, tt := range tests {
593+
t.Run(tt.name, func(t *testing.T) {
594+
app := NewApplication(context.Background(), tt.setup())
595+
response, err := app.audio(context.Background(), tt.prompt, tt.options...)
596+
assert.Equal(t, tt.expectError, err)
597+
if tt.expectError != nil {
598+
assert.Nil(t, response)
599+
return
600+
}
601+
602+
require.NotNil(t, response)
603+
provider, ok := app.config.Providers["openai"].Via.(*applicationAudioProviderStub)
604+
require.True(t, ok)
605+
assert.Equal(t, tt.expectPrompt, provider.prompt)
606+
})
607+
}
608+
}
609+
462610
func TestApplication_image(t *testing.T) {
463611
tests := []struct {
464612
name string
@@ -670,6 +818,90 @@ func (r *applicationImageResponseStub) Then(callback func(contractsai.ImageRespo
670818
return r
671819
}
672820

821+
type applicationAudioProviderStub struct {
822+
ctx context.Context
823+
prompt contractsai.AudioPrompt
824+
response contractsai.AudioResponse
825+
err error
826+
}
827+
828+
func (p *applicationAudioProviderStub) Prompt(context.Context, contractsai.AgentPrompt) (contractsai.AgentResponse, error) {
829+
return nil, nil
830+
}
831+
832+
func (p *applicationAudioProviderStub) Stream(context.Context, contractsai.AgentPrompt) (contractsai.StreamableAgentResponse, error) {
833+
return nil, nil
834+
}
835+
836+
func (p *applicationAudioProviderStub) Audio(ctx context.Context, prompt contractsai.AudioPrompt) (contractsai.AudioResponse, error) {
837+
p.ctx = ctx
838+
p.prompt = prompt
839+
return p.response, p.err
840+
}
841+
842+
type applicationAudioResponseStub struct {
843+
storePath []string
844+
storePathResult string
845+
storeAsName string
846+
storeAsPath []string
847+
storeCalls int
848+
storeAsCalls int
849+
}
850+
851+
func (r *applicationAudioResponseStub) Content() ([]byte, error) {
852+
return []byte("audio"), nil
853+
}
854+
855+
func (r *applicationAudioResponseStub) MimeType() string { return "audio/mpeg" }
856+
857+
func (r *applicationAudioResponseStub) Store(disk ...string) (string, error) {
858+
r.storeCalls++
859+
r.storePath = append([]string(nil), disk...)
860+
if r.storePathResult != "" {
861+
return r.storePathResult, nil
862+
}
863+
864+
content, err := r.Content()
865+
if err != nil {
866+
return "", err
867+
}
868+
869+
resolvedDisk, err := resolveAudioStoreDisk(disk)
870+
if err != nil {
871+
return "", err
872+
}
873+
874+
return audioStorer{}.Store(content, "generated.mp3", resolvedDisk)
875+
}
876+
877+
func (r *applicationAudioResponseStub) StoreAs(path string, disk ...string) (string, error) {
878+
r.storeAsCalls++
879+
r.storeAsName = path
880+
r.storeAsPath = append([]string(nil), disk...)
881+
882+
content, err := r.Content()
883+
if err != nil {
884+
return "", err
885+
}
886+
887+
resolvedDisk, err := resolveAudioStoreDisk(disk)
888+
if err != nil {
889+
return "", err
890+
}
891+
892+
return audioStorer{}.StoreAs(content, path, resolvedDisk)
893+
}
894+
895+
func (r *applicationAudioResponseStub) Usage() contractsai.Usage { return nil }
896+
897+
func (r *applicationAudioResponseStub) Then(callback func(contractsai.AudioResponse)) contractsai.AudioResponse {
898+
if callback != nil {
899+
callback(r)
900+
}
901+
902+
return r
903+
}
904+
673905
func (m *applicationTestMiddleware) Handle(ctx context.Context, prompt contractsai.AgentPrompt, next contractsai.Next) (contractsai.AgentResponse, error) {
674906
response, err := next(ctx, prompt)
675907
if err != nil {

ai/audio_request.go

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
package ai
2+
3+
import (
4+
"context"
5+
"time"
6+
7+
contractsai "github.com/goravel/framework/contracts/ai"
8+
)
9+
10+
type audioRequest struct {
11+
ctx context.Context
12+
app *Application
13+
prompt string
14+
provider string
15+
model string
16+
voice string
17+
instructions string
18+
timeout time.Duration
19+
}
20+
21+
func NewAudioRequest(ctx context.Context, app *Application, prompt string) contractsai.AudioRequest {
22+
return &audioRequest{
23+
ctx: ctx,
24+
app: app,
25+
prompt: prompt,
26+
voice: defaultFemaleVoice,
27+
}
28+
}
29+
30+
func (r *audioRequest) Model(model string) contractsai.AudioRequest {
31+
r.model = model
32+
return r
33+
}
34+
35+
func (r *audioRequest) Provider(provider string) contractsai.AudioRequest {
36+
r.provider = provider
37+
return r
38+
}
39+
40+
func (r *audioRequest) Voice(voice string) contractsai.AudioRequest {
41+
r.voice = voice
42+
return r
43+
}
44+
45+
func (r *audioRequest) Male() contractsai.AudioRequest {
46+
r.voice = defaultMaleVoice
47+
return r
48+
}
49+
50+
func (r *audioRequest) Female() contractsai.AudioRequest {
51+
r.voice = defaultFemaleVoice
52+
return r
53+
}
54+
55+
func (r *audioRequest) Instructions(instructions string) contractsai.AudioRequest {
56+
r.instructions = instructions
57+
return r
58+
}
59+
60+
func (r *audioRequest) Timeout(timeout time.Duration) contractsai.AudioRequest {
61+
r.timeout = timeout
62+
return r
63+
}
64+
65+
func (r *audioRequest) Store(disk ...string) (string, error) {
66+
response, err := r.Generate()
67+
if err != nil {
68+
return "", err
69+
}
70+
71+
return response.Store(disk...)
72+
}
73+
74+
func (r *audioRequest) StoreAs(path string, disk ...string) (string, error) {
75+
response, err := r.Generate()
76+
if err != nil {
77+
return "", err
78+
}
79+
80+
return response.StoreAs(path, disk...)
81+
}
82+
83+
func (r *audioRequest) Generate() (contractsai.AudioResponse, error) {
84+
options := make([]contractsai.Option, 0, 2)
85+
if r.provider != "" {
86+
options = append(options, WithProvider(r.provider))
87+
}
88+
if r.model != "" {
89+
options = append(options, WithModel(r.model))
90+
}
91+
92+
return r.app.audio(r.ctx, contractsai.AudioPrompt{
93+
Prompt: r.prompt,
94+
Model: r.model,
95+
Voice: r.voice,
96+
Instructions: r.instructions,
97+
Timeout: r.timeout,
98+
}, options...)
99+
}

0 commit comments

Comments
 (0)