From abb126eb5889615bafa24d34df23a0e5e53b58dd Mon Sep 17 00:00:00 2001 From: mahoushoujyo-eee <1653545106@qq.com> Date: Fri, 24 Apr 2026 23:37:56 +0800 Subject: [PATCH] fix: use default retriever router --- flow/retriever/router/router.go | 2 +- flow/retriever/router/router_test.go | 44 ++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/flow/retriever/router/router.go b/flow/retriever/router/router.go index f0363388e..da96cd2d0 100644 --- a/flow/retriever/router/router.go +++ b/flow/retriever/router/router.go @@ -97,7 +97,7 @@ func NewRetriever(ctx context.Context, config *Config) (retriever.Retriever, err return &routerRetriever{ retrievers: config.Retrievers, - router: config.Router, + router: router, fusionFunc: fusion, }, nil } diff --git a/flow/retriever/router/router_test.go b/flow/retriever/router/router_test.go index edf1af1e7..c8309c08b 100644 --- a/flow/retriever/router/router_test.go +++ b/flow/retriever/router/router_test.go @@ -20,6 +20,7 @@ import ( "context" "reflect" "strings" + "sync" "testing" "github.com/cloudwego/eino/callbacks" @@ -54,6 +55,20 @@ func (m *mockRetriever) GetType() string { return "Mock" } +type recordingRetriever struct { + name string + mu *sync.Mutex + called map[string]int +} + +func (r *recordingRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) { + r.mu.Lock() + r.called[r.name]++ + r.mu.Unlock() + + return []*schema.Document{{ID: r.name}}, nil +} + func TestRouterRetriever(t *testing.T) { ctx := context.Background() r, err := NewRetriever(ctx, &Config{ @@ -111,6 +126,35 @@ func TestRouterRetriever(t *testing.T) { } } +func TestRouterRetrieverDefaultRouter(t *testing.T) { + ctx := context.Background() + mu := &sync.Mutex{} + called := map[string]int{} + r, err := NewRetriever(ctx, &Config{ + Retrievers: map[string]retriever.Retriever{ + "1": &recordingRetriever{name: "1", mu: mu, called: called}, + "2": &recordingRetriever{name: "2", mu: mu, called: called}, + "3": &recordingRetriever{name: "3", mu: mu, called: called}, + }, + }) + if err != nil { + t.Fatal(err) + } + + result, err := r.Retrieve(ctx, "query") + if err != nil { + t.Fatal(err) + } + if len(result) != 3 { + t.Fatalf("expected 3 results, got %d", len(result)) + } + for _, name := range []string{"1", "2", "3"} { + if called[name] != 1 { + t.Fatalf("expected retriever %s to be called once, got %d", name, called[name]) + } + } +} + func TestRRF(t *testing.T) { doc1 := &schema.Document{ID: "1"} doc2 := &schema.Document{ID: "2"}