Skip to content

Commit c826cfa

Browse files
author
NGUYEN Duc Trung
committed
feat: support custom tools in subagents
1 parent e3638da commit c826cfa

File tree

7 files changed

+1102
-22
lines changed

7 files changed

+1102
-22
lines changed

go/README.md

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,58 @@ safeLookup := copilot.DefineTool("safe_lookup", "A read-only lookup that needs n
359359
safeLookup.SkipPermission = true
360360
```
361361

362+
### Custom Tools with Subagents
363+
364+
When a session is configured with both custom tools and custom agents (subagents), the
365+
subagents can invoke the parent session's custom tools. The SDK automatically routes
366+
tool calls from child sessions back to the parent session's tool handlers.
367+
368+
#### Tool Access Control
369+
370+
The `Tools` field on `CustomAgentConfig` controls which custom tools each subagent can access:
371+
372+
| `Tools` value | Behavior |
373+
|---------------|----------|
374+
| `nil` (default) | Subagent can access **all** custom tools registered on the parent session |
375+
| `[]string{}` (empty) | Subagent cannot access **any** custom tools |
376+
| `[]string{"tool_a", "tool_b"}` | Subagent can only access the listed tools |
377+
378+
#### Example
379+
380+
```go
381+
session, err := client.CreateSession(ctx, &copilot.SessionConfig{
382+
Tools: []copilot.Tool{
383+
copilot.DefineTool("save_output", "Saves output to storage",
384+
func(params SaveParams, inv copilot.ToolInvocation) (string, error) {
385+
// Handle tool call — works for both direct and subagent invocations
386+
return saveToStorage(params.Content)
387+
}),
388+
copilot.DefineTool("get_data", "Retrieves data from storage",
389+
func(params GetParams, inv copilot.ToolInvocation) (string, error) {
390+
return getData(params.Key)
391+
}),
392+
},
393+
CustomAgents: []copilot.CustomAgentConfig{
394+
{
395+
Name: "researcher",
396+
Description: "Researches topics and saves findings",
397+
Tools: []string{"save_output"}, // Can only use save_output, not get_data
398+
Prompt: "You are a research assistant. Save your findings using save_output.",
399+
},
400+
{
401+
Name: "analyst",
402+
Description: "Analyzes data from storage",
403+
Tools: nil, // Can access ALL custom tools
404+
Prompt: "You are a data analyst.",
405+
},
406+
},
407+
})
408+
```
409+
410+
When `researcher` is invoked as a subagent, it can call `save_output` but not `get_data`.
411+
When `analyst` is invoked, it can call both tools. If a subagent attempts to use a tool
412+
not in its allowlist, the SDK returns a `"Tool '{name}' is not supported by this client instance."` response to the LLM.
413+
362414
## Streaming
363415

364416
Enable streaming to receive assistant response chunks as they're generated:

go/client.go

Lines changed: 208 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ import (
5353

5454
const noResultPermissionV2Error = "permission handlers cannot return 'no-result' when connected to a protocol v2 server"
5555

56+
// subagentInstance represents a single active subagent launch.
57+
type subagentInstance struct {
58+
agentName string
59+
toolCallID string
60+
childSessionID string // empty until child session ID is known
61+
startedAt time.Time
62+
}
63+
5664
// Client manages the connection to the Copilot CLI server and provides session management.
5765
//
5866
// The Client can either spawn a CLI server process or connect to an existing server.
@@ -81,6 +89,22 @@ type Client struct {
8189
state ConnectionState
8290
sessions map[string]*Session
8391
sessionsMux sync.Mutex
92+
93+
// childToParent maps childSessionID → parentSessionID.
94+
// Populated exclusively from authoritative protocol signals.
95+
// Protected by sessionsMux.
96+
childToParent map[string]string
97+
98+
// childToAgent maps childSessionID → agentName.
99+
// Used for allowlist enforcement. Populated alongside childToParent.
100+
// Protected by sessionsMux.
101+
childToAgent map[string]string
102+
103+
// subagentInstances tracks active subagent launches per parent session.
104+
// Key: parentSessionID → map of toolCallID → subagentInstance.
105+
// Protected by sessionsMux.
106+
subagentInstances map[string]map[string]*subagentInstance
107+
84108
isExternalServer bool
85109
conn net.Conn // stores net.Conn for external TCP connections
86110
useStdio bool // resolved value from options
@@ -129,8 +153,11 @@ func NewClient(options *ClientOptions) *Client {
129153
client := &Client{
130154
options: opts,
131155
state: StateDisconnected,
132-
sessions: make(map[string]*Session),
133-
actualHost: "localhost",
156+
sessions: make(map[string]*Session),
157+
childToParent: make(map[string]string),
158+
childToAgent: make(map[string]string),
159+
subagentInstances: make(map[string]map[string]*subagentInstance),
160+
actualHost: "localhost",
134161
isExternalServer: false,
135162
useStdio: true,
136163
autoStart: true, // default
@@ -346,6 +373,9 @@ func (c *Client) Stop() error {
346373

347374
c.sessionsMux.Lock()
348375
c.sessions = make(map[string]*Session)
376+
c.childToParent = make(map[string]string)
377+
c.childToAgent = make(map[string]string)
378+
c.subagentInstances = make(map[string]map[string]*subagentInstance)
349379
c.sessionsMux.Unlock()
350380

351381
c.startStopMux.Lock()
@@ -586,6 +616,12 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses
586616
// events emitted by the CLI (e.g. session.start) are not dropped.
587617
session := newSession(sessionID, c.client, "")
588618

619+
session.customAgents = config.CustomAgents
620+
session.onDestroy = func() {
621+
c.sessionsMux.Lock()
622+
c.removeChildMappingsForParentLocked(session.SessionID)
623+
c.sessionsMux.Unlock()
624+
}
589625
session.registerTools(config.Tools)
590626
session.registerPermissionHandler(config.OnPermissionRequest)
591627
if config.OnUserInputRequest != nil {
@@ -707,6 +743,12 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string,
707743
// events emitted by the CLI (e.g. session.start) are not dropped.
708744
session := newSession(sessionID, c.client, "")
709745

746+
session.customAgents = config.CustomAgents
747+
session.onDestroy = func() {
748+
c.sessionsMux.Lock()
749+
c.removeChildMappingsForParentLocked(session.SessionID)
750+
c.sessionsMux.Unlock()
751+
}
710752
session.registerTools(config.Tools)
711753
session.registerPermissionHandler(config.OnPermissionRequest)
712754
if config.OnUserInputRequest != nil {
@@ -860,6 +902,7 @@ func (c *Client) DeleteSession(ctx context.Context, sessionID string) error {
860902
// Remove from local sessions map if present
861903
c.sessionsMux.Lock()
862904
delete(c.sessions, sessionID)
905+
c.removeChildMappingsForParentLocked(sessionID)
863906
c.sessionsMux.Unlock()
864907

865908
return nil
@@ -1500,21 +1543,160 @@ func (c *Client) handleSessionEvent(req sessionEventRequest) {
15001543
c.sessionsMux.Unlock()
15011544

15021545
if ok {
1546+
// Intercept subagent lifecycle events for child tracking
1547+
c.handleSubagentEvent(req.SessionID, req.Event)
15031548
session.dispatchEvent(req.Event)
15041549
}
15051550
}
15061551

1552+
// handleSubagentEvent intercepts subagent lifecycle events to manage child session tracking.
1553+
func (c *Client) handleSubagentEvent(parentSessionID string, event SessionEvent) {
1554+
switch event.Type {
1555+
case SessionEventTypeSubagentStarted:
1556+
c.onSubagentStarted(parentSessionID, event)
1557+
case SessionEventTypeSubagentCompleted, SessionEventTypeSubagentFailed:
1558+
c.onSubagentEnded(parentSessionID, event)
1559+
}
1560+
}
1561+
1562+
// onSubagentStarted handles a subagent.started event by creating a subagent instance
1563+
// and mapping the child session to its parent.
1564+
func (c *Client) onSubagentStarted(parentSessionID string, event SessionEvent) {
1565+
toolCallID := derefStr(event.Data.ToolCallID)
1566+
agentName := derefStr(event.Data.AgentName)
1567+
childSessionID := derefStr(event.Data.RemoteSessionID)
1568+
1569+
c.sessionsMux.Lock()
1570+
defer c.sessionsMux.Unlock()
1571+
1572+
// Track instance by toolCallID (unique per launch)
1573+
if c.subagentInstances[parentSessionID] == nil {
1574+
c.subagentInstances[parentSessionID] = make(map[string]*subagentInstance)
1575+
}
1576+
c.subagentInstances[parentSessionID][toolCallID] = &subagentInstance{
1577+
agentName: agentName,
1578+
toolCallID: toolCallID,
1579+
childSessionID: childSessionID,
1580+
startedAt: event.Timestamp,
1581+
}
1582+
1583+
// Eagerly map child→parent and child→agent
1584+
if childSessionID != "" {
1585+
c.childToParent[childSessionID] = parentSessionID
1586+
c.childToAgent[childSessionID] = agentName
1587+
}
1588+
}
1589+
1590+
// onSubagentEnded handles subagent.completed and subagent.failed events
1591+
// by removing the subagent instance. Child-to-parent mappings are NOT removed
1592+
// here because in-flight requests may still arrive after the subagent completes.
1593+
func (c *Client) onSubagentEnded(parentSessionID string, event SessionEvent) {
1594+
toolCallID := derefStr(event.Data.ToolCallID)
1595+
1596+
c.sessionsMux.Lock()
1597+
defer c.sessionsMux.Unlock()
1598+
1599+
if instances, ok := c.subagentInstances[parentSessionID]; ok {
1600+
delete(instances, toolCallID)
1601+
if len(instances) == 0 {
1602+
delete(c.subagentInstances, parentSessionID)
1603+
}
1604+
}
1605+
}
1606+
1607+
// derefStr safely dereferences a string pointer, returning "" if nil.
1608+
func derefStr(s *string) string {
1609+
if s == nil {
1610+
return ""
1611+
}
1612+
return *s
1613+
}
1614+
1615+
// resolveSession looks up a session by ID. If the ID is not a directly
1616+
// registered session, it checks whether it is a known child session and
1617+
// returns the parent session instead.
1618+
//
1619+
// Returns (session, isChild, error). isChild=true means the request came
1620+
// from a child session and was resolved via parent lineage.
1621+
//
1622+
// Lock contract: acquires and releases sessionsMux internally.
1623+
// Does NOT hold sessionsMux when returning.
1624+
func (c *Client) resolveSession(sessionID string) (*Session, bool, error) {
1625+
c.sessionsMux.Lock()
1626+
// Direct lookup
1627+
if session, ok := c.sessions[sessionID]; ok {
1628+
c.sessionsMux.Unlock()
1629+
return session, false, nil
1630+
}
1631+
// Child→parent lookup (authoritative mapping only)
1632+
parentID, isChild := c.childToParent[sessionID]
1633+
if !isChild {
1634+
c.sessionsMux.Unlock()
1635+
return nil, false, fmt.Errorf("unknown session %s", sessionID)
1636+
}
1637+
session, ok := c.sessions[parentID]
1638+
c.sessionsMux.Unlock()
1639+
if !ok {
1640+
return nil, false, fmt.Errorf("parent session %s for child %s not found", parentID, sessionID)
1641+
}
1642+
return session, true, nil
1643+
}
1644+
1645+
// removeChildMappingsForParentLocked removes all child mappings for a parent session.
1646+
// MUST be called with sessionsMux held.
1647+
func (c *Client) removeChildMappingsForParentLocked(parentSessionID string) {
1648+
for childID, parentID := range c.childToParent {
1649+
if parentID == parentSessionID {
1650+
delete(c.childToParent, childID)
1651+
delete(c.childToAgent, childID)
1652+
}
1653+
}
1654+
delete(c.subagentInstances, parentSessionID)
1655+
}
1656+
1657+
// isToolAllowedForChild checks whether a tool is in the allowlist for the agent
1658+
// that owns the given child session.
1659+
func (c *Client) isToolAllowedForChild(childSessionID, toolName string) bool {
1660+
c.sessionsMux.Lock()
1661+
agentName, ok := c.childToAgent[childSessionID]
1662+
c.sessionsMux.Unlock()
1663+
if !ok {
1664+
return false // unknown child → deny
1665+
}
1666+
1667+
session, _, _ := c.resolveSession(childSessionID)
1668+
if session == nil {
1669+
return false
1670+
}
1671+
1672+
agentConfig := session.getAgentConfig(agentName)
1673+
if agentConfig == nil {
1674+
return false // agent not found → deny
1675+
}
1676+
1677+
// nil Tools = all tools allowed
1678+
if agentConfig.Tools == nil {
1679+
return true
1680+
}
1681+
1682+
// Explicit list — check membership
1683+
for _, t := range agentConfig.Tools {
1684+
if t == toolName {
1685+
return true
1686+
}
1687+
}
1688+
return false
1689+
}
1690+
15071691
// handleUserInputRequest handles a user input request from the CLI server.
15081692
func (c *Client) handleUserInputRequest(req userInputRequest) (*userInputResponse, *jsonrpc2.Error) {
15091693
if req.SessionID == "" || req.Question == "" {
15101694
return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid user input request payload"}
15111695
}
15121696

1513-
c.sessionsMux.Lock()
1514-
session, ok := c.sessions[req.SessionID]
1515-
c.sessionsMux.Unlock()
1516-
if !ok {
1517-
return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)}
1697+
session, _, err := c.resolveSession(req.SessionID)
1698+
if err != nil {
1699+
return nil, &jsonrpc2.Error{Code: -32602, Message: err.Error()}
15181700
}
15191701

15201702
response, err := session.handleUserInputRequest(UserInputRequest{
@@ -1535,11 +1717,9 @@ func (c *Client) handleHooksInvoke(req hooksInvokeRequest) (map[string]any, *jso
15351717
return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid hooks invoke payload"}
15361718
}
15371719

1538-
c.sessionsMux.Lock()
1539-
session, ok := c.sessions[req.SessionID]
1540-
c.sessionsMux.Unlock()
1541-
if !ok {
1542-
return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)}
1720+
session, _, err := c.resolveSession(req.SessionID)
1721+
if err != nil {
1722+
return nil, &jsonrpc2.Error{Code: -32602, Message: err.Error()}
15431723
}
15441724

15451725
output, err := session.handleHooksInvoke(req.Type, req.Input)
@@ -1610,11 +1790,19 @@ func (c *Client) handleToolCallRequestV2(req toolCallRequestV2) (*toolCallRespon
16101790
return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid tool call payload"}
16111791
}
16121792

1613-
c.sessionsMux.Lock()
1614-
session, ok := c.sessions[req.SessionID]
1615-
c.sessionsMux.Unlock()
1616-
if !ok {
1617-
return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)}
1793+
session, isChild, err := c.resolveSession(req.SessionID)
1794+
if err != nil {
1795+
return nil, &jsonrpc2.Error{Code: -32602, Message: err.Error()}
1796+
}
1797+
1798+
// For child sessions, enforce tool allowlist
1799+
if isChild && !c.isToolAllowedForChild(req.SessionID, req.ToolName) {
1800+
return &toolCallResponseV2{Result: ToolResult{
1801+
TextResultForLLM: fmt.Sprintf("Tool '%s' is not supported by this client instance.", req.ToolName),
1802+
ResultType: "failure",
1803+
Error: fmt.Sprintf("tool '%s' not supported", req.ToolName),
1804+
ToolTelemetry: map[string]any{},
1805+
}}, nil
16181806
}
16191807

16201808
handler, ok := session.getToolHandler(req.ToolName)
@@ -1656,11 +1844,9 @@ func (c *Client) handlePermissionRequestV2(req permissionRequestV2) (*permission
16561844
return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid permission request payload"}
16571845
}
16581846

1659-
c.sessionsMux.Lock()
1660-
session, ok := c.sessions[req.SessionID]
1661-
c.sessionsMux.Unlock()
1662-
if !ok {
1663-
return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)}
1847+
session, _, err := c.resolveSession(req.SessionID)
1848+
if err != nil {
1849+
return nil, &jsonrpc2.Error{Code: -32602, Message: err.Error()}
16641850
}
16651851

16661852
handler := session.getPermissionHandler()

0 commit comments

Comments
 (0)