Skip to content

Commit fc73239

Browse files
committed
test: add comprehensive test suite across all packages
New test files: - anthropic_test.go — Name, SSE parsing edge cases, header forwarding - azure_test.go — Complete via httptest, headers, URL construction, SSEDecoder - gemini_test.go — Name, SSE parsing, streaming integration - openaicompat_test.go — Complete via httptest, body fields, temperature, max_tokens - context_test.go — ContextTracker, ExecutionTracker, concurrent access - retry_test.go — IsRetryable with HTTP codes, context errors, provider phrases - skills_test.go — LoadSkills deduplication, frontmatter parsing - sse_test.go — edge cases: empty body, multiline data, ordering, SSEResponse - sub_agent_test.go — SubAgent, WithMaxTurns, SubAgentPool - mcp/transport_test.go — HTTPTransport JSON-RPC, headers, incremental IDs - mcp/tool_adapter_test.go — GetTools, Execute, JSON arg conversion, ConnectHTTP - openapi/adapter_test.go — ParseSpec, OperationFilter, GetTools, callOperation Modified source: - agent.go — WithCacheEnabled builder - skills.go — LoadSkills deduplication (last-dir-wins) - sub_agent.go — WithMaxTurns builder, copy MaxTurns from config - retry.go — strengthen IsRetryable with HTTP status codes and provider phrases
1 parent 9852fd5 commit fc73239

35 files changed

Lines changed: 5837 additions & 164 deletions

.golangci.yml

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
run:
2+
timeout: 5m
3+
4+
linters:
5+
enable:
6+
- errcheck # check unchecked errors
7+
- gosimple # simplification suggestions
8+
- govet # go vet checks
9+
- ineffassign # detect unused assignments
10+
- staticcheck # comprehensive static analysis
11+
- unused # detect unused code
12+
- gofmt # enforce formatting
13+
- goimports # enforce import ordering
14+
- misspell # fix common spelling mistakes
15+
- revive # opinionated linter (replaces golint)
16+
- bodyclose # ensure http response bodies are closed
17+
- contextcheck # check context propagation
18+
- noctx # disallow http.NewRequest without context
19+
20+
linters-settings:
21+
revive:
22+
rules:
23+
- name: exported
24+
disabled: true # don't require doc comments on every export
25+
- name: unused-parameter
26+
disabled: true # too noisy for callback-heavy code
27+
28+
errcheck:
29+
# Ignore common intentional unchecked errors.
30+
ignore: fmt:.*,io/fs:PathError
31+
32+
issues:
33+
exclude-rules:
34+
# Test files: relax some rules.
35+
- path: "_test\\.go"
36+
linters:
37+
- errcheck
38+
- bodyclose
39+
# Example files: formatting/import issues are expected.
40+
- path: "examples/"
41+
linters:
42+
- goimports
43+
- revive

agent.go

Lines changed: 137 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,19 @@ type Event struct {
3737
Thinking string
3838
}
3939

40+
// AgentHooks provides optional callbacks for lifecycle events in the agent loop.
41+
// All fields are optional; nil functions are silently skipped.
42+
type AgentHooks struct {
43+
// BeforeTurn is called before each provider completion turn. turn is 1-indexed.
44+
BeforeTurn func(turn int, messages []Message)
45+
// AfterTurn is called after each provider completion turn with the response.
46+
AfterTurn func(turn int, response string)
47+
// OnToolStart is called before each tool execution.
48+
OnToolStart func(toolName string, args map[string]string)
49+
// OnToolEnd is called after each tool execution with the result and any error.
50+
OnToolEnd func(toolName string, result string, err error)
51+
}
52+
4053
// Agent is the core reasoning loop.
4154
type Agent struct {
4255
provider Provider
@@ -60,6 +73,9 @@ type Agent struct {
6073
// Input filters.
6174
inputFilters []InputFilter
6275

76+
// Lifecycle hooks.
77+
hooks AgentHooks
78+
6379
// Concurrency state — tracks whether a streaming operation is in progress.
6480
mu sync.Mutex
6581
isStreaming bool
@@ -73,6 +89,10 @@ type Agent struct {
7389

7490
// cacheConfig controls prompt caching behaviour.
7591
cacheConfig CacheConfig
92+
93+
// mcpClients holds MCP server connections owned by this agent.
94+
// They are shut down when Close() is called.
95+
mcpClients []*mcp.McpClient
7696
}
7797

7898
// New creates a new Agent.
@@ -192,16 +212,26 @@ func (a *Agent) executeToolsParallel(ctx context.Context, calls []ToolCall, iter
192212
})
193213
tool, ok := a.tools[c.Tool]
194214
if !ok {
195-
results[i] = indexedResult{call: c, result: fmt.Sprintf("unknown tool: %s", c.Tool), isError: true, unknown: true}
215+
res := fmt.Sprintf("unknown tool: %s", c.Tool)
216+
if a.hooks.OnToolEnd != nil {
217+
a.hooks.OnToolEnd(c.Tool, res, fmt.Errorf("unknown tool: %s", c.Tool))
218+
}
219+
results[i] = indexedResult{call: c, result: res, isError: true, unknown: true}
196220
emitFn(Event{Type: string(EventToolExecutionEnd), ToolName: c.Tool, Result: results[i].result, IsError: true})
197221
return
198222
}
223+
if a.hooks.OnToolStart != nil {
224+
a.hooks.OnToolStart(c.Tool, c.Args)
225+
}
199226
res, err := tool.Execute(ctx, c.Args)
200227
isErr := false
201228
if err != nil {
202229
res = fmt.Sprintf("ERROR: %s\nOutput: %s", err.Error(), res)
203230
isErr = true
204231
}
232+
if a.hooks.OnToolEnd != nil {
233+
a.hooks.OnToolEnd(c.Tool, res, err)
234+
}
205235
results[i] = indexedResult{call: c, result: res, isError: isErr}
206236
emitFn(Event{Type: string(EventToolExecutionEnd), ToolName: c.Tool, Result: res, IsError: isErr})
207237
}(idx, call)
@@ -258,13 +288,21 @@ func (a *Agent) executeSingleTool(ctx context.Context, call ToolCall, iteration
258288
})
259289
a.logger.Info("executing tool", "tool", call.Tool, "args", call.Args)
260290

291+
if a.hooks.OnToolStart != nil {
292+
a.hooks.OnToolStart(call.Tool, call.Args)
293+
}
294+
261295
result, err := tool.Execute(ctx, call.Args)
262296
isError := false
263297
if err != nil {
264298
result = fmt.Sprintf("ERROR: %s\nOutput: %s", err.Error(), result)
265299
isError = true
266300
}
267301

302+
if a.hooks.OnToolEnd != nil {
303+
a.hooks.OnToolEnd(call.Tool, result, err)
304+
}
305+
268306
emitFn(Event{
269307
Type: string(EventToolExecutionEnd),
270308
ToolName: call.Tool,
@@ -309,12 +347,34 @@ func (a *Agent) Run(ctx context.Context, systemPrompt, userMessage string, emitF
309347
// Context compaction check before provider call.
310348
messages = a.maybeCompact(messages, emit)
311349

312-
response, err := a.provider.Complete(ctx, messages, opts)
350+
if a.hooks.BeforeTurn != nil {
351+
a.hooks.BeforeTurn(i+1, messages)
352+
}
353+
354+
var (
355+
response string
356+
err error
357+
)
358+
if ts, ok := a.provider.(TokenStreamer); ok {
359+
response, err = RetryWithResult(ctx, DefaultRetryConfig, func() (string, error) {
360+
return ts.CompleteStream(ctx, messages, opts, func(token string) {
361+
emit(Event{Type: string(EventTokenUpdate), Content: token})
362+
})
363+
})
364+
} else {
365+
response, err = RetryWithResult(ctx, DefaultRetryConfig, func() (string, error) {
366+
return a.provider.Complete(ctx, messages, opts)
367+
})
368+
}
313369
if err != nil {
314370
emit(Event{Type: string(EventError), Content: err.Error(), IsError: true})
315371
return "", fmt.Errorf("provider error at step %d: %w", i+1, err)
316372
}
317373

374+
if a.hooks.AfterTurn != nil {
375+
a.hooks.AfterTurn(i+1, response)
376+
}
377+
318378
emit(Event{Type: string(EventMessageUpdate), Content: response})
319379

320380
messages = append(messages, Message{
@@ -489,12 +549,29 @@ func (a *Agent) WithInputFilter(f InputFilter) *Agent {
489549
return a
490550
}
491551

552+
// WithHooks sets lifecycle hook callbacks on the agent.
553+
func (a *Agent) WithHooks(h AgentHooks) *Agent {
554+
a.hooks = h
555+
return a
556+
}
557+
492558
// WithCacheConfig enables prompt caching with the given configuration.
493559
func (a *Agent) WithCacheConfig(cfg CacheConfig) *Agent {
494560
a.cacheConfig = cfg
495561
return a
496562
}
497563

564+
// WithCacheEnabled is a convenience builder that enables or disables prompt
565+
// caching using the DefaultCacheConfig when enabled is true.
566+
func (a *Agent) WithCacheEnabled(enabled bool) *Agent {
567+
if enabled {
568+
a.cacheConfig = DefaultCacheConfig()
569+
} else {
570+
a.cacheConfig = CacheConfig{}
571+
}
572+
return a
573+
}
574+
498575
// WithMcpServerStdio connects to an MCP server via stdio (spawns a child process),
499576
// performs the initialize handshake, and registers all advertised tools.
500577
// Returns an error if the server fails to start or initialize.
@@ -529,9 +606,42 @@ func (a *Agent) registerMcpTools(ctx context.Context, adapter *mcp.ToolAdapter)
529606
Execute: execute,
530607
}
531608
}
609+
// Track the client so Close() can shut it down.
610+
if client := adapter.Client(); client != nil {
611+
a.mu.Lock()
612+
a.mcpClients = append(a.mcpClients, client)
613+
a.mu.Unlock()
614+
}
532615
return a, nil
533616
}
534617

618+
// Close shuts down any MCP server connections owned by this agent,
619+
// cancels any running operation, and waits for it to finish.
620+
// Safe to call multiple times.
621+
func (a *Agent) Close() error {
622+
a.Reset() // cancel + drain pending work
623+
624+
a.mu.Lock()
625+
clients := a.mcpClients
626+
a.mcpClients = nil
627+
a.mu.Unlock()
628+
629+
var errs []error
630+
for _, c := range clients {
631+
if err := c.Close(); err != nil {
632+
errs = append(errs, err)
633+
}
634+
}
635+
if len(errs) > 0 {
636+
msgs := make([]string, len(errs))
637+
for i, e := range errs {
638+
msgs[i] = e.Error()
639+
}
640+
return fmt.Errorf("mcp close errors: %s", strings.Join(msgs, "; "))
641+
}
642+
return nil
643+
}
644+
535645
// WithOpenApiFile loads an OpenAPI spec from a JSON file and registers its operations as tools.
536646
func (a *Agent) WithOpenApiFile(path string, cfg openapi.Config) (*Agent, error) {
537647
adapter, err := openapi.FromFile(path, cfg)
@@ -713,12 +823,34 @@ func (a *Agent) PromptMessages(ctx context.Context, messages []Message) chan Eve
713823
// Context compaction check before provider call.
714824
fullMessages = a.maybeCompact(fullMessages, emitFn)
715825

716-
response, err := a.provider.Complete(loopCtx, fullMessages, opts)
717-
if err != nil {
718-
emitFn(Event{Type: string(EventError), Content: err.Error(), IsError: true})
826+
if a.hooks.BeforeTurn != nil {
827+
a.hooks.BeforeTurn(i+1, fullMessages)
828+
}
829+
830+
var (
831+
response string
832+
turnErr error
833+
)
834+
if ts, ok := a.provider.(TokenStreamer); ok {
835+
response, turnErr = RetryWithResult(loopCtx, DefaultRetryConfig, func() (string, error) {
836+
return ts.CompleteStream(loopCtx, fullMessages, opts, func(token string) {
837+
emitFn(Event{Type: string(EventTokenUpdate), Content: token})
838+
})
839+
})
840+
} else {
841+
response, turnErr = RetryWithResult(loopCtx, DefaultRetryConfig, func() (string, error) {
842+
return a.provider.Complete(loopCtx, fullMessages, opts)
843+
})
844+
}
845+
if turnErr != nil {
846+
emitFn(Event{Type: string(EventError), Content: turnErr.Error(), IsError: true})
719847
break
720848
}
721849

850+
if a.hooks.AfterTurn != nil {
851+
a.hooks.AfterTurn(i+1, response)
852+
}
853+
722854
emitFn(Event{Type: string(EventMessageUpdate), Content: response})
723855

724856
fullMessages = append(fullMessages, Message{Role: "assistant", Content: response})

0 commit comments

Comments
 (0)