Skip to content

Commit 538e341

Browse files
committed
feat(plugin, api): prevent plugin recursion on host model callbacks, enable targeted interceptor skipping
- Updated host model callback logic to skip originating plugin's interceptors during nested model executions. - Added `SkipInterceptorPluginID` field to plugin API structs for controlling interceptor bypass behavior. - Introduced supporting logic in host API handlers, plugin host registry, and callback contexts to identify and skip specific plugins. - Enhanced unit tests across plugin host, API handlers, and execution paths to verify interceptor skipping behavior and plugin isolation. - Revised documentation to clarify non-recursive behavior of host model callbacks and the use of `SkipInterceptorPluginID`.
1 parent 8e39db2 commit 538e341

20 files changed

Lines changed: 472 additions & 83 deletions

examples/plugin/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ plugins:
4343
4444
`host-model-callback` declares the Management API capability and exposes a browser resource named `Host Model Callback`. The resource calls `host.model.execute` for non-streaming requests and `host.model.execute_stream` plus `host.model.stream_read` for streaming requests. It demonstrates explicit stream close with `host.model.stream_close` and an `implicit_close=true` option for RPC-scope host cleanup.
4545

46+
When the resource forwards its `host_callback_id`, CPA identifies the plugin that initiated the host model callback and skips that same plugin's interceptors for the nested execution. This makes host model callbacks non-recursive for the caller while allowing other plugins to intercept the nested request.
47+
4648
```yaml
4749
plugins:
4850
configs:

examples/plugin/README_CN.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ plugins:
4343
4444
`host-model-callback` 声明 Management API 能力,并暴露名为 `Host Model Callback` 的浏览器资源。该资源在非流式请求中调用 `host.model.execute`,在流式请求中调用 `host.model.execute_stream` 和 `host.model.stream_read`。它演示了通过 `host.model.stream_close` 显式关闭流,也提供 `implicit_close=true` 用于演示 RPC 作用域结束时的宿主隐式清理。
4545

46+
当该资源转发自身收到的 `host_callback_id` 时,CPA 会识别发起宿主模型回调的插件,并在嵌套模型执行中跳过同一个插件的拦截器。因此宿主模型回调不会递归调用发起插件自身,但其他已启用插件仍可拦截这次嵌套请求。
47+
4648
```yaml
4749
plugins:
4850
configs:

examples/plugin/host-model-callback/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,12 @@ By default, streaming mode explicitly closes the host-owned stream with `host.mo
115115

116116
When `implicit_close=true` is set, the plugin intentionally skips the explicit close call. CPA injects `host_callback_id` into the `management.handle` request, and this example forwards that callback ID to `host.model.execute_stream` so the host can close the stream when the `management.handle` RPC callback scope returns. This mode exists only to demonstrate host cleanup behavior; normal plugin code should explicitly close streams it opens.
117117

118+
## Recursion Guard
119+
120+
This example forwards the `host_callback_id` received from `management.handle` when it calls `host.model.execute` or `host.model.execute_stream`. CPA uses that callback scope to identify the plugin that initiated the host model callback and skips that same plugin's request, response, and stream interceptors for the nested model execution.
121+
122+
Host model callbacks are therefore not recursive for the caller. Other enabled plugins can still intercept the nested request.
123+
118124
## Billing and Usage
119125

120126
The callback uses the existing CPA model executor path. Usage collection, request accounting, and billing metadata are handled by the same executor and usage reporter path as normal proxied requests. The callback layer does not bill twice and does not create an additional usage record by itself.

examples/plugin/host-model-callback/go/main.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,9 @@ func executeOnce(opts runOptions) (pluginapi.HostModelExecutionResponse, error)
427427
if errBody != nil {
428428
return pluginapi.HostModelExecutionResponse{}, errBody
429429
}
430+
// Forward HostCallbackID so the host skips this plugin's interceptors on the
431+
// nested model execution. Host model callbacks do not recursively call the
432+
// originating plugin's interceptor chain.
430433
result, errCall := callHost(pluginabi.MethodHostModelExecute, hostModelExecutionRequest{
431434
HostModelExecutionRequest: pluginapi.HostModelExecutionRequest{
432435
EntryProtocol: opts.EntryProtocol,
@@ -456,6 +459,9 @@ func executeStream(opts runOptions) (data streamPageData) {
456459
data.Error = errBody.Error()
457460
return data
458461
}
462+
// Forward HostCallbackID so the host skips this plugin's interceptors on the
463+
// nested model execution. Host model callbacks do not recursively call the
464+
// originating plugin's interceptor chain.
459465
result, errCall := callHost(pluginabi.MethodHostModelExecuteStream, hostModelExecutionRequest{
460466
HostModelExecutionRequest: pluginapi.HostModelExecutionRequest{
461467
EntryProtocol: opts.EntryProtocol,

internal/pluginhost/abi.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,5 @@ type pluginClient interface {
1414
}
1515

1616
type pluginLoader interface {
17-
Open(path string, host *Host) (pluginClient, error)
17+
Open(file pluginFile, host *Host) (pluginClient, error)
1818
}

internal/pluginhost/adapters.go

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -569,25 +569,34 @@ func (h *Host) callStreamChunkInterceptor(ctx context.Context, pluginID string,
569569
}
570570

571571
func (h *Host) InterceptRequestBeforeAuth(ctx context.Context, req pluginapi.RequestInterceptRequest) pluginapi.RequestInterceptResponse {
572+
return h.InterceptRequestBeforeAuthExcept(ctx, req, "")
573+
}
574+
575+
func (h *Host) InterceptRequestBeforeAuthExcept(ctx context.Context, req pluginapi.RequestInterceptRequest, skipPluginID string) pluginapi.RequestInterceptResponse {
572576
return h.interceptRequest(ctx, req, "RequestInterceptor.InterceptRequestBeforeAuth", func(interceptor pluginapi.RequestInterceptor, ctx context.Context, req pluginapi.RequestInterceptRequest) (pluginapi.RequestInterceptResponse, error) {
573577
return interceptor.InterceptRequestBeforeAuth(ctx, req)
574-
})
578+
}, skipPluginID)
575579
}
576580

577581
func (h *Host) InterceptRequestAfterAuth(ctx context.Context, req pluginapi.RequestInterceptRequest) pluginapi.RequestInterceptResponse {
582+
return h.InterceptRequestAfterAuthExcept(ctx, req, "")
583+
}
584+
585+
func (h *Host) InterceptRequestAfterAuthExcept(ctx context.Context, req pluginapi.RequestInterceptRequest, skipPluginID string) pluginapi.RequestInterceptResponse {
578586
return h.interceptRequest(ctx, req, "RequestInterceptor.InterceptRequestAfterAuth", func(interceptor pluginapi.RequestInterceptor, ctx context.Context, req pluginapi.RequestInterceptRequest) (pluginapi.RequestInterceptResponse, error) {
579587
return interceptor.InterceptRequestAfterAuth(ctx, req)
580-
})
588+
}, skipPluginID)
581589
}
582590

583-
func (h *Host) interceptRequest(ctx context.Context, req pluginapi.RequestInterceptRequest, method string, invoke func(pluginapi.RequestInterceptor, context.Context, pluginapi.RequestInterceptRequest) (pluginapi.RequestInterceptResponse, error)) pluginapi.RequestInterceptResponse {
591+
func (h *Host) interceptRequest(ctx context.Context, req pluginapi.RequestInterceptRequest, method string, invoke func(pluginapi.RequestInterceptor, context.Context, pluginapi.RequestInterceptRequest) (pluginapi.RequestInterceptResponse, error), skipPluginID string) pluginapi.RequestInterceptResponse {
584592
current := pluginapi.RequestInterceptResponse{
585593
Headers: cloneHeader(req.Headers),
586594
Body: bytes.Clone(req.Body),
587595
}
596+
skipPluginID = strings.TrimSpace(skipPluginID)
588597
for _, record := range h.Snapshot().records {
589598
interceptor := record.plugin.Capabilities.RequestInterceptor
590-
if h.isPluginFused(record.id) || interceptor == nil {
599+
if h.isPluginFused(record.id) || interceptor == nil || record.id == skipPluginID {
591600
continue
592601
}
593602
nextReq := req
@@ -607,13 +616,18 @@ func (h *Host) interceptRequest(ctx context.Context, req pluginapi.RequestInterc
607616
}
608617

609618
func (h *Host) InterceptResponse(ctx context.Context, req pluginapi.ResponseInterceptRequest) pluginapi.ResponseInterceptResponse {
619+
return h.InterceptResponseExcept(ctx, req, "")
620+
}
621+
622+
func (h *Host) InterceptResponseExcept(ctx context.Context, req pluginapi.ResponseInterceptRequest, skipPluginID string) pluginapi.ResponseInterceptResponse {
610623
current := pluginapi.ResponseInterceptResponse{
611624
Headers: cloneHeader(req.ResponseHeaders),
612625
Body: bytes.Clone(req.Body),
613626
}
627+
skipPluginID = strings.TrimSpace(skipPluginID)
614628
for _, record := range h.Snapshot().records {
615629
interceptor := record.plugin.Capabilities.ResponseInterceptor
616-
if h.isPluginFused(record.id) || interceptor == nil {
630+
if h.isPluginFused(record.id) || interceptor == nil || record.id == skipPluginID {
617631
continue
618632
}
619633
nextReq := req
@@ -634,13 +648,18 @@ func (h *Host) InterceptResponse(ctx context.Context, req pluginapi.ResponseInte
634648
}
635649

636650
func (h *Host) InterceptStreamChunk(ctx context.Context, req pluginapi.StreamChunkInterceptRequest) pluginapi.StreamChunkInterceptResponse {
651+
return h.InterceptStreamChunkExcept(ctx, req, "")
652+
}
653+
654+
func (h *Host) InterceptStreamChunkExcept(ctx context.Context, req pluginapi.StreamChunkInterceptRequest, skipPluginID string) pluginapi.StreamChunkInterceptResponse {
637655
current := pluginapi.StreamChunkInterceptResponse{
638656
Headers: cloneHeader(req.ResponseHeaders),
639657
Body: bytes.Clone(req.Body),
640658
}
659+
skipPluginID = strings.TrimSpace(skipPluginID)
641660
for _, record := range h.Snapshot().records {
642661
interceptor := record.plugin.Capabilities.StreamChunkInterceptor
643-
if h.isPluginFused(record.id) || interceptor == nil || current.DropChunk {
662+
if h.isPluginFused(record.id) || interceptor == nil || current.DropChunk || record.id == skipPluginID {
644663
continue
645664
}
646665
nextReq := req

internal/pluginhost/adapters_test.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,6 +1342,81 @@ func TestInterceptRequestAfterAuthPassesTargetFormat(t *testing.T) {
13421342
}
13431343
}
13441344

1345+
func TestInterceptorsSkipExceptedPlugin(t *testing.T) {
1346+
originCalls := 0
1347+
otherCalls := 0
1348+
host := newHostWithRecords(
1349+
capabilityRecord{
1350+
id: "origin",
1351+
priority: 20,
1352+
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
1353+
RequestInterceptor: requestInterceptorFunc(func(ctx context.Context, req pluginapi.RequestInterceptRequest) (pluginapi.RequestInterceptResponse, error) {
1354+
originCalls++
1355+
return pluginapi.RequestInterceptResponse{Body: append(req.Body, []byte("|origin-request")...)}, nil
1356+
}),
1357+
ResponseInterceptor: responseInterceptorFunc{
1358+
interceptResponse: func(ctx context.Context, req pluginapi.ResponseInterceptRequest) (pluginapi.ResponseInterceptResponse, error) {
1359+
originCalls++
1360+
return pluginapi.ResponseInterceptResponse{Body: append(req.Body, []byte("|origin-response")...)}, nil
1361+
},
1362+
},
1363+
StreamChunkInterceptor: responseInterceptorFunc{
1364+
interceptStreamChunk: func(ctx context.Context, req pluginapi.StreamChunkInterceptRequest) (pluginapi.StreamChunkInterceptResponse, error) {
1365+
originCalls++
1366+
return pluginapi.StreamChunkInterceptResponse{Body: append(req.Body, []byte("|origin-stream")...)}, nil
1367+
},
1368+
},
1369+
}},
1370+
},
1371+
capabilityRecord{
1372+
id: "other",
1373+
priority: 10,
1374+
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
1375+
RequestInterceptor: requestInterceptorFunc(func(ctx context.Context, req pluginapi.RequestInterceptRequest) (pluginapi.RequestInterceptResponse, error) {
1376+
otherCalls++
1377+
return pluginapi.RequestInterceptResponse{Body: append(req.Body, []byte("|other-request")...)}, nil
1378+
}),
1379+
ResponseInterceptor: responseInterceptorFunc{
1380+
interceptResponse: func(ctx context.Context, req pluginapi.ResponseInterceptRequest) (pluginapi.ResponseInterceptResponse, error) {
1381+
otherCalls++
1382+
return pluginapi.ResponseInterceptResponse{Body: append(req.Body, []byte("|other-response")...)}, nil
1383+
},
1384+
},
1385+
StreamChunkInterceptor: responseInterceptorFunc{
1386+
interceptStreamChunk: func(ctx context.Context, req pluginapi.StreamChunkInterceptRequest) (pluginapi.StreamChunkInterceptResponse, error) {
1387+
otherCalls++
1388+
return pluginapi.StreamChunkInterceptResponse{Body: append(req.Body, []byte("|other-stream")...)}, nil
1389+
},
1390+
},
1391+
}},
1392+
},
1393+
)
1394+
1395+
reqOut := host.InterceptRequestBeforeAuthExcept(context.Background(), pluginapi.RequestInterceptRequest{Body: []byte("body")}, "origin")
1396+
afterOut := host.InterceptRequestAfterAuthExcept(context.Background(), pluginapi.RequestInterceptRequest{Body: []byte("body")}, "origin")
1397+
respOut := host.InterceptResponseExcept(context.Background(), pluginapi.ResponseInterceptRequest{Body: []byte("body")}, "origin")
1398+
streamOut := host.InterceptStreamChunkExcept(context.Background(), pluginapi.StreamChunkInterceptRequest{Body: []byte("body")}, "origin")
1399+
1400+
if originCalls != 0 {
1401+
t.Fatalf("origin plugin calls = %d, want 0", originCalls)
1402+
}
1403+
if otherCalls != 4 {
1404+
t.Fatalf("other plugin calls = %d, want 4", otherCalls)
1405+
}
1406+
if string(reqOut.Body) != "body|other-request" {
1407+
t.Fatalf("request body = %q, want body|other-request", reqOut.Body)
1408+
}
1409+
if string(afterOut.Body) != "body|other-request" {
1410+
t.Fatalf("after-auth request body = %q, want body|other-request", afterOut.Body)
1411+
}
1412+
if string(respOut.Body) != "body|other-response" {
1413+
t.Fatalf("response body = %q, want body|other-response", respOut.Body)
1414+
}
1415+
if string(streamOut.Body) != "body|other-stream" {
1416+
t.Fatalf("stream body = %q, want body|other-stream", streamOut.Body)
1417+
}
1418+
}
1419+
13451420
func TestResponseInterceptorsChainAndStreamHistory(t *testing.T) {
13461421
var seenHistory [][]byte
13471422
var sawSecondResponse bool

internal/pluginhost/callback_contexts.go

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package pluginhost
33
import (
44
"context"
55
"strconv"
6+
"strings"
67
"sync"
78
"sync/atomic"
89
)
@@ -14,24 +15,27 @@ type callbackContextRegistry struct {
1415
}
1516

1617
type callbackContextEntry struct {
17-
ctx context.Context
18-
cleanup []func()
18+
ctx context.Context
19+
pluginID string
20+
cleanup []func()
1921
}
2022

2123
func newCallbackContextRegistry() *callbackContextRegistry {
2224
return &callbackContextRegistry{contexts: make(map[string]callbackContextEntry)}
2325
}
2426

25-
func (r *callbackContextRegistry) open(ctx context.Context) (string, func()) {
27+
func (r *callbackContextRegistry) open(ctx context.Context, pluginID string) (string, func()) {
2628
if r == nil {
2729
return "", func() {}
2830
}
2931
if ctx == nil {
3032
ctx = context.Background()
3133
}
34+
pluginID = strings.TrimSpace(pluginID)
35+
ctx = withHostCallbackPluginID(ctx, pluginID)
3236
id := strconv.FormatUint(r.next.Add(1), 10)
3337
r.mu.Lock()
34-
r.contexts[id] = callbackContextEntry{ctx: ctx}
38+
r.contexts[id] = callbackContextEntry{ctx: ctx, pluginID: pluginID}
3539
r.mu.Unlock()
3640

3741
var once sync.Once
@@ -52,6 +56,16 @@ func (r *callbackContextRegistry) open(ctx context.Context) (string, func()) {
5256
}
5357
}
5458

59+
func (r *callbackContextRegistry) pluginID(id string) string {
60+
if r == nil || id == "" {
61+
return ""
62+
}
63+
r.mu.RLock()
64+
entry := r.contexts[id]
65+
r.mu.RUnlock()
66+
return strings.TrimSpace(entry.pluginID)
67+
}
68+
5569
func (r *callbackContextRegistry) addCleanup(id string, cleanup func()) bool {
5670
if r == nil || id == "" || cleanup == nil {
5771
return false
@@ -87,10 +101,14 @@ func (r *callbackContextRegistry) resolve(id string, fallback context.Context) c
87101
}
88102

89103
func (h *Host) openCallbackContext(ctx context.Context) (string, func()) {
104+
return h.openCallbackContextForPlugin(ctx, "")
105+
}
106+
107+
func (h *Host) openCallbackContextForPlugin(ctx context.Context, pluginID string) (string, func()) {
90108
if h == nil || h.callbackContexts == nil {
91109
return "", func() {}
92110
}
93-
return h.callbackContexts.open(ctx)
111+
return h.callbackContexts.open(ctx, pluginID)
94112
}
95113

96114
func (h *Host) addCallbackCleanup(id string, cleanup func()) bool {
@@ -112,3 +130,10 @@ func (h *Host) resolveCallbackContext(id string, fallback context.Context) conte
112130
}
113131
return h.callbackContexts.resolve(id, fallback)
114132
}
133+
134+
func (h *Host) callbackContextPluginID(id string) string {
135+
if h == nil || h.callbackContexts == nil {
136+
return ""
137+
}
138+
return h.callbackContexts.pluginID(id)
139+
}

internal/pluginhost/host.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ func (h *Host) ApplyConfig(ctx context.Context, cfg *config.Config) {
186186
}
187187

188188
func (h *Host) loadLocked(file pluginFile) (*loadedPlugin, error) {
189-
client, errOpen := h.loader.Open(file.Path, h)
189+
client, errOpen := h.loader.Open(file, h)
190190
if errOpen != nil {
191191
return nil, errOpen
192192
}

0 commit comments

Comments
 (0)