Skip to content

Commit 827d1ce

Browse files
fateleiWH-2099
andauthored
feat: support plugin poll feature (#744)
* feat: support plugin poll * fix: validate polling payloads --------- Co-authored-by: WH-2099 <wh2099@pm.me>
1 parent 4e09b97 commit 827d1ce

9 files changed

Lines changed: 397 additions & 0 deletions

File tree

internal/core/io_tunnel/access_types/access.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ const (
4545
PLUGIN_ACCESS_ACTION_GET_TEXT_EMBEDDING_NUM_TOKENS PluginAccessAction = "get_text_embedding_num_tokens"
4646
PLUGIN_ACCESS_ACTION_GET_AI_MODEL_SCHEMAS PluginAccessAction = "get_ai_model_schemas"
4747
PLUGIN_ACCESS_ACTION_GET_LLM_NUM_TOKENS PluginAccessAction = "get_llm_num_tokens"
48+
PLUGIN_ACCESS_ACTION_START_POLLING PluginAccessAction = "start_polling"
49+
PLUGIN_ACCESS_ACTION_CHECK_POLLING PluginAccessAction = "check_polling"
4850
PLUGIN_ACCESS_ACTION_INVOKE_AGENT_STRATEGY PluginAccessAction = "invoke_agent_strategy"
4951
PLUGIN_ACCESS_ACTION_GET_AUTHORIZATION_URL PluginAccessAction = "get_authorization_url"
5052
PLUGIN_ACCESS_ACTION_GET_CREDENTIALS PluginAccessAction = "get_credentials"
@@ -81,6 +83,8 @@ func (p PluginAccessAction) IsValid() bool {
8183
p == PLUGIN_ACCESS_ACTION_GET_TEXT_EMBEDDING_NUM_TOKENS ||
8284
p == PLUGIN_ACCESS_ACTION_GET_AI_MODEL_SCHEMAS ||
8385
p == PLUGIN_ACCESS_ACTION_GET_LLM_NUM_TOKENS ||
86+
p == PLUGIN_ACCESS_ACTION_START_POLLING ||
87+
p == PLUGIN_ACCESS_ACTION_CHECK_POLLING ||
8488
p == PLUGIN_ACCESS_ACTION_INVOKE_AGENT_STRATEGY ||
8589
p == PLUGIN_ACCESS_ACTION_GET_AUTHORIZATION_URL ||
8690
p == PLUGIN_ACCESS_ACTION_GET_CREDENTIALS ||
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package io_tunnel
2+
3+
import (
4+
"errors"
5+
6+
"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
7+
"github.com/langgenius/dify-plugin-daemon/pkg/entities/model_entities"
8+
"github.com/langgenius/dify-plugin-daemon/pkg/entities/requests"
9+
)
10+
11+
func StartPolling(
12+
session *session_manager.Session,
13+
request *requests.RequestStartPolling,
14+
) (*model_entities.ModelPollingResult, error) {
15+
return invokeSingleResponse[requests.RequestStartPolling, model_entities.ModelPollingResult](session, request)
16+
}
17+
18+
func CheckPolling(
19+
session *session_manager.Session,
20+
request *requests.RequestCheckPolling,
21+
) (*model_entities.ModelPollingResult, error) {
22+
return invokeSingleResponse[requests.RequestCheckPolling, model_entities.ModelPollingResult](session, request)
23+
}
24+
25+
func invokeSingleResponse[Req any, Rsp any](
26+
session *session_manager.Session,
27+
request *Req,
28+
) (*Rsp, error) {
29+
response, err := GenericInvokePlugin[Req, Rsp](session, request, 1)
30+
if err != nil {
31+
return nil, err
32+
}
33+
defer response.Close()
34+
35+
var (
36+
got bool
37+
item Rsp
38+
)
39+
40+
for response.Next() {
41+
value, readErr := response.Read()
42+
if readErr != nil {
43+
return nil, readErr
44+
}
45+
item = value
46+
got = true
47+
}
48+
49+
if !got {
50+
return nil, errors.New("no polling result received from plugin")
51+
}
52+
53+
return &item, nil
54+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package controllers
2+
3+
import (
4+
"github.com/gin-gonic/gin"
5+
"github.com/langgenius/dify-plugin-daemon/internal/service"
6+
"github.com/langgenius/dify-plugin-daemon/pkg/entities/plugin_entities"
7+
"github.com/langgenius/dify-plugin-daemon/pkg/entities/requests"
8+
)
9+
10+
func StartPolling() gin.HandlerFunc {
11+
type request = plugin_entities.InvokePluginRequest[requests.RequestStartPolling]
12+
13+
return func(c *gin.Context) {
14+
BindPluginDispatchRequest(c, func(itr request) {
15+
service.StartPolling(&itr, c)
16+
})
17+
}
18+
}
19+
20+
func CheckPolling() gin.HandlerFunc {
21+
type request = plugin_entities.InvokePluginRequest[requests.RequestCheckPolling]
22+
23+
return func(c *gin.Context) {
24+
BindPluginDispatchRequest(c, func(itr request) {
25+
service.CheckPolling(&itr, c)
26+
})
27+
}
28+
}

internal/server/http_server.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ func (app *App) pluginDispatchGroup(group *gin.RouterGroup, config *app.Config)
120120
group.Use(app.InitClusterID())
121121

122122
group.POST("/agent_strategy/invoke", controllers.InvokeAgentStrategy(config))
123+
group.POST("/model/polling/start", controllers.StartPolling())
124+
group.POST("/model/polling/check", controllers.CheckPolling())
123125

124126
app.setupGeneratedRoutes(group, config)
125127
}

internal/service/model_polling.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package service
2+
3+
import (
4+
"github.com/gin-gonic/gin"
5+
"github.com/langgenius/dify-plugin-daemon/internal/core/io_tunnel"
6+
"github.com/langgenius/dify-plugin-daemon/internal/core/io_tunnel/access_types"
7+
"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
8+
"github.com/langgenius/dify-plugin-daemon/internal/types/exception"
9+
"github.com/langgenius/dify-plugin-daemon/pkg/entities"
10+
"github.com/langgenius/dify-plugin-daemon/pkg/entities/plugin_entities"
11+
"github.com/langgenius/dify-plugin-daemon/pkg/entities/requests"
12+
)
13+
14+
func StartPolling(
15+
r *plugin_entities.InvokePluginRequest[requests.RequestStartPolling],
16+
ctx *gin.Context,
17+
) {
18+
session, err := createSession(
19+
r,
20+
access_types.PLUGIN_ACCESS_TYPE_MODEL,
21+
access_types.PLUGIN_ACCESS_ACTION_START_POLLING,
22+
ctx.GetString("cluster_id"),
23+
ctx.Request.Context(),
24+
)
25+
if err != nil {
26+
ctx.JSON(500, exception.InternalServerError(err).ToResponse())
27+
return
28+
}
29+
defer session.Close(session_manager.CloseSessionPayload{IgnoreCache: false})
30+
31+
resp, err := io_tunnel.StartPolling(session, &r.Data)
32+
if err != nil {
33+
ctx.JSON(500, exception.InvokePluginError(err).ToResponse())
34+
return
35+
}
36+
ctx.JSON(200, entities.NewSuccessResponse(resp))
37+
}
38+
39+
func CheckPolling(
40+
r *plugin_entities.InvokePluginRequest[requests.RequestCheckPolling],
41+
ctx *gin.Context,
42+
) {
43+
session, err := createSession(
44+
r,
45+
access_types.PLUGIN_ACCESS_TYPE_MODEL,
46+
access_types.PLUGIN_ACCESS_ACTION_CHECK_POLLING,
47+
ctx.GetString("cluster_id"),
48+
ctx.Request.Context(),
49+
)
50+
if err != nil {
51+
ctx.JSON(500, exception.InternalServerError(err).ToResponse())
52+
return
53+
}
54+
defer session.Close(session_manager.CloseSessionPayload{IgnoreCache: false})
55+
56+
resp, err := io_tunnel.CheckPolling(session, &r.Data)
57+
if err != nil {
58+
ctx.JSON(500, exception.InvokePluginError(err).ToResponse())
59+
return
60+
}
61+
62+
ctx.JSON(200, entities.NewSuccessResponse(resp))
63+
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package model_entities
2+
3+
import (
4+
"bytes"
5+
"encoding/json"
6+
"strings"
7+
8+
"github.com/go-playground/validator/v10"
9+
"github.com/langgenius/dify-plugin-daemon/pkg/validators"
10+
)
11+
12+
type PollingStatus string
13+
14+
const (
15+
PollingStatusRunning PollingStatus = "running"
16+
PollingStatusSucceeded PollingStatus = "succeeded"
17+
PollingStatusFailed PollingStatus = "failed"
18+
)
19+
20+
type ModelPollingResult struct {
21+
Status PollingStatus `json:"status" validate:"required,oneof=running succeeded failed"`
22+
PluginState map[string]any `json:"plugin_state,omitempty"`
23+
Result json.RawMessage `json:"result,omitempty"`
24+
Error string `json:"error,omitempty"`
25+
NextCheckAfterSeconds *int `json:"next_check_after_seconds,omitempty"`
26+
ExpiresAfterSeconds *int `json:"expires_after_seconds,omitempty"`
27+
MaxAttempts *int `json:"max_attempts,omitempty"`
28+
}
29+
30+
func init() {
31+
validators.GlobalEntitiesValidator.RegisterStructValidation(
32+
validateModelPollingResult,
33+
ModelPollingResult{},
34+
)
35+
}
36+
37+
func validateModelPollingResult(sl validator.StructLevel) {
38+
result, ok := sl.Current().Interface().(ModelPollingResult)
39+
if !ok {
40+
return
41+
}
42+
43+
switch result.Status {
44+
case PollingStatusRunning:
45+
if len(result.PluginState) == 0 {
46+
sl.ReportError(result.PluginState, "plugin_state", "PluginState", "required_for_running", "")
47+
}
48+
case PollingStatusSucceeded:
49+
if isEmptyJSON(result.Result) {
50+
sl.ReportError(result.Result, "result", "Result", "required_for_succeeded", "")
51+
}
52+
case PollingStatusFailed:
53+
if strings.TrimSpace(result.Error) == "" {
54+
sl.ReportError(result.Error, "error", "Error", "required_for_failed", "")
55+
}
56+
}
57+
}
58+
59+
func isEmptyJSON(data json.RawMessage) bool {
60+
trimmed := bytes.TrimSpace(data)
61+
return len(trimmed) == 0 || bytes.Equal(trimmed, []byte("null"))
62+
}
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
package model_entities
2+
3+
import (
4+
"encoding/json"
5+
"testing"
6+
7+
"github.com/langgenius/dify-plugin-daemon/pkg/validators"
8+
)
9+
10+
func TestModelPollingResultValidatesStatus(t *testing.T) {
11+
nextCheckAfterSeconds := 10
12+
result := ModelPollingResult{
13+
Status: PollingStatus("running"),
14+
PluginState: map[string]any{"task_id": "task-1"},
15+
NextCheckAfterSeconds: &nextCheckAfterSeconds,
16+
}
17+
18+
if err := validators.GlobalEntitiesValidator.Struct(result); err != nil {
19+
t.Fatalf("validate polling result: %v", err)
20+
}
21+
22+
data, err := json.Marshal(result)
23+
if err != nil {
24+
t.Fatalf("marshal polling result: %v", err)
25+
}
26+
if !json.Valid(data) {
27+
t.Fatalf("invalid json: %s", data)
28+
}
29+
}
30+
31+
func TestModelPollingResultRejectsUnknownStatus(t *testing.T) {
32+
result := ModelPollingResult{
33+
Status: PollingStatus("pending"),
34+
}
35+
36+
if err := validators.GlobalEntitiesValidator.Struct(result); err == nil {
37+
t.Fatal("expected unknown polling status validation error")
38+
}
39+
}
40+
41+
func TestModelPollingResultRequiresRunningState(t *testing.T) {
42+
result := ModelPollingResult{
43+
Status: PollingStatus("running"),
44+
}
45+
46+
if err := validators.GlobalEntitiesValidator.Struct(result); err == nil {
47+
t.Fatal("expected missing plugin_state validation error")
48+
}
49+
}
50+
51+
func TestModelPollingResultRequiresSucceededResult(t *testing.T) {
52+
result := ModelPollingResult{
53+
Status: PollingStatus("succeeded"),
54+
Result: json.RawMessage("null"),
55+
}
56+
57+
if err := validators.GlobalEntitiesValidator.Struct(result); err == nil {
58+
t.Fatal("expected missing result validation error")
59+
}
60+
}
61+
62+
func TestModelPollingResultRequiresFailedError(t *testing.T) {
63+
result := ModelPollingResult{
64+
Status: PollingStatus("failed"),
65+
Error: " ",
66+
}
67+
68+
if err := validators.GlobalEntitiesValidator.Struct(result); err == nil {
69+
t.Fatal("expected missing error validation error")
70+
}
71+
}
72+
73+
func TestModelPollingResultAcceptsTerminalPayloads(t *testing.T) {
74+
succeeded := ModelPollingResult{
75+
Status: PollingStatus("succeeded"),
76+
Result: json.RawMessage(`{"text":"done"}`),
77+
}
78+
if err := validators.GlobalEntitiesValidator.Struct(succeeded); err != nil {
79+
t.Fatalf("validate succeeded result: %v", err)
80+
}
81+
82+
failed := ModelPollingResult{
83+
Status: PollingStatus("failed"),
84+
Error: "provider failed",
85+
}
86+
if err := validators.GlobalEntitiesValidator.Struct(failed); err != nil {
87+
t.Fatalf("validate failed result: %v", err)
88+
}
89+
}

pkg/entities/requests/model.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ type InvokeLLMSchema struct {
2121
PromptMessages []model_entities.PromptMessage `json:"prompt_messages" validate:"omitempty"`
2222
Tools []model_entities.PromptMessageTool `json:"tools" validate:"omitempty,dive"`
2323
Stop []string `json:"stop" validate:"omitempty"`
24+
JSONSchema map[string]any `json:"json_schema" validate:"omitempty"`
2425
Stream bool `json:"stream"`
2526
}
2627

@@ -182,3 +183,20 @@ type RequestGetAIModelSchema struct {
182183

183184
ModelType model_entities.ModelType `json:"model_type" validate:"required,model_type"`
184185
}
186+
187+
type RequestStartPolling struct {
188+
RequestInvokeLLM
189+
190+
WorkflowRunID string `json:"workflow_run_id" validate:"required"`
191+
NodeID string `json:"node_id" validate:"required"`
192+
}
193+
194+
type RequestCheckPolling struct {
195+
BaseRequestInvokeModel
196+
Credentials
197+
198+
ModelType model_entities.ModelType `json:"model_type" validate:"required,model_type,eq=llm"`
199+
WorkflowRunID string `json:"workflow_run_id" validate:"required"`
200+
NodeID string `json:"node_id" validate:"required"`
201+
PluginState map[string]any `json:"plugin_state" validate:"required,min=1"`
202+
}

0 commit comments

Comments
 (0)