Skip to content

Commit edd87ef

Browse files
hwbrzzlCopilot
andauthored
feat(ai): add transcription support (#1470)
* feat(ai): add transcription support * fix(ai): address transcription review comments * fix(ai): remove unused transcription test field * fix(ai): reject typed nil transcription files * optimize * fix(ai): restore image.Of and add FromStorage test coverage Agent-Logs-Url: https://github.com/goravel/framework/sessions/fd686dee-9139-4d5c-9ad1-2922227e7dd6 Co-authored-by: hwbrzzl <24771476+hwbrzzl@users.noreply.github.com> * optimize * optimize --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: hwbrzzl <24771476+hwbrzzl@users.noreply.github.com>
1 parent c78a1d9 commit edd87ef

24 files changed

Lines changed: 1675 additions & 77 deletions

ai/application.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ func (r *Application) Image(prompt string) contractsai.ImageRequest {
4444
return NewImageRequest(r.ctx, r, prompt)
4545
}
4646

47+
func (r *Application) Transcription(file contractsai.StorableFile) contractsai.TranscriptionRequest {
48+
return NewTranscriptionRequest(r.ctx, r, file)
49+
}
50+
4751
func (r *Application) putFile(ctx context.Context, file contractsai.StorableFile, options ...contractsai.Option) (contractsai.FileResponse, error) {
4852
_, providerName, provider, err := r.resolveProvider(options)
4953
if err != nil {
@@ -128,6 +132,23 @@ func (r *Application) image(ctx context.Context, prompt contractsai.ImagePrompt,
128132
return imageProvider.Image(ctx, prompt)
129133
}
130134

135+
func (r *Application) transcription(ctx context.Context, prompt contractsai.TranscriptionPrompt, options ...contractsai.Option) (contractsai.TranscriptionResponse, error) {
136+
opts, providerName, provider, err := r.resolveProvider(options)
137+
if err != nil {
138+
return nil, err
139+
}
140+
if prompt.Model == "" {
141+
prompt.Model = opts.Model
142+
}
143+
144+
transcriptionProvider, ok := provider.(contractsai.TranscriptionProvider)
145+
if !ok {
146+
return nil, errors.AIProviderDoesNotSupportTranscription.Args(providerName)
147+
}
148+
149+
return transcriptionProvider.Transcription(ctx, prompt)
150+
}
151+
131152
func (r *Application) resolveProvider(options []contractsai.Option) (*contractsai.Options, string, contractsai.Provider, error) {
132153
opts := &contractsai.Options{}
133154
for _, option := range options {

ai/application_test.go

Lines changed: 172 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -478,12 +478,72 @@ func TestApplication_Audio(t *testing.T) {
478478
assert.Equal(t, "welcome to goravel", req.prompt)
479479
assert.Equal(t, "default", req.provider)
480480
assert.Equal(t, "gpt-4o-mini-tts", req.model)
481-
assert.Equal(t, defaultMaleVoice, req.voice)
481+
assert.Equal(t, DefaultMaleVoice, req.voice)
482482
assert.Equal(t, "Speak slowly", req.instructions)
483483
assert.Equal(t, 2*time.Second, req.timeout)
484484

485485
assert.Same(t, req, request.Female())
486-
assert.Equal(t, defaultFemaleVoice, req.voice)
486+
assert.Equal(t, DefaultFemaleVoice, req.voice)
487+
}
488+
489+
func TestApplication_Transcription(t *testing.T) {
490+
ctx := context.Background()
491+
config := contractsai.Config{
492+
Default: "default",
493+
Providers: map[string]contractsai.ProviderConfig{
494+
"default": {Via: mocksai.NewProvider(t)},
495+
},
496+
}
497+
file := mocksai.NewStorableFile(t)
498+
499+
app := NewApplication(ctx, config)
500+
request := app.Transcription(file).Provider("default").Model("gpt-4o-mini-transcribe").Language("en").Diarize().Timeout(2 * time.Second)
501+
502+
req, ok := request.(*transcriptionRequest)
503+
assert.True(t, ok)
504+
assert.Equal(t, ctx, req.ctx)
505+
assert.Equal(t, app, req.app)
506+
assert.Equal(t, file, req.file)
507+
assert.Equal(t, "default", req.provider)
508+
assert.Equal(t, "gpt-4o-mini-transcribe", req.model)
509+
assert.Equal(t, "en", req.language)
510+
assert.True(t, req.diarize)
511+
assert.Equal(t, 2*time.Second, req.timeout)
512+
}
513+
514+
func TestTranscriptionRequest_Generate(t *testing.T) {
515+
ctx := context.Background()
516+
provider := &applicationTranscriptionProviderStub{}
517+
config := contractsai.Config{
518+
Default: "default",
519+
Providers: map[string]contractsai.ProviderConfig{
520+
"default": {Via: provider},
521+
},
522+
}
523+
file := mocksai.NewStorableFile(t)
524+
525+
app := NewApplication(context.Background(), config)
526+
response := &applicationTranscriptionResponseStub{}
527+
provider.response = response
528+
529+
result, err := app.Transcription(file).
530+
Provider("default").
531+
Model("gpt-4o-mini-transcribe").
532+
Language("en").
533+
Diarize().
534+
Timeout(3 * time.Second).
535+
Generate()
536+
537+
require.NoError(t, err)
538+
assert.Equal(t, response, result)
539+
assert.Equal(t, ctx, provider.ctx)
540+
assert.Equal(t, contractsai.TranscriptionPrompt{
541+
File: file,
542+
Model: "gpt-4o-mini-transcribe",
543+
Language: "en",
544+
Diarize: true,
545+
Timeout: 3 * time.Second,
546+
}, provider.prompt)
487547
}
488548

489549
func TestAudioRequest_Generate(t *testing.T) {
@@ -514,7 +574,7 @@ func TestAudioRequest_Generate(t *testing.T) {
514574
assert.Equal(t, contractsai.AudioPrompt{
515575
Prompt: "welcome to goravel",
516576
Model: "gpt-4o-mini-tts",
517-
Voice: defaultMaleVoice,
577+
Voice: DefaultMaleVoice,
518578
Instructions: "Speak slowly",
519579
Timeout: 3 * time.Second,
520580
}, provider.prompt)
@@ -554,7 +614,7 @@ func TestApplication_audio(t *testing.T) {
554614
options: []contractsai.Option{WithProvider("openai")},
555615
prompt: contractsai.AudioPrompt{
556616
Prompt: "welcome to goravel",
557-
Voice: defaultFemaleVoice,
617+
Voice: DefaultFemaleVoice,
558618
},
559619
setup: func() contractsai.Config {
560620
provider := &applicationAudioProviderStub{}
@@ -569,7 +629,7 @@ func TestApplication_audio(t *testing.T) {
569629
},
570630
expectPrompt: contractsai.AudioPrompt{
571631
Prompt: "welcome to goravel",
572-
Voice: defaultFemaleVoice,
632+
Voice: DefaultFemaleVoice,
573633
},
574634
},
575635
{
@@ -607,6 +667,70 @@ func TestApplication_audio(t *testing.T) {
607667
}
608668
}
609669

670+
func TestApplication_transcription(t *testing.T) {
671+
file := mocksai.NewStorableFile(t)
672+
tests := []struct {
673+
name string
674+
options []contractsai.Option
675+
prompt contractsai.TranscriptionPrompt
676+
setup func() contractsai.Config
677+
expectError error
678+
expectPrompt contractsai.TranscriptionPrompt
679+
}{
680+
{
681+
name: "success with default model",
682+
options: []contractsai.Option{WithProvider("openai")},
683+
prompt: contractsai.TranscriptionPrompt{
684+
File: file,
685+
},
686+
setup: func() contractsai.Config {
687+
provider := &applicationTranscriptionProviderStub{}
688+
provider.response = &applicationTranscriptionResponseStub{}
689+
return contractsai.Config{
690+
Default: "default",
691+
Providers: map[string]contractsai.ProviderConfig{
692+
"default": {Via: mocksai.NewProvider(t)},
693+
"openai": {Via: provider},
694+
},
695+
}
696+
},
697+
expectPrompt: contractsai.TranscriptionPrompt{File: file},
698+
},
699+
{
700+
name: "provider does not support transcription",
701+
prompt: contractsai.TranscriptionPrompt{
702+
File: file,
703+
},
704+
setup: func() contractsai.Config {
705+
return contractsai.Config{
706+
Default: "default",
707+
Providers: map[string]contractsai.ProviderConfig{
708+
"default": {Via: mocksai.NewProvider(t)},
709+
},
710+
}
711+
},
712+
expectError: errors.AIProviderDoesNotSupportTranscription.Args("default"),
713+
},
714+
}
715+
716+
for _, tt := range tests {
717+
t.Run(tt.name, func(t *testing.T) {
718+
app := NewApplication(context.Background(), tt.setup())
719+
response, err := app.transcription(context.Background(), tt.prompt, tt.options...)
720+
assert.Equal(t, tt.expectError, err)
721+
if tt.expectError != nil {
722+
assert.Nil(t, response)
723+
return
724+
}
725+
726+
require.NotNil(t, response)
727+
provider, ok := app.config.Providers["openai"].Via.(*applicationTranscriptionProviderStub)
728+
require.True(t, ok)
729+
assert.Equal(t, tt.expectPrompt, provider.prompt)
730+
})
731+
}
732+
}
733+
610734
func TestApplication_image(t *testing.T) {
611735
tests := []struct {
612736
name string
@@ -825,6 +949,13 @@ type applicationAudioProviderStub struct {
825949
err error
826950
}
827951

952+
type applicationTranscriptionProviderStub struct {
953+
ctx context.Context
954+
prompt contractsai.TranscriptionPrompt
955+
response contractsai.TranscriptionResponse
956+
err error
957+
}
958+
828959
func (p *applicationAudioProviderStub) Prompt(context.Context, contractsai.AgentPrompt) (contractsai.AgentResponse, error) {
829960
return nil, nil
830961
}
@@ -839,6 +970,20 @@ func (p *applicationAudioProviderStub) Audio(ctx context.Context, prompt contrac
839970
return p.response, p.err
840971
}
841972

973+
func (p *applicationTranscriptionProviderStub) Prompt(context.Context, contractsai.AgentPrompt) (contractsai.AgentResponse, error) {
974+
return nil, nil
975+
}
976+
977+
func (p *applicationTranscriptionProviderStub) Stream(context.Context, contractsai.AgentPrompt) (contractsai.StreamableAgentResponse, error) {
978+
return nil, nil
979+
}
980+
981+
func (p *applicationTranscriptionProviderStub) Transcription(ctx context.Context, prompt contractsai.TranscriptionPrompt) (contractsai.TranscriptionResponse, error) {
982+
p.ctx = ctx
983+
p.prompt = prompt
984+
return p.response, p.err
985+
}
986+
842987
type applicationAudioResponseStub struct {
843988
storePath []string
844989
storePathResult string
@@ -902,6 +1047,28 @@ func (r *applicationAudioResponseStub) Then(callback func(contractsai.AudioRespo
9021047
return r
9031048
}
9041049

1050+
type applicationTranscriptionResponseStub struct {
1051+
text string
1052+
segments []contractsai.TranscriptionSegment
1053+
usage contractsai.Usage
1054+
}
1055+
1056+
func (r *applicationTranscriptionResponseStub) Text() string { return r.text }
1057+
1058+
func (r *applicationTranscriptionResponseStub) Segments() []contractsai.TranscriptionSegment {
1059+
return append([]contractsai.TranscriptionSegment(nil), r.segments...)
1060+
}
1061+
1062+
func (r *applicationTranscriptionResponseStub) Usage() contractsai.Usage { return r.usage }
1063+
1064+
func (r *applicationTranscriptionResponseStub) Then(callback func(contractsai.TranscriptionResponse)) contractsai.TranscriptionResponse {
1065+
if callback != nil {
1066+
callback(r)
1067+
}
1068+
1069+
return r
1070+
}
1071+
9051072
func (m *applicationTestMiddleware) Handle(ctx context.Context, prompt contractsai.AgentPrompt, next contractsai.Next) (contractsai.AgentResponse, error) {
9061073
response, err := next(ctx, prompt)
9071074
if err != nil {

ai/audio_request.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ func NewAudioRequest(ctx context.Context, app *Application, prompt string) contr
2323
ctx: ctx,
2424
app: app,
2525
prompt: prompt,
26-
voice: defaultFemaleVoice,
26+
voice: DefaultFemaleVoice,
2727
}
2828
}
2929

@@ -43,12 +43,12 @@ func (r *audioRequest) Voice(voice string) contractsai.AudioRequest {
4343
}
4444

4545
func (r *audioRequest) Male() contractsai.AudioRequest {
46-
r.voice = defaultMaleVoice
46+
r.voice = DefaultMaleVoice
4747
return r
4848
}
4949

5050
func (r *audioRequest) Female() contractsai.AudioRequest {
51-
r.voice = defaultFemaleVoice
51+
r.voice = DefaultFemaleVoice
5252
return r
5353
}
5454

ai/audio_voice.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
package ai
22

3-
const defaultMaleVoice = "default-male"
4-
const defaultFemaleVoice = "default-female"
3+
const DefaultMaleVoice = "default-male"
4+
const DefaultFemaleVoice = "default-female"

ai/image/image.go

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package image
22

33
import (
44
contractsai "github.com/goravel/framework/contracts/ai"
5-
"github.com/goravel/framework/facades"
65
)
76

87
type Quality = contractsai.ImageQuality
@@ -17,7 +16,3 @@ const (
1716
SizePortrait = contractsai.ImageSizePortrait
1817
SizeLandscape = contractsai.ImageSizeLandscape
1918
)
20-
21-
func Of(prompt string) contractsai.ImageRequest {
22-
return facades.AI().Image(prompt)
23-
}

ai/image/image_test.go

Lines changed: 0 additions & 31 deletions
This file was deleted.

0 commit comments

Comments
 (0)