@@ -37,6 +37,19 @@ type Event struct {
3737 Thinking string
3838}
3939
40+ // AgentHooks provides optional callbacks for lifecycle events in the agent loop.
41+ // All fields are optional; nil functions are silently skipped.
42+ type AgentHooks struct {
43+ // BeforeTurn is called before each provider completion turn. turn is 1-indexed.
44+ BeforeTurn func (turn int , messages []Message )
45+ // AfterTurn is called after each provider completion turn with the response.
46+ AfterTurn func (turn int , response string )
47+ // OnToolStart is called before each tool execution.
48+ OnToolStart func (toolName string , args map [string ]string )
49+ // OnToolEnd is called after each tool execution with the result and any error.
50+ OnToolEnd func (toolName string , result string , err error )
51+ }
52+
4053// Agent is the core reasoning loop.
4154type Agent struct {
4255 provider Provider
@@ -60,6 +73,9 @@ type Agent struct {
6073 // Input filters.
6174 inputFilters []InputFilter
6275
76+ // Lifecycle hooks.
77+ hooks AgentHooks
78+
6379 // Concurrency state — tracks whether a streaming operation is in progress.
6480 mu sync.Mutex
6581 isStreaming bool
@@ -73,6 +89,10 @@ type Agent struct {
7389
7490 // cacheConfig controls prompt caching behaviour.
7591 cacheConfig CacheConfig
92+
93+ // mcpClients holds MCP server connections owned by this agent.
94+ // They are shut down when Close() is called.
95+ mcpClients []* mcp.McpClient
7696}
7797
7898// New creates a new Agent.
@@ -192,16 +212,26 @@ func (a *Agent) executeToolsParallel(ctx context.Context, calls []ToolCall, iter
192212 })
193213 tool , ok := a .tools [c .Tool ]
194214 if ! ok {
195- results [i ] = indexedResult {call : c , result : fmt .Sprintf ("unknown tool: %s" , c .Tool ), isError : true , unknown : true }
215+ res := fmt .Sprintf ("unknown tool: %s" , c .Tool )
216+ if a .hooks .OnToolEnd != nil {
217+ a .hooks .OnToolEnd (c .Tool , res , fmt .Errorf ("unknown tool: %s" , c .Tool ))
218+ }
219+ results [i ] = indexedResult {call : c , result : res , isError : true , unknown : true }
196220 emitFn (Event {Type : string (EventToolExecutionEnd ), ToolName : c .Tool , Result : results [i ].result , IsError : true })
197221 return
198222 }
223+ if a .hooks .OnToolStart != nil {
224+ a .hooks .OnToolStart (c .Tool , c .Args )
225+ }
199226 res , err := tool .Execute (ctx , c .Args )
200227 isErr := false
201228 if err != nil {
202229 res = fmt .Sprintf ("ERROR: %s\n Output: %s" , err .Error (), res )
203230 isErr = true
204231 }
232+ if a .hooks .OnToolEnd != nil {
233+ a .hooks .OnToolEnd (c .Tool , res , err )
234+ }
205235 results [i ] = indexedResult {call : c , result : res , isError : isErr }
206236 emitFn (Event {Type : string (EventToolExecutionEnd ), ToolName : c .Tool , Result : res , IsError : isErr })
207237 }(idx , call )
@@ -258,13 +288,21 @@ func (a *Agent) executeSingleTool(ctx context.Context, call ToolCall, iteration
258288 })
259289 a .logger .Info ("executing tool" , "tool" , call .Tool , "args" , call .Args )
260290
291+ if a .hooks .OnToolStart != nil {
292+ a .hooks .OnToolStart (call .Tool , call .Args )
293+ }
294+
261295 result , err := tool .Execute (ctx , call .Args )
262296 isError := false
263297 if err != nil {
264298 result = fmt .Sprintf ("ERROR: %s\n Output: %s" , err .Error (), result )
265299 isError = true
266300 }
267301
302+ if a .hooks .OnToolEnd != nil {
303+ a .hooks .OnToolEnd (call .Tool , result , err )
304+ }
305+
268306 emitFn (Event {
269307 Type : string (EventToolExecutionEnd ),
270308 ToolName : call .Tool ,
@@ -309,12 +347,34 @@ func (a *Agent) Run(ctx context.Context, systemPrompt, userMessage string, emitF
309347 // Context compaction check before provider call.
310348 messages = a .maybeCompact (messages , emit )
311349
312- response , err := a .provider .Complete (ctx , messages , opts )
350+ if a .hooks .BeforeTurn != nil {
351+ a .hooks .BeforeTurn (i + 1 , messages )
352+ }
353+
354+ var (
355+ response string
356+ err error
357+ )
358+ if ts , ok := a .provider .(TokenStreamer ); ok {
359+ response , err = RetryWithResult (ctx , DefaultRetryConfig , func () (string , error ) {
360+ return ts .CompleteStream (ctx , messages , opts , func (token string ) {
361+ emit (Event {Type : string (EventTokenUpdate ), Content : token })
362+ })
363+ })
364+ } else {
365+ response , err = RetryWithResult (ctx , DefaultRetryConfig , func () (string , error ) {
366+ return a .provider .Complete (ctx , messages , opts )
367+ })
368+ }
313369 if err != nil {
314370 emit (Event {Type : string (EventError ), Content : err .Error (), IsError : true })
315371 return "" , fmt .Errorf ("provider error at step %d: %w" , i + 1 , err )
316372 }
317373
374+ if a .hooks .AfterTurn != nil {
375+ a .hooks .AfterTurn (i + 1 , response )
376+ }
377+
318378 emit (Event {Type : string (EventMessageUpdate ), Content : response })
319379
320380 messages = append (messages , Message {
@@ -489,12 +549,29 @@ func (a *Agent) WithInputFilter(f InputFilter) *Agent {
489549 return a
490550}
491551
552+ // WithHooks sets lifecycle hook callbacks on the agent.
553+ func (a * Agent ) WithHooks (h AgentHooks ) * Agent {
554+ a .hooks = h
555+ return a
556+ }
557+
492558// WithCacheConfig enables prompt caching with the given configuration.
493559func (a * Agent ) WithCacheConfig (cfg CacheConfig ) * Agent {
494560 a .cacheConfig = cfg
495561 return a
496562}
497563
564+ // WithCacheEnabled is a convenience builder that enables or disables prompt
565+ // caching using the DefaultCacheConfig when enabled is true.
566+ func (a * Agent ) WithCacheEnabled (enabled bool ) * Agent {
567+ if enabled {
568+ a .cacheConfig = DefaultCacheConfig ()
569+ } else {
570+ a .cacheConfig = CacheConfig {}
571+ }
572+ return a
573+ }
574+
498575// WithMcpServerStdio connects to an MCP server via stdio (spawns a child process),
499576// performs the initialize handshake, and registers all advertised tools.
500577// Returns an error if the server fails to start or initialize.
@@ -529,9 +606,42 @@ func (a *Agent) registerMcpTools(ctx context.Context, adapter *mcp.ToolAdapter)
529606 Execute : execute ,
530607 }
531608 }
609+ // Track the client so Close() can shut it down.
610+ if client := adapter .Client (); client != nil {
611+ a .mu .Lock ()
612+ a .mcpClients = append (a .mcpClients , client )
613+ a .mu .Unlock ()
614+ }
532615 return a , nil
533616}
534617
618+ // Close shuts down any MCP server connections owned by this agent,
619+ // cancels any running operation, and waits for it to finish.
620+ // Safe to call multiple times.
621+ func (a * Agent ) Close () error {
622+ a .Reset () // cancel + drain pending work
623+
624+ a .mu .Lock ()
625+ clients := a .mcpClients
626+ a .mcpClients = nil
627+ a .mu .Unlock ()
628+
629+ var errs []error
630+ for _ , c := range clients {
631+ if err := c .Close (); err != nil {
632+ errs = append (errs , err )
633+ }
634+ }
635+ if len (errs ) > 0 {
636+ msgs := make ([]string , len (errs ))
637+ for i , e := range errs {
638+ msgs [i ] = e .Error ()
639+ }
640+ return fmt .Errorf ("mcp close errors: %s" , strings .Join (msgs , "; " ))
641+ }
642+ return nil
643+ }
644+
535645// WithOpenApiFile loads an OpenAPI spec from a JSON file and registers its operations as tools.
536646func (a * Agent ) WithOpenApiFile (path string , cfg openapi.Config ) (* Agent , error ) {
537647 adapter , err := openapi .FromFile (path , cfg )
@@ -713,12 +823,34 @@ func (a *Agent) PromptMessages(ctx context.Context, messages []Message) chan Eve
713823 // Context compaction check before provider call.
714824 fullMessages = a .maybeCompact (fullMessages , emitFn )
715825
716- response , err := a .provider .Complete (loopCtx , fullMessages , opts )
717- if err != nil {
718- emitFn (Event {Type : string (EventError ), Content : err .Error (), IsError : true })
826+ if a .hooks .BeforeTurn != nil {
827+ a .hooks .BeforeTurn (i + 1 , fullMessages )
828+ }
829+
830+ var (
831+ response string
832+ turnErr error
833+ )
834+ if ts , ok := a .provider .(TokenStreamer ); ok {
835+ response , turnErr = RetryWithResult (loopCtx , DefaultRetryConfig , func () (string , error ) {
836+ return ts .CompleteStream (loopCtx , fullMessages , opts , func (token string ) {
837+ emitFn (Event {Type : string (EventTokenUpdate ), Content : token })
838+ })
839+ })
840+ } else {
841+ response , turnErr = RetryWithResult (loopCtx , DefaultRetryConfig , func () (string , error ) {
842+ return a .provider .Complete (loopCtx , fullMessages , opts )
843+ })
844+ }
845+ if turnErr != nil {
846+ emitFn (Event {Type : string (EventError ), Content : turnErr .Error (), IsError : true })
719847 break
720848 }
721849
850+ if a .hooks .AfterTurn != nil {
851+ a .hooks .AfterTurn (i + 1 , response )
852+ }
853+
722854 emitFn (Event {Type : string (EventMessageUpdate ), Content : response })
723855
724856 fullMessages = append (fullMessages , Message {Role : "assistant" , Content : response })
0 commit comments