Skip to content

Commit 9b52b4e

Browse files
author
NGUYEN Duc Trung
committed
feat: support custom tools in subagents
1 parent dd42d42 commit 9b52b4e

File tree

7 files changed

+1102
-22
lines changed

7 files changed

+1102
-22
lines changed

go/README.md

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,58 @@ safeLookup := copilot.DefineTool("safe_lookup", "A read-only lookup that needs n
368368
safeLookup.SkipPermission = true
369369
```
370370

371+
### Custom Tools with Subagents
372+
373+
When a session is configured with both custom tools and custom agents (subagents), the
374+
subagents can invoke the parent session's custom tools. The SDK automatically routes
375+
tool calls from child sessions back to the parent session's tool handlers.
376+
377+
#### Tool Access Control
378+
379+
The `Tools` field on `CustomAgentConfig` controls which custom tools each subagent can access:
380+
381+
| `Tools` value | Behavior |
382+
|---------------|----------|
383+
| `nil` (default) | Subagent can access **all** custom tools registered on the parent session |
384+
| `[]string{}` (empty) | Subagent cannot access **any** custom tools |
385+
| `[]string{"tool_a", "tool_b"}` | Subagent can only access the listed tools |
386+
387+
#### Example
388+
389+
```go
390+
session, err := client.CreateSession(ctx, &copilot.SessionConfig{
391+
Tools: []copilot.Tool{
392+
copilot.DefineTool("save_output", "Saves output to storage",
393+
func(params SaveParams, inv copilot.ToolInvocation) (string, error) {
394+
// Handle tool call — works for both direct and subagent invocations
395+
return saveToStorage(params.Content)
396+
}),
397+
copilot.DefineTool("get_data", "Retrieves data from storage",
398+
func(params GetParams, inv copilot.ToolInvocation) (string, error) {
399+
return getData(params.Key)
400+
}),
401+
},
402+
CustomAgents: []copilot.CustomAgentConfig{
403+
{
404+
Name: "researcher",
405+
Description: "Researches topics and saves findings",
406+
Tools: []string{"save_output"}, // Can only use save_output, not get_data
407+
Prompt: "You are a research assistant. Save your findings using save_output.",
408+
},
409+
{
410+
Name: "analyst",
411+
Description: "Analyzes data from storage",
412+
Tools: nil, // Can access ALL custom tools
413+
Prompt: "You are a data analyst.",
414+
},
415+
},
416+
})
417+
```
418+
419+
When `researcher` is invoked as a subagent, it can call `save_output` but not `get_data`.
420+
When `analyst` is invoked, it can call both tools. If a subagent attempts to use a tool
421+
not in its allowlist, the SDK returns a `"Tool '{name}' is not supported by this client instance."` response to the LLM.
422+
371423
## Streaming
372424

373425
Enable streaming to receive assistant response chunks as they're generated:

go/client.go

Lines changed: 208 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ import (
5353

5454
const noResultPermissionV2Error = "permission handlers cannot return 'no-result' when connected to a protocol v2 server"
5555

56+
// subagentInstance represents a single active subagent launch.
57+
type subagentInstance struct {
58+
agentName string
59+
toolCallID string
60+
childSessionID string // empty until child session ID is known
61+
startedAt time.Time
62+
}
63+
5664
// Client manages the connection to the Copilot CLI server and provides session management.
5765
//
5866
// The Client can either spawn a CLI server process or connect to an existing server.
@@ -81,6 +89,22 @@ type Client struct {
8189
state ConnectionState
8290
sessions map[string]*Session
8391
sessionsMux sync.Mutex
92+
93+
// childToParent maps childSessionID → parentSessionID.
94+
// Populated exclusively from authoritative protocol signals.
95+
// Protected by sessionsMux.
96+
childToParent map[string]string
97+
98+
// childToAgent maps childSessionID → agentName.
99+
// Used for allowlist enforcement. Populated alongside childToParent.
100+
// Protected by sessionsMux.
101+
childToAgent map[string]string
102+
103+
// subagentInstances tracks active subagent launches per parent session.
104+
// Key: parentSessionID → map of toolCallID → subagentInstance.
105+
// Protected by sessionsMux.
106+
subagentInstances map[string]map[string]*subagentInstance
107+
84108
isExternalServer bool
85109
conn net.Conn // stores net.Conn for external TCP connections
86110
useStdio bool // resolved value from options
@@ -129,8 +153,11 @@ func NewClient(options *ClientOptions) *Client {
129153
client := &Client{
130154
options: opts,
131155
state: StateDisconnected,
132-
sessions: make(map[string]*Session),
133-
actualHost: "localhost",
156+
sessions: make(map[string]*Session),
157+
childToParent: make(map[string]string),
158+
childToAgent: make(map[string]string),
159+
subagentInstances: make(map[string]map[string]*subagentInstance),
160+
actualHost: "localhost",
134161
isExternalServer: false,
135162
useStdio: true,
136163
autoStart: true, // default
@@ -346,6 +373,9 @@ func (c *Client) Stop() error {
346373

347374
c.sessionsMux.Lock()
348375
c.sessions = make(map[string]*Session)
376+
c.childToParent = make(map[string]string)
377+
c.childToAgent = make(map[string]string)
378+
c.subagentInstances = make(map[string]map[string]*subagentInstance)
349379
c.sessionsMux.Unlock()
350380

351381
c.startStopMux.Lock()
@@ -597,6 +627,12 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses
597627
// events emitted by the CLI (e.g. session.start) are not dropped.
598628
session := newSession(sessionID, c.client, "")
599629

630+
session.customAgents = config.CustomAgents
631+
session.onDestroy = func() {
632+
c.sessionsMux.Lock()
633+
c.removeChildMappingsForParentLocked(session.SessionID)
634+
c.sessionsMux.Unlock()
635+
}
600636
session.registerTools(config.Tools)
601637
session.registerPermissionHandler(config.OnPermissionRequest)
602638
if config.OnUserInputRequest != nil {
@@ -736,6 +772,12 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string,
736772
// events emitted by the CLI (e.g. session.start) are not dropped.
737773
session := newSession(sessionID, c.client, "")
738774

775+
session.customAgents = config.CustomAgents
776+
session.onDestroy = func() {
777+
c.sessionsMux.Lock()
778+
c.removeChildMappingsForParentLocked(session.SessionID)
779+
c.sessionsMux.Unlock()
780+
}
739781
session.registerTools(config.Tools)
740782
session.registerPermissionHandler(config.OnPermissionRequest)
741783
if config.OnUserInputRequest != nil {
@@ -896,6 +938,7 @@ func (c *Client) DeleteSession(ctx context.Context, sessionID string) error {
896938
// Remove from local sessions map if present
897939
c.sessionsMux.Lock()
898940
delete(c.sessions, sessionID)
941+
c.removeChildMappingsForParentLocked(sessionID)
899942
c.sessionsMux.Unlock()
900943

901944
return nil
@@ -1536,21 +1579,160 @@ func (c *Client) handleSessionEvent(req sessionEventRequest) {
15361579
c.sessionsMux.Unlock()
15371580

15381581
if ok {
1582+
// Intercept subagent lifecycle events for child tracking
1583+
c.handleSubagentEvent(req.SessionID, req.Event)
15391584
session.dispatchEvent(req.Event)
15401585
}
15411586
}
15421587

1588+
// handleSubagentEvent intercepts subagent lifecycle events to manage child session tracking.
1589+
func (c *Client) handleSubagentEvent(parentSessionID string, event SessionEvent) {
1590+
switch event.Type {
1591+
case SessionEventTypeSubagentStarted:
1592+
c.onSubagentStarted(parentSessionID, event)
1593+
case SessionEventTypeSubagentCompleted, SessionEventTypeSubagentFailed:
1594+
c.onSubagentEnded(parentSessionID, event)
1595+
}
1596+
}
1597+
1598+
// onSubagentStarted handles a subagent.started event by creating a subagent instance
1599+
// and mapping the child session to its parent.
1600+
func (c *Client) onSubagentStarted(parentSessionID string, event SessionEvent) {
1601+
toolCallID := derefStr(event.Data.ToolCallID)
1602+
agentName := derefStr(event.Data.AgentName)
1603+
childSessionID := derefStr(event.Data.RemoteSessionID)
1604+
1605+
c.sessionsMux.Lock()
1606+
defer c.sessionsMux.Unlock()
1607+
1608+
// Track instance by toolCallID (unique per launch)
1609+
if c.subagentInstances[parentSessionID] == nil {
1610+
c.subagentInstances[parentSessionID] = make(map[string]*subagentInstance)
1611+
}
1612+
c.subagentInstances[parentSessionID][toolCallID] = &subagentInstance{
1613+
agentName: agentName,
1614+
toolCallID: toolCallID,
1615+
childSessionID: childSessionID,
1616+
startedAt: event.Timestamp,
1617+
}
1618+
1619+
// Eagerly map child→parent and child→agent
1620+
if childSessionID != "" {
1621+
c.childToParent[childSessionID] = parentSessionID
1622+
c.childToAgent[childSessionID] = agentName
1623+
}
1624+
}
1625+
1626+
// onSubagentEnded handles subagent.completed and subagent.failed events
1627+
// by removing the subagent instance. Child-to-parent mappings are NOT removed
1628+
// here because in-flight requests may still arrive after the subagent completes.
1629+
func (c *Client) onSubagentEnded(parentSessionID string, event SessionEvent) {
1630+
toolCallID := derefStr(event.Data.ToolCallID)
1631+
1632+
c.sessionsMux.Lock()
1633+
defer c.sessionsMux.Unlock()
1634+
1635+
if instances, ok := c.subagentInstances[parentSessionID]; ok {
1636+
delete(instances, toolCallID)
1637+
if len(instances) == 0 {
1638+
delete(c.subagentInstances, parentSessionID)
1639+
}
1640+
}
1641+
}
1642+
1643+
// derefStr safely dereferences a string pointer, returning "" if nil.
1644+
func derefStr(s *string) string {
1645+
if s == nil {
1646+
return ""
1647+
}
1648+
return *s
1649+
}
1650+
1651+
// resolveSession looks up a session by ID. If the ID is not a directly
1652+
// registered session, it checks whether it is a known child session and
1653+
// returns the parent session instead.
1654+
//
1655+
// Returns (session, isChild, error). isChild=true means the request came
1656+
// from a child session and was resolved via parent lineage.
1657+
//
1658+
// Lock contract: acquires and releases sessionsMux internally.
1659+
// Does NOT hold sessionsMux when returning.
1660+
func (c *Client) resolveSession(sessionID string) (*Session, bool, error) {
1661+
c.sessionsMux.Lock()
1662+
// Direct lookup
1663+
if session, ok := c.sessions[sessionID]; ok {
1664+
c.sessionsMux.Unlock()
1665+
return session, false, nil
1666+
}
1667+
// Child→parent lookup (authoritative mapping only)
1668+
parentID, isChild := c.childToParent[sessionID]
1669+
if !isChild {
1670+
c.sessionsMux.Unlock()
1671+
return nil, false, fmt.Errorf("unknown session %s", sessionID)
1672+
}
1673+
session, ok := c.sessions[parentID]
1674+
c.sessionsMux.Unlock()
1675+
if !ok {
1676+
return nil, false, fmt.Errorf("parent session %s for child %s not found", parentID, sessionID)
1677+
}
1678+
return session, true, nil
1679+
}
1680+
1681+
// removeChildMappingsForParentLocked removes all child mappings for a parent session.
1682+
// MUST be called with sessionsMux held.
1683+
func (c *Client) removeChildMappingsForParentLocked(parentSessionID string) {
1684+
for childID, parentID := range c.childToParent {
1685+
if parentID == parentSessionID {
1686+
delete(c.childToParent, childID)
1687+
delete(c.childToAgent, childID)
1688+
}
1689+
}
1690+
delete(c.subagentInstances, parentSessionID)
1691+
}
1692+
1693+
// isToolAllowedForChild checks whether a tool is in the allowlist for the agent
1694+
// that owns the given child session.
1695+
func (c *Client) isToolAllowedForChild(childSessionID, toolName string) bool {
1696+
c.sessionsMux.Lock()
1697+
agentName, ok := c.childToAgent[childSessionID]
1698+
c.sessionsMux.Unlock()
1699+
if !ok {
1700+
return false // unknown child → deny
1701+
}
1702+
1703+
session, _, _ := c.resolveSession(childSessionID)
1704+
if session == nil {
1705+
return false
1706+
}
1707+
1708+
agentConfig := session.getAgentConfig(agentName)
1709+
if agentConfig == nil {
1710+
return false // agent not found → deny
1711+
}
1712+
1713+
// nil Tools = all tools allowed
1714+
if agentConfig.Tools == nil {
1715+
return true
1716+
}
1717+
1718+
// Explicit list — check membership
1719+
for _, t := range agentConfig.Tools {
1720+
if t == toolName {
1721+
return true
1722+
}
1723+
}
1724+
return false
1725+
}
1726+
15431727
// handleUserInputRequest handles a user input request from the CLI server.
15441728
func (c *Client) handleUserInputRequest(req userInputRequest) (*userInputResponse, *jsonrpc2.Error) {
15451729
if req.SessionID == "" || req.Question == "" {
15461730
return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid user input request payload"}
15471731
}
15481732

1549-
c.sessionsMux.Lock()
1550-
session, ok := c.sessions[req.SessionID]
1551-
c.sessionsMux.Unlock()
1552-
if !ok {
1553-
return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)}
1733+
session, _, err := c.resolveSession(req.SessionID)
1734+
if err != nil {
1735+
return nil, &jsonrpc2.Error{Code: -32602, Message: err.Error()}
15541736
}
15551737

15561738
response, err := session.handleUserInputRequest(UserInputRequest{
@@ -1571,11 +1753,9 @@ func (c *Client) handleHooksInvoke(req hooksInvokeRequest) (map[string]any, *jso
15711753
return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid hooks invoke payload"}
15721754
}
15731755

1574-
c.sessionsMux.Lock()
1575-
session, ok := c.sessions[req.SessionID]
1576-
c.sessionsMux.Unlock()
1577-
if !ok {
1578-
return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)}
1756+
session, _, err := c.resolveSession(req.SessionID)
1757+
if err != nil {
1758+
return nil, &jsonrpc2.Error{Code: -32602, Message: err.Error()}
15791759
}
15801760

15811761
output, err := session.handleHooksInvoke(req.Type, req.Input)
@@ -1646,11 +1826,19 @@ func (c *Client) handleToolCallRequestV2(req toolCallRequestV2) (*toolCallRespon
16461826
return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid tool call payload"}
16471827
}
16481828

1649-
c.sessionsMux.Lock()
1650-
session, ok := c.sessions[req.SessionID]
1651-
c.sessionsMux.Unlock()
1652-
if !ok {
1653-
return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)}
1829+
session, isChild, err := c.resolveSession(req.SessionID)
1830+
if err != nil {
1831+
return nil, &jsonrpc2.Error{Code: -32602, Message: err.Error()}
1832+
}
1833+
1834+
// For child sessions, enforce tool allowlist
1835+
if isChild && !c.isToolAllowedForChild(req.SessionID, req.ToolName) {
1836+
return &toolCallResponseV2{Result: ToolResult{
1837+
TextResultForLLM: fmt.Sprintf("Tool '%s' is not supported by this client instance.", req.ToolName),
1838+
ResultType: "failure",
1839+
Error: fmt.Sprintf("tool '%s' not supported", req.ToolName),
1840+
ToolTelemetry: map[string]any{},
1841+
}}, nil
16541842
}
16551843

16561844
handler, ok := session.getToolHandler(req.ToolName)
@@ -1692,11 +1880,9 @@ func (c *Client) handlePermissionRequestV2(req permissionRequestV2) (*permission
16921880
return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid permission request payload"}
16931881
}
16941882

1695-
c.sessionsMux.Lock()
1696-
session, ok := c.sessions[req.SessionID]
1697-
c.sessionsMux.Unlock()
1698-
if !ok {
1699-
return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)}
1883+
session, _, err := c.resolveSession(req.SessionID)
1884+
if err != nil {
1885+
return nil, &jsonrpc2.Error{Code: -32602, Message: err.Error()}
17001886
}
17011887

17021888
handler := session.getPermissionHandler()

0 commit comments

Comments
 (0)