Skip to content

Commit 02e034d

Browse files
committed
Merge branch 'main-pull-requests' into main-vxcontrol
2 parents b94dd9c + 98a49b6 commit 02e034d

13 files changed

Lines changed: 939 additions & 19 deletions

agents/executor.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ func (e *Executor) doAction(
134134
}), nil
135135
}
136136

137-
observation, err := tool.Call(ctx, action.ToolInput)
137+
observation, err := tool.Call(ctx, strings.TrimSuffix(action.ToolInput, "\nObservation:"))
138138
if err != nil {
139139
return nil, err
140140
}

agents/executor_test.go

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ type testAgent struct {
2424
err error
2525
inputKeys []string
2626
outputKeys []string
27+
tools []tools.Tool
2728

2829
recordedIntermediateSteps []schema.AgentStep
2930
recordedInputs map[string]string
@@ -51,7 +52,7 @@ func (a testAgent) GetOutputKeys() []string {
5152
}
5253

5354
func (a *testAgent) GetTools() []tools.Tool {
54-
return nil
55+
return a.tools
5556
}
5657

5758
func TestExecutorWithErrorHandler(t *testing.T) {
@@ -173,3 +174,59 @@ func TestExecutorWithOpenAIFunctionAgent(t *testing.T) {
173174
require.True(t, strings.Contains(result, "2012") || strings.Contains(result, "March"),
174175
"correct answer 2012 or March not in response")
175176
}
177+
178+
// mockTool implements the tools.Tool interface for testing
179+
type mockTool struct {
180+
name string
181+
description string
182+
receivedInputPtr *string
183+
}
184+
185+
func (m *mockTool) Name() string {
186+
return m.name
187+
}
188+
189+
func (m *mockTool) Description() string {
190+
return m.description
191+
}
192+
193+
func (m *mockTool) Call(_ context.Context, input string) (string, error) {
194+
*m.receivedInputPtr = input
195+
return "mock result", nil
196+
}
197+
198+
func TestExecutorTrimsObservationSuffix(t *testing.T) {
199+
t.Parallel()
200+
ctx := context.Background()
201+
202+
// Create a mock tool that records what input it receives
203+
var receivedInput string
204+
mockToolInst := &mockTool{
205+
name: "mock_tool",
206+
description: "A mock tool for testing",
207+
receivedInputPtr: &receivedInput,
208+
}
209+
210+
// Create a test agent that returns an action with trailing "\nObservation:"
211+
testAgent := &testAgent{
212+
actions: []schema.AgentAction{
213+
{
214+
Tool: "mock_tool",
215+
ToolInput: "test input\nObservation:",
216+
Log: "Action: mock_tool\nAction Input: test input\nObservation:",
217+
},
218+
},
219+
inputKeys: []string{"input"},
220+
outputKeys: []string{"output"},
221+
tools: []tools.Tool{mockToolInst},
222+
}
223+
224+
executor := agents.NewExecutor(testAgent, agents.WithMaxIterations(1))
225+
226+
_, err := chains.Call(ctx, executor, map[string]any{"input": "test question"})
227+
// We expect ErrNotFinished since our test agent doesn't provide a finish action
228+
require.ErrorIs(t, err, agents.ErrNotFinished)
229+
230+
// Verify that the tool received the input with "\nObservation:" trimmed off
231+
require.Equal(t, "test input", receivedInput, "Tool should receive input with \\nObservation: suffix trimmed")
232+
}

agents/markl_test.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,16 @@ func TestMRKLOutputParser(t *testing.T) {
3737
expectedFinish: nil,
3838
expectedErr: nil,
3939
},
40+
{
41+
input: "Action: calculator\nAction Input: 5 + 3\nObservation:",
42+
expectedActions: []schema.AgentAction{{
43+
Tool: "calculator",
44+
ToolInput: "5 + 3\nObservation:",
45+
Log: "Action: calculator\nAction Input: 5 + 3\nObservation:",
46+
}},
47+
expectedFinish: nil,
48+
expectedErr: nil,
49+
},
4050
}
4151

4252
a := OneShotZeroAgent{}

docs/package.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,5 +62,10 @@
6262
},
6363
"engines": {
6464
"node": ">=18"
65+
},
66+
"pnpm": {
67+
"overrides": {
68+
"webpack-dev-server": ">=5.2.1"
69+
}
6570
}
6671
}

llms/bedrock/bedrockllm_test.go

Lines changed: 144 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
package bedrock_test
22

33
import (
4+
"context"
45
"net/http"
6+
"os"
57
"testing"
68

7-
"github.com/vxcontrol/langchaingo/httputil"
89
"github.com/vxcontrol/langchaingo/internal/httprr"
910
"github.com/vxcontrol/langchaingo/llms"
1011
"github.com/vxcontrol/langchaingo/llms/bedrock"
@@ -13,13 +14,25 @@ import (
1314
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
1415
)
1516

16-
func setupTest(t *testing.T) (*bedrockruntime.Client, error) {
17-
t.Helper()
17+
func setUpTestWithTransport(rr *httprr.RecordReplay) (*bedrockruntime.Client, error) {
18+
// Configure request scrubbing to remove dynamic AWS headers
19+
rr.ScrubReq(func(req *http.Request) error {
20+
req.Header.Del("Amz-Sdk-Invocation-Id")
21+
req.Header.Del("Amz-Sdk-Request")
22+
req.Header.Del("X-Amz-Date")
23+
return nil
24+
})
1825

19-
cfg, err := config.LoadDefaultConfig(t.Context())
26+
httpClient := &http.Client{
27+
Transport: rr,
28+
}
29+
30+
cfg, err := config.LoadDefaultConfig(context.Background(),
31+
config.WithHTTPClient(httpClient))
2032
if err != nil {
2133
return nil, err
2234
}
35+
2336
client := bedrockruntime.NewFromConfig(cfg)
2437
return client, nil
2538
}
@@ -37,12 +50,8 @@ func TestAmazonOutput(t *testing.T) {
3750
t.Parallel()
3851
}
3952

40-
// Replace httputil.DefaultClient with httprr client
41-
oldClient := httputil.DefaultClient
42-
httputil.DefaultClient = rr.Client()
43-
defer func() { httputil.DefaultClient = oldClient }()
44-
45-
client, err := setupTest(t)
53+
// Configure AWS client to use httprr transport
54+
client, err := setUpTestWithTransport(rr)
4655
if err != nil {
4756
t.Fatal(err)
4857
}
@@ -68,8 +77,6 @@ func TestAmazonOutput(t *testing.T) {
6877

6978
// All the test models.
7079
models := []string{
71-
bedrock.ModelAi21J2MidV1,
72-
bedrock.ModelAi21J2UltraV1,
7380
bedrock.ModelAmazonTitanTextLiteV1,
7481
bedrock.ModelAmazonTitanTextExpressV1,
7582
bedrock.ModelAnthropicClaudeV3Sonnet,
@@ -79,10 +86,11 @@ func TestAmazonOutput(t *testing.T) {
7986
bedrock.ModelAnthropicClaudeInstantV1,
8087
bedrock.ModelCohereCommandTextV14,
8188
bedrock.ModelCohereCommandLightTextV14,
82-
bedrock.ModelMetaLlama213bChatV1,
83-
bedrock.ModelMetaLlama270bChatV1,
8489
bedrock.ModelMetaLlama38bInstructV1,
8590
bedrock.ModelMetaLlama370bInstructV1,
91+
bedrock.ModelAmazonNovaMicroV1,
92+
bedrock.ModelAmazonNovaLiteV1,
93+
bedrock.ModelAmazonNovaProV1,
8694
}
8795

8896
for _, model := range models {
@@ -97,3 +105,125 @@ func TestAmazonOutput(t *testing.T) {
97105
}
98106
}
99107
}
108+
109+
func TestAmazonNova(t *testing.T) {
110+
httprr.SkipIfNoCredentialsAndRecordingMissing(t, "AWS_ACCESS_KEY_ID")
111+
112+
rr := httprr.OpenForTest(t, http.DefaultTransport)
113+
defer rr.Close()
114+
115+
// Only run tests in parallel when not recording (to avoid rate limits)
116+
if !rr.Recording() {
117+
t.Parallel()
118+
}
119+
120+
// Configure AWS client to use httprr transport
121+
client, err := setUpTestWithTransport(rr)
122+
if err != nil {
123+
t.Fatal(err)
124+
}
125+
llm, err := bedrock.New(bedrock.WithClient(client))
126+
if err != nil {
127+
t.Fatal(err)
128+
}
129+
130+
msgs := []llms.MessageContent{
131+
{
132+
Role: llms.ChatMessageTypeSystem,
133+
Parts: []llms.ContentPart{
134+
llms.TextPart("You know all about AI."),
135+
},
136+
},
137+
{
138+
Role: llms.ChatMessageTypeHuman,
139+
Parts: []llms.ContentPart{
140+
llms.TextPart("Explain AI in 10 words or less."),
141+
},
142+
},
143+
}
144+
145+
// All the test models.
146+
models := []string{
147+
bedrock.ModelAmazonNovaMicroV1,
148+
bedrock.ModelAmazonNovaLiteV1,
149+
bedrock.ModelAmazonNovaProV1,
150+
}
151+
152+
ctx := context.Background()
153+
154+
for _, model := range models {
155+
t.Logf("Model output for %s:-", model)
156+
157+
resp, err := llm.GenerateContent(ctx, msgs, llms.WithModel(model), llms.WithMaxTokens(4096))
158+
if err != nil {
159+
t.Fatal(err)
160+
}
161+
for i, choice := range resp.Choices {
162+
t.Logf("Choice %d: %s", i, choice.Content)
163+
}
164+
}
165+
}
166+
167+
func TestAnthropicNovaImage(t *testing.T) {
168+
httprr.SkipIfNoCredentialsAndRecordingMissing(t, "AWS_ACCESS_KEY_ID")
169+
170+
rr := httprr.OpenForTest(t, http.DefaultTransport)
171+
defer rr.Close()
172+
173+
// Only run tests in parallel when not recording (to avoid rate limits)
174+
if !rr.Recording() {
175+
t.Parallel()
176+
}
177+
178+
// Configure AWS client to use httprr transport
179+
client, err := setUpTestWithTransport(rr)
180+
if err != nil {
181+
t.Fatal(err)
182+
}
183+
llm, err := bedrock.New(bedrock.WithClient(client))
184+
if err != nil {
185+
t.Fatal(err)
186+
}
187+
188+
image, err := os.ReadFile("testdata/wikipage.jpg")
189+
mimeType := "image/jpeg"
190+
if err != nil {
191+
t.Fatal(err)
192+
}
193+
194+
msgs := []llms.MessageContent{
195+
{
196+
Role: llms.ChatMessageTypeSystem,
197+
Parts: []llms.ContentPart{
198+
llms.TextPart("You know all about AI."),
199+
},
200+
},
201+
{
202+
Role: llms.ChatMessageTypeHuman,
203+
Parts: []llms.ContentPart{
204+
llms.TextPart("Explain AI according to the image. Provide quotes from the image."),
205+
llms.BinaryPart(mimeType, image),
206+
},
207+
},
208+
}
209+
210+
// All the test models.
211+
models := []string{
212+
bedrock.ModelAmazonNovaLiteV1,
213+
bedrock.ModelAmazonNovaProV1,
214+
}
215+
216+
ctx := context.Background()
217+
218+
for _, model := range models {
219+
t.Logf("Model output for %s:-", model)
220+
221+
resp, err := llm.GenerateContent(ctx, msgs, llms.WithModel(model), llms.WithMaxTokens(4096))
222+
if err != nil {
223+
t.Fatal(err)
224+
}
225+
for i, choice := range resp.Choices {
226+
t.Logf("Choice %d: %s", i, choice.Content)
227+
}
228+
}
229+
}

llms/bedrock/internal/bedrockclient/bedrockclient.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@ type Message struct {
3030
}
3131

3232
func getProvider(modelID string) string {
33+
// Check for Nova models (including inference profiles like us.amazon.nova-*)
34+
if strings.Contains(modelID, ".nova-") || strings.Contains(modelID, "amazon.nova-") {
35+
return "nova"
36+
}
37+
38+
parts := strings.Split(modelID, ".")
39+
40+
// For backward compatibility with the original provider detection
3341
switch {
3442
case strings.Contains(modelID, "ai21"):
3543
return "ai21"
@@ -41,9 +49,14 @@ func getProvider(modelID string) string {
4149
return "cohere"
4250
case strings.Contains(modelID, "meta"):
4351
return "meta"
44-
default:
45-
return ""
4652
}
53+
54+
// Default to using the first part of the model ID
55+
if len(parts) > 0 {
56+
return parts[0]
57+
}
58+
59+
return ""
4760
}
4861

4962
// NewClient creates a new Bedrock client.
@@ -66,6 +79,8 @@ func (c *Client) CreateCompletion(ctx context.Context,
6679
return createAi21Completion(ctx, c.client, modelID, messages, options)
6780
case "amazon":
6881
return createAmazonCompletion(ctx, c.client, modelID, messages, options)
82+
case "nova":
83+
return createNovaCompletion(ctx, c.client, modelID, messages, options)
6984
case "anthropic":
7085
return createAnthropicCompletion(ctx, c.client, modelID, messages, options)
7186
case "cohere":

llms/bedrock/internal/bedrockclient/bedrockclient_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ func TestGetProvider(t *testing.T) {
4646
{
4747
name: "unknown provider",
4848
modelID: "unknown.model",
49-
expected: "",
49+
expected: "unknown",
5050
},
5151
}
5252

0 commit comments

Comments
 (0)