11package bedrock_test
22
33import (
4+ "context"
45 "net/http"
6+ "os"
57 "testing"
68
7- "github.com/vxcontrol/langchaingo/httputil"
89 "github.com/vxcontrol/langchaingo/internal/httprr"
910 "github.com/vxcontrol/langchaingo/llms"
1011 "github.com/vxcontrol/langchaingo/llms/bedrock"
@@ -13,13 +14,25 @@ import (
1314 "github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
1415)
1516
16- func setupTest (t * testing.T ) (* bedrockruntime.Client , error ) {
17- t .Helper ()
17+ func setUpTestWithTransport (rr * httprr.RecordReplay ) (* bedrockruntime.Client , error ) {
18+ // Configure request scrubbing to remove dynamic AWS headers
19+ rr .ScrubReq (func (req * http.Request ) error {
20+ req .Header .Del ("Amz-Sdk-Invocation-Id" )
21+ req .Header .Del ("Amz-Sdk-Request" )
22+ req .Header .Del ("X-Amz-Date" )
23+ return nil
24+ })
1825
19- cfg , err := config .LoadDefaultConfig (t .Context ())
26+ httpClient := & http.Client {
27+ Transport : rr ,
28+ }
29+
30+ cfg , err := config .LoadDefaultConfig (context .Background (),
31+ config .WithHTTPClient (httpClient ))
2032 if err != nil {
2133 return nil , err
2234 }
35+
2336 client := bedrockruntime .NewFromConfig (cfg )
2437 return client , nil
2538}
@@ -37,12 +50,8 @@ func TestAmazonOutput(t *testing.T) {
3750 t .Parallel ()
3851 }
3952
40- // Replace httputil.DefaultClient with httprr client
41- oldClient := httputil .DefaultClient
42- httputil .DefaultClient = rr .Client ()
43- defer func () { httputil .DefaultClient = oldClient }()
44-
45- client , err := setupTest (t )
53+ // Configure AWS client to use httprr transport
54+ client , err := setUpTestWithTransport (rr )
4655 if err != nil {
4756 t .Fatal (err )
4857 }
@@ -68,8 +77,6 @@ func TestAmazonOutput(t *testing.T) {
6877
6978 // All the test models.
7079 models := []string {
71- bedrock .ModelAi21J2MidV1 ,
72- bedrock .ModelAi21J2UltraV1 ,
7380 bedrock .ModelAmazonTitanTextLiteV1 ,
7481 bedrock .ModelAmazonTitanTextExpressV1 ,
7582 bedrock .ModelAnthropicClaudeV3Sonnet ,
@@ -79,10 +86,11 @@ func TestAmazonOutput(t *testing.T) {
7986 bedrock .ModelAnthropicClaudeInstantV1 ,
8087 bedrock .ModelCohereCommandTextV14 ,
8188 bedrock .ModelCohereCommandLightTextV14 ,
82- bedrock .ModelMetaLlama213bChatV1 ,
83- bedrock .ModelMetaLlama270bChatV1 ,
8489 bedrock .ModelMetaLlama38bInstructV1 ,
8590 bedrock .ModelMetaLlama370bInstructV1 ,
91+ bedrock .ModelAmazonNovaMicroV1 ,
92+ bedrock .ModelAmazonNovaLiteV1 ,
93+ bedrock .ModelAmazonNovaProV1 ,
8694 }
8795
8896 for _ , model := range models {
@@ -97,3 +105,125 @@ func TestAmazonOutput(t *testing.T) {
97105 }
98106 }
99107}
108+
109+ func TestAmazonNova (t * testing.T ) {
110+ httprr .SkipIfNoCredentialsAndRecordingMissing (t , "AWS_ACCESS_KEY_ID" )
111+
112+ rr := httprr .OpenForTest (t , http .DefaultTransport )
113+ defer rr .Close ()
114+
115+ // Only run tests in parallel when not recording (to avoid rate limits)
116+ if ! rr .Recording () {
117+ t .Parallel ()
118+ }
119+
120+ // Configure AWS client to use httprr transport
121+ client , err := setUpTestWithTransport (rr )
122+ if err != nil {
123+ t .Fatal (err )
124+ }
125+ llm , err := bedrock .New (bedrock .WithClient (client ))
126+ if err != nil {
127+ t .Fatal (err )
128+ }
129+
130+ msgs := []llms.MessageContent {
131+ {
132+ Role : llms .ChatMessageTypeSystem ,
133+ Parts : []llms.ContentPart {
134+ llms .TextPart ("You know all about AI." ),
135+ },
136+ },
137+ {
138+ Role : llms .ChatMessageTypeHuman ,
139+ Parts : []llms.ContentPart {
140+ llms .TextPart ("Explain AI in 10 words or less." ),
141+ },
142+ },
143+ }
144+
145+ // All the test models.
146+ models := []string {
147+ bedrock .ModelAmazonNovaMicroV1 ,
148+ bedrock .ModelAmazonNovaLiteV1 ,
149+ bedrock .ModelAmazonNovaProV1 ,
150+ }
151+
152+ ctx := context .Background ()
153+
154+ for _ , model := range models {
155+ t .Logf ("Model output for %s:-" , model )
156+
157+ resp , err := llm .GenerateContent (ctx , msgs , llms .WithModel (model ), llms .WithMaxTokens (4096 ))
158+ if err != nil {
159+ t .Fatal (err )
160+ }
161+ for i , choice := range resp .Choices {
162+ t .Logf ("Choice %d: %s" , i , choice .Content )
163+ }
164+ }
165+ }
166+
167+ func TestAnthropicNovaImage (t * testing.T ) {
168+ httprr .SkipIfNoCredentialsAndRecordingMissing (t , "AWS_ACCESS_KEY_ID" )
169+
170+ rr := httprr .OpenForTest (t , http .DefaultTransport )
171+ defer rr .Close ()
172+
173+ // Only run tests in parallel when not recording (to avoid rate limits)
174+ if ! rr .Recording () {
175+ t .Parallel ()
176+ }
177+
178+ // Configure AWS client to use httprr transport
179+ client , err := setUpTestWithTransport (rr )
180+ if err != nil {
181+ t .Fatal (err )
182+ }
183+ llm , err := bedrock .New (bedrock .WithClient (client ))
184+ if err != nil {
185+ t .Fatal (err )
186+ }
187+
188+ image , err := os .ReadFile ("testdata/wikipage.jpg" )
189+ mimeType := "image/jpeg"
190+ if err != nil {
191+ t .Fatal (err )
192+ }
193+
194+ msgs := []llms.MessageContent {
195+ {
196+ Role : llms .ChatMessageTypeSystem ,
197+ Parts : []llms.ContentPart {
198+ llms .TextPart ("You know all about AI." ),
199+ },
200+ },
201+ {
202+ Role : llms .ChatMessageTypeHuman ,
203+ Parts : []llms.ContentPart {
204+ llms .TextPart ("Explain AI according to the image. Provide quotes from the image." ),
205+ llms .BinaryPart (mimeType , image ),
206+ },
207+ },
208+ }
209+
210+ // All the test models.
211+ models := []string {
212+ bedrock .ModelAmazonNovaLiteV1 ,
213+ bedrock .ModelAmazonNovaProV1 ,
214+ }
215+
216+ ctx := context .Background ()
217+
218+ for _ , model := range models {
219+ t .Logf ("Model output for %s:-" , model )
220+
221+ resp , err := llm .GenerateContent (ctx , msgs , llms .WithModel (model ), llms .WithMaxTokens (4096 ))
222+ if err != nil {
223+ t .Fatal (err )
224+ }
225+ for i , choice := range resp .Choices {
226+ t .Logf ("Choice %d: %s" , i , choice .Content )
227+ }
228+ }
229+ }
0 commit comments