Skip to content

Commit 9f940f1

Browse files
authored
fix(pluginhost): keep stream callbacks alive until stream close
Keep RPC streaming executor callback scopes alive until async streams close, detach nested host.model.execute_stream contexts from request cancellation, and clean up the stream bridge on stream completion.
1 parent 907e349 commit 9f940f1

6 files changed

Lines changed: 370 additions & 106 deletions

File tree

internal/pluginhost/host_callbacks.go

Lines changed: 0 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -291,83 +291,6 @@ func (h *Host) callHostModelExecute(ctx context.Context, request []byte) ([]byte
291291
})
292292
}
293293

294-
func (h *Host) callHostModelExecuteStream(ctx context.Context, request []byte) ([]byte, error) {
295-
var req rpcHostModelExecutionRequest
296-
if errUnmarshal := json.Unmarshal(request, &req); errUnmarshal != nil {
297-
return nil, fmt.Errorf("decode host model execution stream request: %w", errUnmarshal)
298-
}
299-
if !req.Stream {
300-
return nil, fmt.Errorf("host.model.execute_stream requires stream=true")
301-
}
302-
executor := h.currentModelExecutor()
303-
if executor == nil {
304-
return nil, fmt.Errorf("host model executor is unavailable")
305-
}
306-
skipPluginID := h.callbackCallerPluginID(ctx, req.HostCallbackID)
307-
ctx = h.resolveCallbackContext(req.HostCallbackID, ctx)
308-
if ctx == nil {
309-
ctx = context.Background()
310-
}
311-
streamCtx, cancel := context.WithCancel(ctx)
312-
stream, errMsg := executor.ExecuteModelStream(streamCtx, modelExecutionRequestFromPlugin(req.HostModelExecutionRequest, skipPluginID))
313-
if errMsg != nil {
314-
cancel()
315-
return nil, modelExecutionError(errMsg)
316-
}
317-
streamID := ""
318-
if h != nil && h.modelStreams != nil {
319-
streamID = h.modelStreams.open(req.HostCallbackID, stream.Chunks, cancel)
320-
}
321-
if streamID == "" {
322-
cancel()
323-
return nil, fmt.Errorf("host model stream bridge is unavailable")
324-
}
325-
if req.HostCallbackID != "" {
326-
h.addCallbackCleanup(req.HostCallbackID, func() {
327-
h.modelStreams.close(streamID)
328-
})
329-
}
330-
return marshalRPCResult(pluginapi.HostModelStreamResponse{
331-
StatusCode: stream.StatusCode,
332-
Headers: cloneHeader(stream.Headers),
333-
StreamID: streamID,
334-
})
335-
}
336-
337-
func (h *Host) callHostModelStreamRead(ctx context.Context, request []byte) ([]byte, error) {
338-
var req pluginapi.HostModelStreamReadRequest
339-
if errUnmarshal := json.Unmarshal(request, &req); errUnmarshal != nil {
340-
return nil, fmt.Errorf("decode host model stream read request: %w", errUnmarshal)
341-
}
342-
if h == nil || h.modelStreams == nil {
343-
return nil, fmt.Errorf("host model stream bridge is unavailable")
344-
}
345-
chunk, done, errRead := h.modelStreams.read(ctx, req.StreamID)
346-
if errRead != nil {
347-
return nil, errRead
348-
}
349-
resp := pluginapi.HostModelStreamReadResponse{
350-
Payload: append([]byte(nil), chunk.Payload...),
351-
Done: done,
352-
}
353-
if chunk.Err != nil {
354-
resp.Error = chunk.Err.Error()
355-
resp.Done = true
356-
}
357-
return marshalRPCResult(resp)
358-
}
359-
360-
func (h *Host) callHostModelStreamClose(request []byte) ([]byte, error) {
361-
var req pluginapi.HostModelStreamCloseRequest
362-
if errUnmarshal := json.Unmarshal(request, &req); errUnmarshal != nil {
363-
return nil, fmt.Errorf("decode host model stream close request: %w", errUnmarshal)
364-
}
365-
if h != nil && h.modelStreams != nil {
366-
h.modelStreams.close(req.StreamID)
367-
}
368-
return marshalRPCResult(rpcEmptyResponse{})
369-
}
370-
371294
func modelExecutionRequestFromPlugin(req pluginapi.HostModelExecutionRequest, skipPluginID string) handlers.ModelExecutionRequest {
372295
return handlers.ModelExecutionRequest{
373296
EntryProtocol: req.EntryProtocol,
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
package pluginhost
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
8+
"github.com/router-for-me/CLIProxyAPI/v7/sdk/pluginapi"
9+
)
10+
11+
func (h *Host) callHostModelExecuteStream(ctx context.Context, request []byte) ([]byte, error) {
12+
var req rpcHostModelExecutionRequest
13+
if errUnmarshal := json.Unmarshal(request, &req); errUnmarshal != nil {
14+
return nil, fmt.Errorf("decode host model execution stream request: %w", errUnmarshal)
15+
}
16+
if !req.Stream {
17+
return nil, fmt.Errorf("host.model.execute_stream requires stream=true")
18+
}
19+
executor := h.currentModelExecutor()
20+
if executor == nil {
21+
return nil, fmt.Errorf("host model executor is unavailable")
22+
}
23+
skipPluginID := h.callbackCallerPluginID(ctx, req.HostCallbackID)
24+
callbackCtx := h.resolveCallbackContext(req.HostCallbackID, ctx)
25+
if callbackCtx == nil {
26+
callbackCtx = context.Background()
27+
}
28+
// Detach request cancellation while preserving callback values; callback cleanup owns the model stream lifetime.
29+
streamCtx, cancel := context.WithCancel(context.WithoutCancel(callbackCtx))
30+
stream, errMsg := executor.ExecuteModelStream(streamCtx, modelExecutionRequestFromPlugin(req.HostModelExecutionRequest, skipPluginID))
31+
if errMsg != nil {
32+
cancel()
33+
return nil, modelExecutionError(errMsg)
34+
}
35+
streamID := ""
36+
if h.modelStreams != nil {
37+
streamID = h.modelStreams.open(req.HostCallbackID, stream.Chunks, cancel)
38+
}
39+
if streamID == "" {
40+
cancel()
41+
return nil, fmt.Errorf("host model stream bridge is unavailable")
42+
}
43+
if req.HostCallbackID != "" {
44+
h.addCallbackCleanup(req.HostCallbackID, func() {
45+
h.modelStreams.close(streamID)
46+
})
47+
}
48+
return marshalRPCResult(pluginapi.HostModelStreamResponse{
49+
StatusCode: stream.StatusCode,
50+
Headers: cloneHeader(stream.Headers),
51+
StreamID: streamID,
52+
})
53+
}
54+
55+
func (h *Host) callHostModelStreamRead(ctx context.Context, request []byte) ([]byte, error) {
56+
var req pluginapi.HostModelStreamReadRequest
57+
if errUnmarshal := json.Unmarshal(request, &req); errUnmarshal != nil {
58+
return nil, fmt.Errorf("decode host model stream read request: %w", errUnmarshal)
59+
}
60+
if h == nil || h.modelStreams == nil {
61+
return nil, fmt.Errorf("host model stream bridge is unavailable")
62+
}
63+
chunk, done, errRead := h.modelStreams.read(ctx, req.StreamID)
64+
if errRead != nil {
65+
return nil, errRead
66+
}
67+
resp := pluginapi.HostModelStreamReadResponse{
68+
Payload: append([]byte(nil), chunk.Payload...),
69+
Done: done,
70+
}
71+
if chunk.Err != nil {
72+
resp.Error = chunk.Err.Error()
73+
resp.Done = true
74+
}
75+
return marshalRPCResult(resp)
76+
}
77+
78+
func (h *Host) callHostModelStreamClose(request []byte) ([]byte, error) {
79+
var req pluginapi.HostModelStreamCloseRequest
80+
if errUnmarshal := json.Unmarshal(request, &req); errUnmarshal != nil {
81+
return nil, fmt.Errorf("decode host model stream close request: %w", errUnmarshal)
82+
}
83+
if h != nil && h.modelStreams != nil {
84+
h.modelStreams.close(req.StreamID)
85+
}
86+
return marshalRPCResult(rpcEmptyResponse{})
87+
}
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
package pluginhost
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"net/http"
7+
"testing"
8+
"time"
9+
10+
"github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces"
11+
"github.com/router-for-me/CLIProxyAPI/v7/sdk/api/handlers"
12+
"github.com/router-for-me/CLIProxyAPI/v7/sdk/pluginabi"
13+
"github.com/router-for-me/CLIProxyAPI/v7/sdk/pluginapi"
14+
)
15+
16+
func TestHostModelExecuteStreamDetachesFromCallbackParentCancel(t *testing.T) {
17+
host := New()
18+
ctxSeen := make(chan context.Context, 1)
19+
host.SetModelExecutor(&fakeHostModelExecutor{
20+
executeModelStream: func(ctx context.Context, req handlers.ModelExecutionRequest) (handlers.ModelExecutionStream, *interfaces.ErrorMessage) {
21+
ctxSeen <- ctx
22+
return handlers.ModelExecutionStream{
23+
StatusCode: http.StatusOK,
24+
Chunks: make(chan handlers.ModelExecutionChunk),
25+
}, nil
26+
},
27+
})
28+
parentCtx, cancelParent := context.WithCancel(context.Background())
29+
callbackID, closeCallback := host.openCallbackContext(parentCtx)
30+
defer closeCallback()
31+
32+
rawReq, errMarshal := json.Marshal(rpcHostModelExecutionRequest{
33+
HostModelExecutionRequest: pluginapi.HostModelExecutionRequest{
34+
EntryProtocol: "openai",
35+
ExitProtocol: "openai",
36+
Model: "model-1",
37+
Stream: true,
38+
Body: []byte(`{"stream":true}`),
39+
},
40+
HostCallbackID: callbackID,
41+
})
42+
if errMarshal != nil {
43+
t.Fatalf("marshal request: %v", errMarshal)
44+
}
45+
rawResp, errCall := host.callFromPlugin(context.Background(), pluginabi.MethodHostModelExecuteStream, rawReq)
46+
if errCall != nil {
47+
t.Fatalf("callFromPlugin() error = %v", errCall)
48+
}
49+
resp, errDecode := decodeRPCEnvelope[pluginapi.HostModelStreamResponse](rawResp)
50+
if errDecode != nil {
51+
t.Fatalf("decode response: %v", errDecode)
52+
}
53+
if resp.StreamID == "" {
54+
t.Fatalf("stream id is empty: %#v", resp)
55+
}
56+
57+
var streamCtx context.Context
58+
select {
59+
case streamCtx = <-ctxSeen:
60+
case <-time.After(time.Second):
61+
t.Fatal("model executor was not called")
62+
}
63+
cancelParent()
64+
select {
65+
case <-streamCtx.Done():
66+
t.Fatal("stream context was canceled by callback parent context")
67+
default:
68+
}
69+
70+
closeCallback()
71+
select {
72+
case <-streamCtx.Done():
73+
case <-time.After(time.Second):
74+
t.Fatal("stream context was not canceled after callback scope closed")
75+
}
76+
}

internal/pluginhost/rpc_client.go

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -377,35 +377,6 @@ func (a *rpcPluginAdapter) Execute(ctx context.Context, req pluginapi.ExecutorRe
377377
})
378378
}
379379

380-
func (a *rpcPluginAdapter) ExecuteStream(ctx context.Context, req pluginapi.ExecutorRequest) (pluginapi.ExecutorStreamResponse, error) {
381-
if a == nil || a.host == nil || a.host.streams == nil {
382-
return pluginapi.ExecutorStreamResponse{}, fmt.Errorf("plugin stream bridge is unavailable")
383-
}
384-
streamID, chunks, cleanup := a.host.streams.open(ctx)
385-
callbackID, closeCallback := a.openHostCallbackContext(ctx)
386-
defer closeCallback()
387-
rpcReq := rpcExecutorRequest{
388-
ExecutorRequest: req,
389-
StreamID: streamID,
390-
HostCallbackID: callbackID,
391-
}
392-
resp, errCall := callPlugin[rpcExecutorStreamResponse](ctx, a.client, pluginabi.MethodExecutorExecuteStream, rpcReq)
393-
if errCall != nil {
394-
cleanup()
395-
return pluginapi.ExecutorStreamResponse{}, errCall
396-
}
397-
if len(resp.Chunks) > 0 {
398-
cleanup()
399-
out := make(chan pluginapi.ExecutorStreamChunk, len(resp.Chunks))
400-
for _, chunk := range resp.Chunks {
401-
out <- chunk
402-
}
403-
close(out)
404-
return pluginapi.ExecutorStreamResponse{Headers: resp.Headers, Chunks: out}, nil
405-
}
406-
return pluginapi.ExecutorStreamResponse{Headers: resp.Headers, Chunks: chunks}, nil
407-
}
408-
409380
func (a *rpcPluginAdapter) CountTokens(ctx context.Context, req pluginapi.ExecutorRequest) (pluginapi.ExecutorResponse, error) {
410381
callbackID, closeCallback := a.openHostCallbackContext(ctx)
411382
defer closeCallback()
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
package pluginhost
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"sync"
7+
8+
"github.com/router-for-me/CLIProxyAPI/v7/sdk/pluginabi"
9+
"github.com/router-for-me/CLIProxyAPI/v7/sdk/pluginapi"
10+
)
11+
12+
func (a *rpcPluginAdapter) ExecuteStream(ctx context.Context, req pluginapi.ExecutorRequest) (pluginapi.ExecutorStreamResponse, error) {
13+
if a == nil || a.host == nil || a.host.streams == nil {
14+
return pluginapi.ExecutorStreamResponse{}, fmt.Errorf("plugin stream bridge is unavailable")
15+
}
16+
streamID, chunks, cleanupStream := a.host.streams.open(ctx)
17+
callbackID, closeCallback := a.openHostCallbackContext(ctx)
18+
cleanup := combinedCleanup(cleanupStream, closeCallback)
19+
rpcReq := rpcExecutorRequest{
20+
ExecutorRequest: req,
21+
StreamID: streamID,
22+
HostCallbackID: callbackID,
23+
}
24+
resp, errCall := callPlugin[rpcExecutorStreamResponse](ctx, a.client, pluginabi.MethodExecutorExecuteStream, rpcReq)
25+
if errCall != nil {
26+
cleanup()
27+
return pluginapi.ExecutorStreamResponse{}, errCall
28+
}
29+
if len(resp.Chunks) > 0 {
30+
cleanup()
31+
out := make(chan pluginapi.ExecutorStreamChunk, len(resp.Chunks))
32+
for _, chunk := range resp.Chunks {
33+
out <- chunk
34+
}
35+
close(out)
36+
return pluginapi.ExecutorStreamResponse{Headers: resp.Headers, Chunks: out}, nil
37+
}
38+
// Async streaming plugins can return before they finish emitting chunks, so keep callbacks alive until the stream ends.
39+
return pluginapi.ExecutorStreamResponse{
40+
Headers: resp.Headers,
41+
Chunks: cleanupWhenStreamDone(ctx, chunks, cleanup),
42+
}, nil
43+
}
44+
45+
func combinedCleanup(cleanups ...func()) func() {
46+
var once sync.Once
47+
return func() {
48+
once.Do(func() {
49+
for _, cleanup := range cleanups {
50+
if cleanup != nil {
51+
cleanup()
52+
}
53+
}
54+
})
55+
}
56+
}
57+
58+
func cleanupWhenStreamDone(ctx context.Context, chunks <-chan pluginapi.ExecutorStreamChunk, cleanup func()) <-chan pluginapi.ExecutorStreamChunk {
59+
out := make(chan pluginapi.ExecutorStreamChunk)
60+
go func() {
61+
defer func() {
62+
if cleanup != nil {
63+
cleanup()
64+
}
65+
close(out)
66+
}()
67+
var done <-chan struct{}
68+
if ctx != nil {
69+
done = ctx.Done()
70+
}
71+
for chunk := range chunks {
72+
select {
73+
case out <- chunk:
74+
case <-done:
75+
return
76+
}
77+
}
78+
}()
79+
return out
80+
}

0 commit comments

Comments
 (0)