Skip to content

Commit 6ea2cc2

Browse files
committed
fix: security, concurrency, and correctness hardening
- tools.go: regex-based command filtering, mutex-protected globals, safeJoin fix, grep -- separator - agent.go: ErrUnknownTool sentinel, toolsMu for thread-safe map access, panic recovery in goroutines, shared maxIterations constant - types.go: ProviderRegistry thread safety, correct CacheHitRate formula - sub_agent.go: mutex-protected SubAgentPool - nvidia.go: Complete() now respects CompletionOptions - bedrock.go: remove incorrect X-Amz-Target header, remove dead init() - gemini.go: move API key from URL query to X-Goog-Api-Key header - vertex.go: remove confusing dead credential file read - context.go: ceiling division in EstimateTokens, fix slice aliasing in compaction - openaicompat.go: proper URL hostname check instead of substring match - openapi/adapter.go: filter path/query params from request body - mcp/client.go: handle json.Marshal errors - mcp/transport.go: Wait() on process to prevent zombies - openairesponses.go: add 120s HTTP timeout - retry.go: context check before first attempt, nil guard in RetryableError - sse.go: return error instead of silently breaking on read errors - mock.go: range over string instead of strings.Split - azure.go: reuse buffer in SSEDecoder - provider.go: list valid providers in error message
1 parent 8340f36 commit 6ea2cc2

21 files changed

Lines changed: 235 additions & 107 deletions

agent.go

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package iteragent
66
import (
77
"context"
88
"encoding/json"
9+
"errors"
910
"fmt"
1011
"log/slog"
1112
"strings"
@@ -15,6 +16,11 @@ import (
1516
"github.com/GrayCodeAI/iteragent/openapi"
1617
)
1718

19+
// ErrUnknownTool is returned when a tool call references a tool that is not registered.
20+
var ErrUnknownTool = errors.New("unknown tool")
21+
22+
const maxAgentIterations = 20
23+
1824
// ToolCall represents a tool invocation.
1925
type ToolCall struct {
2026
Tool string `json:"tool"`
@@ -57,6 +63,7 @@ type AgentHooks struct {
5763
type Agent struct {
5864
provider Provider
5965
tools map[string]Tool
66+
toolsMu sync.RWMutex
6067
logger *slog.Logger
6168
Events chan Event
6269
SystemPrompt string
@@ -185,7 +192,7 @@ func (a *Agent) executeToolsSequential(ctx context.Context, calls []ToolCall, it
185192
var toolResults strings.Builder
186193
for _, call := range calls {
187194
result, isError := a.executeSingleTool(ctx, call, iteration, emitFn)
188-
if isError && result == fmt.Sprintf("unknown tool: %s", call.Tool) {
195+
if isError && errors.Is(isErrorToErr(isError, result), ErrUnknownTool) {
189196
toolResults.WriteString(fmt.Sprintf("Tool %s: %s\n", call.Tool, result))
190197
} else {
191198
toolResults.WriteString(fmt.Sprintf("Tool %s result:\n%s\n\n", call.Tool, result))
@@ -194,6 +201,17 @@ func (a *Agent) executeToolsSequential(ctx context.Context, calls []ToolCall, it
194201
return toolResults.String()
195202
}
196203

204+
// isErrorToErr converts the (result, isErrorBool) pair back to an error for checking.
205+
func isErrorToErr(isError bool, result string) error {
206+
if !isError {
207+
return nil
208+
}
209+
if strings.HasPrefix(result, "unknown tool:") {
210+
return ErrUnknownTool
211+
}
212+
return errors.New(result)
213+
}
214+
197215
func (a *Agent) executeToolsParallel(ctx context.Context, calls []ToolCall, iteration int, emitFn func(Event)) string {
198216
type indexedResult struct {
199217
call ToolCall
@@ -213,11 +231,13 @@ func (a *Agent) executeToolsParallel(ctx context.Context, calls []ToolCall, iter
213231
ToolName: c.Tool,
214232
ToolCallID: fmt.Sprintf("%s-%d-%d", c.Tool, iteration, i),
215233
})
234+
a.toolsMu.RLock()
216235
tool, ok := a.tools[c.Tool]
236+
a.toolsMu.RUnlock()
217237
if !ok {
218238
res := fmt.Sprintf("unknown tool: %s", c.Tool)
219239
if a.hooks.OnToolEnd != nil {
220-
a.hooks.OnToolEnd(c.Tool, res, fmt.Errorf("unknown tool: %s", c.Tool))
240+
a.hooks.OnToolEnd(c.Tool, res, ErrUnknownTool)
221241
}
222242
results[i] = indexedResult{call: c, result: res, isError: true, unknown: true}
223243
emitFn(Event{Type: string(EventToolExecutionEnd), ToolName: c.Tool, Result: results[i].result, IsError: true})
@@ -272,7 +292,9 @@ func (a *Agent) executeToolsBatched(ctx context.Context, calls []ToolCall, itera
272292
// executeSingleTool runs one tool call and emits start/end events.
273293
// Returns (result string, isError bool).
274294
func (a *Agent) executeSingleTool(ctx context.Context, call ToolCall, iteration int, emitFn func(Event)) (string, bool) {
295+
a.toolsMu.RLock()
275296
tool, ok := a.tools[call.Tool]
297+
a.toolsMu.RUnlock()
276298
if !ok {
277299
result := fmt.Sprintf("unknown tool: %s", call.Tool)
278300
emitFn(Event{Type: string(EventToolExecutionEnd), ToolName: call.Tool, Result: result, IsError: true})
@@ -330,10 +352,7 @@ func (a *Agent) Run(ctx context.Context, systemPrompt, userMessage string, emitF
330352
}
331353
userMessage = filtered
332354

333-
allTools := make([]Tool, 0, len(a.tools))
334-
for _, t := range a.tools {
335-
allTools = append(allTools, t)
336-
}
355+
allTools := a.GetTools()
337356

338357
messages := []Message{
339358
{Role: "system", Content: systemPrompt + "\n\n" + ToolDescriptions(allTools)},
@@ -342,8 +361,7 @@ func (a *Agent) Run(ctx context.Context, systemPrompt, userMessage string, emitF
342361

343362
opts := a.completionOpts()
344363

345-
const maxIterations = 20
346-
for i := 0; i < maxIterations; i++ {
364+
for i := 0; i < maxAgentIterations; i++ {
347365
a.logger.Info("agent iteration", "step", i+1)
348366
emit(Event{Type: string(EventTurnStart), Content: fmt.Sprintf("turn %d", i+1)})
349367

@@ -401,7 +419,7 @@ func (a *Agent) Run(ctx context.Context, systemPrompt, userMessage string, emitF
401419
emit(Event{Type: string(EventTurnEnd), Content: ""})
402420
}
403421

404-
return "", fmt.Errorf("agent exceeded max iterations (%d)", maxIterations)
422+
return "", fmt.Errorf("agent exceeded max iterations (%d)", maxAgentIterations)
405423
}
406424

407425
func (a *Agent) emit(e Event) {
@@ -517,11 +535,15 @@ func (a *Agent) WithTools(tools []Tool) *Agent {
517535
for _, t := range tools {
518536
toolMap[t.Name] = t
519537
}
538+
a.toolsMu.Lock()
520539
a.tools = toolMap
540+
a.toolsMu.Unlock()
521541
return a
522542
}
523543

524544
func (a *Agent) GetTools() []Tool {
545+
a.toolsMu.RLock()
546+
defer a.toolsMu.RUnlock()
525547
tools := make([]Tool, 0, len(a.tools))
526548
for _, t := range a.tools {
527549
tools = append(tools, t)
@@ -530,7 +552,9 @@ func (a *Agent) GetTools() []Tool {
530552
}
531553

532554
func (a *Agent) AddTool(tool Tool) *Agent {
555+
a.toolsMu.Lock()
533556
a.tools[tool.Name] = tool
557+
a.toolsMu.Unlock()
534558
return a
535559
}
536560

@@ -601,6 +625,7 @@ func (a *Agent) registerMcpTools(ctx context.Context, adapter *mcp.ToolAdapter)
601625
if err != nil {
602626
return nil, fmt.Errorf("list mcp tools: %w", err)
603627
}
628+
a.toolsMu.Lock()
604629
for _, t := range tools {
605630
execute := t.Execute
606631
a.tools[t.Name] = Tool{
@@ -609,6 +634,7 @@ func (a *Agent) registerMcpTools(ctx context.Context, adapter *mcp.ToolAdapter)
609634
Execute: execute,
610635
}
611636
}
637+
a.toolsMu.Unlock()
612638
// Track the client so Close() can shut it down.
613639
if client := adapter.Client(); client != nil {
614640
a.mu.Lock()
@@ -668,6 +694,7 @@ func (a *Agent) registerOpenApiTools(adapter *openapi.Adapter) (*Agent, error) {
668694
if err != nil {
669695
return nil, fmt.Errorf("list openapi tools: %w", err)
670696
}
697+
a.toolsMu.Lock()
671698
for _, t := range tools {
672699
execute := t.Execute
673700
a.tools[t.Name] = Tool{
@@ -676,6 +703,7 @@ func (a *Agent) registerOpenApiTools(adapter *openapi.Adapter) (*Agent, error) {
676703
Execute: execute,
677704
}
678705
}
706+
a.toolsMu.Unlock()
679707
return a, nil
680708
}
681709

@@ -731,6 +759,9 @@ func (a *Agent) Prompt(ctx context.Context, text string) chan Event {
731759
a.pendingWg.Add(1)
732760
go func() {
733761
defer func() {
762+
if r := recover(); r != nil {
763+
emitFn(Event{Type: string(EventError), Content: fmt.Sprintf("panic: %v", r), IsError: true})
764+
}
734765
a.pendingWg.Done()
735766
a.mu.Lock()
736767
a.isStreaming = false
@@ -776,6 +807,9 @@ func (a *Agent) PromptMessages(ctx context.Context, messages []Message) chan Eve
776807
a.pendingWg.Add(1)
777808
go func() {
778809
defer func() {
810+
if r := recover(); r != nil {
811+
emitFn(Event{Type: string(EventError), Content: fmt.Sprintf("panic: %v", r), IsError: true})
812+
}
779813
a.pendingWg.Done()
780814
a.mu.Lock()
781815
a.isStreaming = false
@@ -812,7 +846,7 @@ func (a *Agent) PromptMessages(ctx context.Context, messages []Message) chan Eve
812846

813847
opts := a.completionOpts()
814848

815-
for i := 0; i < 20; i++ {
849+
for i := 0; i < maxAgentIterations; i++ {
816850
// Check for cancellation before each turn.
817851
select {
818852
case <-loopCtx.Done():

azure.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,21 +232,21 @@ func (p *AzureOpenAIProvider) Stream(ctx context.Context, config StreamConfig, m
232232

233233
type SSEDecoder struct {
234234
reader io.Reader
235+
buf []byte
235236
}
236237

237238
func NewSSEDecoder(reader io.Reader) *SSEDecoder {
238-
return &SSEDecoder{reader: reader}
239+
return &SSEDecoder{reader: reader, buf: make([]byte, 1024)}
239240
}
240241

241242
func (d *SSEDecoder) Decode() (StreamEvent, error) {
242243
var line string
243244
for {
244-
buf := make([]byte, 1024)
245-
n, err := d.reader.Read(buf)
245+
n, err := d.reader.Read(d.buf)
246246
if n == 0 || err != nil {
247247
return StreamEvent{}, err
248248
}
249-
line = string(buf[:n])
249+
line = string(d.buf[:n])
250250
if strings.HasPrefix(line, "data:") {
251251
break
252252
}

bedrock.go

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99
"fmt"
1010
"io"
1111
"net/http"
12-
"os"
1312
"strings"
1413
"time"
1514
)
@@ -51,12 +50,8 @@ func (p *BedrockProvider) Complete(ctx context.Context, messages []Message, opts
5150
if msg.Role == "system" {
5251
systemPrompts = append(systemPrompts, map[string]string{"text": msg.Content})
5352
} else {
54-
role := msg.Role
55-
if role == "assistant" {
56-
role = "assistant"
57-
}
5853
convMessages = append(convMessages, map[string]interface{}{
59-
"role": role,
54+
"role": msg.Role,
6055
"content": []map[string]string{
6156
{"text": msg.Content},
6257
},
@@ -225,20 +220,14 @@ func (p *BedrockProvider) Stream(ctx context.Context, config StreamConfig, messa
225220
func (p *BedrockProvider) signRequest(req *http.Request, payload string) {
226221
now := time.Now().UTC()
227222
date := now.Format("20060102T150405Z")
228-
amzDate := os.Getenv("AWS_AMZ_DATE")
229-
if amzDate != "" {
230-
date = amzDate
231-
}
232223

233224
region := p.config.Region
234225
service := "bedrock"
235226

236227
req.Header.Set("X-Amz-Date", date)
237-
req.Header.Set("X-Amz-Target", "AmazonBedrockRuntime.Converse")
238228

239229
host := req.Host
240230
if host == "" {
241-
// Strip port if present, use full hostname (not just first segment).
242231
host = req.URL.Hostname()
243232
}
244233

@@ -248,9 +237,8 @@ func (p *BedrockProvider) signRequest(req *http.Request, payload string) {
248237
"content-type:application/json",
249238
fmt.Sprintf("host:%s", host),
250239
fmt.Sprintf("x-amz-date:%s", date),
251-
fmt.Sprintf("x-amz-target:AmazonBedrockRuntime.Converse"),
252240
}
253-
signedHeaders := "content-type;host;x-amz-date;x-amz-target"
241+
signedHeaders := "content-type;host;x-amz-date"
254242

255243
canonicalRequest := strings.Join([]string{
256244
"POST",
@@ -291,8 +279,3 @@ func hmacSHA256(key []byte, data string) []byte {
291279
h.Write([]byte(data))
292280
return h.Sum(nil)
293281
}
294-
295-
func init() {
296-
registry := NewProviderRegistry()
297-
registry.Register(ProtocolBedrock, &BedrockProvider{})
298-
}

context.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ import (
1111
const charsPerToken = 4
1212

1313
func EstimateTokens(text string) int {
14-
return len(text) / charsPerToken
14+
if len(text) == 0 {
15+
return 0
16+
}
17+
return (len(text) + charsPerToken - 1) / charsPerToken
1518
}
1619

1720
func EstimateMessageTokens(msg Message) int {
@@ -343,7 +346,7 @@ func CompactMessagesTiered(messages []Message, cfg ContextConfig) []Message {
343346

344347
head := compacted[:keepFirst]
345348
tail := compacted[len(compacted)-keepRecent:]
346-
return append(head, tail...)
349+
return append(head[:len(head):len(head)], tail...)
347350
}
348351

349352
type MessageSummary struct {

context_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,10 +259,10 @@ func TestContextTracker_CacheHitRate(t *testing.T) {
259259
}
260260
ct.UpdateWithRealUsage(usage)
261261

262-
// CacheHitRate = CacheRead / (InputTokens + CacheRead + CacheWrite) = 40/100 = 0.4
262+
// CacheHitRate = CacheRead / InputTokens = 40/60 ≈ 0.667
263263
got := ct.CacheHitRate()
264-
if got < 0.39 || got > 0.41 {
265-
t.Errorf("expected ~0.4 cache hit rate, got %f", got)
264+
if got < 0.66 || got > 0.68 {
265+
t.Errorf("expected ~0.667 cache hit rate, got %f", got)
266266
}
267267
}
268268

gemini.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,13 @@ func (p *geminiProvider) CompleteStream(ctx context.Context, messages []Message,
117117
}
118118

119119
streamURL := fmt.Sprintf(
120-
"https://generativelanguage.googleapis.com/v1beta/models/%s:streamGenerateContent?key=%s&alt=sse",
121-
p.cfg.Model, p.cfg.APIKey)
120+
"https://generativelanguage.googleapis.com/v1beta/models/%s:streamGenerateContent?alt=sse",
121+
p.cfg.Model)
122122

123123
var full strings.Builder
124124
sseClient := NewSSEClient()
125-
err = sseClient.Stream(ctx, streamURL, nil, body, func(e SSEEvent) {
125+
headers := map[string]string{"X-Goog-Api-Key": p.cfg.APIKey}
126+
err = sseClient.Stream(ctx, streamURL, headers, body, func(e SSEEvent) {
126127
if token, ok := ParseGeminiSSE(e.Data); ok && token != "" {
127128
full.WriteString(token)
128129
if onToken != nil {
@@ -151,12 +152,13 @@ func (p *geminiProvider) Complete(ctx context.Context, messages []Message, opts
151152
return "", fmt.Errorf("marshal request: %w", err)
152153
}
153154

154-
url := fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/%s:generateContent?key=%s", p.cfg.Model, p.cfg.APIKey)
155+
url := fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/%s:generateContent", p.cfg.Model)
155156
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
156157
if err != nil {
157158
return "", fmt.Errorf("create request: %w", err)
158159
}
159160
req.Header.Set("Content-Type", "application/json")
161+
req.Header.Set("X-Goog-Api-Key", p.cfg.APIKey)
160162

161163
resp, err := p.client.Do(req)
162164
if err != nil {

0 commit comments

Comments
 (0)