Skip to content

Commit f891635

Browse files
committed
fix: harden providers, tools, and remove duplicate nvidia shim
Cap HTTP response reads at 10MB, replace grep-based search with bounded filepath walk, consolidate NVIDIA onto OpenAICompat, and add compat stubs for downstream iterate APIs.
1 parent 843f5d8 commit f891635

24 files changed

Lines changed: 158 additions & 294 deletions

.golangci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ linters-settings:
2727

2828
errcheck:
2929
# Ignore common intentional unchecked errors.
30-
ignore: fmt:.*,io/fs:PathError
30+
exclude-functions: fmt:.*,io/fs:PathError
3131

3232
issues:
3333
exclude-rules:

anthropic.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ func (p *anthropicProvider) Complete(ctx context.Context, messages []Message, op
167167
}
168168
defer resp.Body.Close()
169169

170-
raw, err := io.ReadAll(resp.Body)
170+
raw, err := io.ReadAll(io.LimitReader(resp.Body, 10*1024*1024))
171171
if err != nil {
172172
return "", fmt.Errorf("read response: %w", err)
173173
}

anthropic_test.go

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package iteragent_test
22

33
import (
44
"context"
5-
"encoding/json"
65
"fmt"
76
"net/http"
87
"net/http/httptest"
@@ -12,36 +11,6 @@ import (
1211
iteragent "github.com/GrayCodeAI/iteragent"
1312
)
1413

15-
// ---------------------------------------------------------------------------
16-
// helpers
17-
// ---------------------------------------------------------------------------
18-
19-
// anthropicServer creates a test server that responds like the Anthropic API.
20-
func anthropicServer(t *testing.T, handler http.HandlerFunc) *httptest.Server {
21-
t.Helper()
22-
return httptest.NewServer(handler)
23-
}
24-
25-
// patchAnthropicURL swaps the hardcoded Anthropic URL — we test via Complete
26-
// using a mock provider that records the request body, since we cannot patch
27-
// the URL directly. Instead we test NewAnthropic then swap out the http.Client
28-
// by exercising exported behaviour end-to-end.
29-
//
30-
// The real Complete calls https://api.anthropic.com. For unit tests we instead:
31-
// 1. Test Name() directly (no network)
32-
// 2. Test request body construction by calling Complete against a local server
33-
// via a custom http.Client set on the provider.
34-
//
35-
// Since anthropicProvider is unexported, we exercise Complete indirectly via
36-
// an Agent backed by a mock provider that returns a fixed body. The actual
37-
// HTTP layer is tested via the SSE tests and the OpenAPI adapter tests.
38-
//
39-
// Here we test the parts we CAN reach without reflection or unexported access:
40-
// - Name() shape
41-
// - thinkingBudget logic (via exported ThinkingLevel constants)
42-
// - NewAnthropic constructor returns non-nil Provider
43-
// - Complete on a testServer-backed provider using the internal testable path
44-
4514
// ---------------------------------------------------------------------------
4615
// NewAnthropic
4716
// ---------------------------------------------------------------------------
@@ -80,30 +49,6 @@ func TestNewAnthropic_DifferentModels(t *testing.T) {
8049
}
8150
}
8251

83-
// ---------------------------------------------------------------------------
84-
// Complete — via httptest (using Agent with the mock HTTP layer)
85-
// ---------------------------------------------------------------------------
86-
87-
// anthropicJSONResponse builds a valid Anthropic /v1/messages response body.
88-
func anthropicJSONResponse(text string) []byte {
89-
resp := map[string]interface{}{
90-
"content": []map[string]interface{}{
91-
{"type": "text", "text": text},
92-
},
93-
}
94-
b, _ := json.Marshal(resp)
95-
return b
96-
}
97-
98-
// anthropicErrorResponse builds an Anthropic error response body.
99-
func anthropicErrorResponse(msg string) []byte {
100-
resp := map[string]interface{}{
101-
"error": map[string]string{"message": msg},
102-
}
103-
b, _ := json.Marshal(resp)
104-
return b
105-
}
106-
10752
// TestAnthropicComplete_* tests use the Agent+Mock pattern because the
10853
// anthropicProvider is unexported. The real provider path is verified via
10954
// direct httptest servers in CompleteStream tests below (which use the SSEClient

azure.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ func (p *AzureOpenAIProvider) Complete(ctx context.Context, messages []Message,
7878
}
7979
defer resp.Body.Close()
8080

81-
respBody, err := io.ReadAll(resp.Body)
81+
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 10*1024*1024))
8282
if err != nil {
8383
return "", fmt.Errorf("read response: %w", err)
8484
}
@@ -160,4 +160,3 @@ func messagesToAzureFormat(messages []Message) []map[string]interface{} {
160160
}
161161
return result
162162
}
163-

azure_test.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,4 +285,3 @@ func TestAzureComplete_ContextCancelled(t *testing.T) {
285285
t.Fatal("expected error from cancelled context")
286286
}
287287
}
288-

bedrock.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ func (p *BedrockProvider) Complete(ctx context.Context, messages []Message, opts
9393
}
9494
defer resp.Body.Close()
9595

96-
respBody, err := io.ReadAll(resp.Body)
96+
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 10*1024*1024))
9797
if err != nil {
9898
return "", fmt.Errorf("read response: %w", err)
9999
}
@@ -138,7 +138,6 @@ func (p *BedrockProvider) CompleteStream(ctx context.Context, messages []Message
138138
return result, nil
139139
}
140140

141-
142141
func (p *BedrockProvider) signRequest(req *http.Request, payload string) {
143142
now := time.Now().UTC()
144143
date := now.Format("20060102T150405Z")

compat.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package iteragent
2+
3+
// Compatibility stubs for iterate. These interfaces and functions existed in
4+
// earlier iteragent versions and are still referenced by downstream consumers.
5+
6+
// TokenStreamer is implemented by providers that support token-level streaming.
7+
type TokenStreamer interface {
8+
CompleteStreamTokens(ctx interface{}, messages []Message, opts CompletionOptions) (<-chan string, error)
9+
}
10+
11+
// ThinkingStreamer is implemented by providers that expose thinking/reasoning tokens.
12+
type ThinkingStreamer interface {
13+
CompleteWithThinking(ctx interface{}, messages []Message, opts CompletionOptions) (string, string, error)
14+
}
15+
16+
// NativeToolCaller is implemented by providers that support native tool calling.
17+
type NativeToolCaller interface {
18+
CompleteWithTools(ctx interface{}, messages []Message, tools []ToolDefinition, opts CompletionOptions) (string, []ToolCall, error)
19+
}
20+
21+
// ProviderContextWindow returns the context window size in tokens for the
22+
// given provider, or 0 if unknown.
23+
func ProviderContextWindow(p Provider) int {
24+
// No providers currently expose their context window size.
25+
return 0
26+
}
27+
28+
// WithLLMCompaction enables LLM-assisted compaction at the given token threshold.
29+
func (a *Agent) WithLLMCompaction(tokens int) *Agent {
30+
// LLM compaction is not supported in this version.
31+
return a
32+
}
33+
34+
// SetPinnedMessages sets messages that should be preserved during compaction.
35+
func (a *Agent) SetPinnedMessages(msgs []Message) {
36+
// Pinned messages are not supported in this version.
37+
}
38+
39+
// LLMCompactionStrategy uses an LLM to compact conversation history.
40+
type LLMCompactionStrategy struct {
41+
Provider Provider
42+
KeepRecent int
43+
}
44+
45+
// Compact reduces the message list to fit within maxTokens using LLM summarisation.
46+
func (s *LLMCompactionStrategy) Compact(messages []Message, maxTokens int) []Message {
47+
// LLM compaction is not supported; fall back to simple truncation.
48+
if len(messages) > s.KeepRecent {
49+
return messages[len(messages)-s.KeepRecent:]
50+
}
51+
return messages
52+
}
53+
54+
// ArgStr extracts a string value from a map[string]string args map.
55+
func ArgStr(args map[string]string, key string) string {
56+
return args[key]
57+
}

context_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func TestEstimateTokens(t *testing.T) {
2828

2929
func TestEstimateTotalTokens(t *testing.T) {
3030
msgs := []iteragent.Message{
31-
{Role: "user", Content: "abcd"}, // 1 token + 4 overhead = 5
31+
{Role: "user", Content: "abcd"}, // 1 token + 4 overhead = 5
3232
{Role: "assistant", Content: "abcd"}, // 1 token + 4 overhead = 5
3333
}
3434
got := iteragent.EstimateTotalTokens(msgs)
@@ -130,7 +130,7 @@ func TestCompactMessagesTiered_Level3_KeepsFirstAndLast(t *testing.T) {
130130
}
131131

132132
cfg := iteragent.ContextConfig{
133-
MaxTokens: 10, // extremely low → all levels triggered
133+
MaxTokens: 10, // extremely low → all levels triggered
134134
KeepRecent: 3,
135135
KeepFirst: 2,
136136
ToolOutputMaxLines: 5,
@@ -369,7 +369,7 @@ func TestCompactMessagesTiered_AssistantSummary(t *testing.T) {
369369
}
370370

371371
cfg := iteragent.ContextConfig{
372-
MaxTokens: 80, // enough to trigger level 2 but not drop everything
372+
MaxTokens: 80, // enough to trigger level 2 but not drop everything
373373
KeepRecent: 2,
374374
KeepFirst: 1,
375375
ToolOutputMaxLines: 50,

gemini.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ func (p *geminiProvider) Complete(ctx context.Context, messages []Message, opts
166166
}
167167
defer resp.Body.Close()
168168

169-
raw, err := io.ReadAll(resp.Body)
169+
raw, err := io.ReadAll(io.LimitReader(resp.Body, 10*1024*1024))
170170
if err != nil {
171171
return "", fmt.Errorf("read response: %w", err)
172172
}

go.sum

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +0,0 @@
1-
# This go.sum file was generated to ensure reproducible builds.
2-
# This module has no external dependencies.

0 commit comments

Comments
 (0)