Skip to content

Commit d687eb7

Browse files
yroblataskbot
andauthored
Wire optimizer flags into thv vmcp serve (#4940)
Add --optimizer, --optimizer-embedding, --embedding-model, and --embedding-image flags to the serve subcommand. Extend ServeConfig with the four corresponding fields. In Serve(), inject a non-nil OptimizerConfig when either Tier 1 or Tier 2 is active, start the TEI container via EmbeddingServiceManager when Tier 2 is requested, and defer Stop() for clean shutdown. --optimizer-embedding implies --optimizer. Regenerate CLI docs. Closes #4887 Co-authored-by: taskbot <taskbot@users.noreply.github.com>
1 parent c17cb40 commit d687eb7

5 files changed

Lines changed: 258 additions & 16 deletions

File tree

cmd/thv/app/vmcp.go

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,15 @@ servers from a ToolHive group into a single unified endpoint.`,
3030
// newVMCPServeCommand returns the "vmcp serve" subcommand.
3131
func newVMCPServeCommand() *cobra.Command {
3232
var (
33-
configPath string
34-
group string
35-
host string
36-
port int
37-
enableAudit bool
33+
configPath string
34+
group string
35+
host string
36+
port int
37+
enableAudit bool
38+
enableOptimizer bool
39+
enableEmbedding bool
40+
embeddingModel string
41+
embeddingImage string
3842
)
3943
cmd := &cobra.Command{
4044
Use: "serve",
@@ -51,16 +55,28 @@ configuration file is needed for the common case of aggregating a local group.`,
5155
Args: cobra.NoArgs,
5256
RunE: func(cmd *cobra.Command, _ []string) error {
5357
return vmcpcli.Serve(cmd.Context(), vmcpcli.ServeConfig{
54-
ConfigPath: configPath,
55-
GroupRef: group,
56-
Host: host,
57-
Port: port,
58-
EnableAudit: enableAudit,
58+
ConfigPath: configPath,
59+
GroupRef: group,
60+
Host: host,
61+
Port: port,
62+
EnableAudit: enableAudit,
63+
EnableOptimizer: enableOptimizer,
64+
EnableEmbedding: enableEmbedding,
65+
EmbeddingModel: embeddingModel,
66+
EmbeddingImage: embeddingImage,
5967
})
6068
},
6169
}
6270
cmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to vMCP configuration file")
6371
cmd.Flags().StringVar(&group, "group", "", "ToolHive group name (zero-config quick mode when --config is omitted)")
72+
cmd.Flags().BoolVar(&enableOptimizer, "optimizer", false,
73+
"Enable FTS5 keyword optimizer (Tier 1): exposes find_tool and call_tool instead of all backend tools")
74+
cmd.Flags().BoolVar(&enableEmbedding, "optimizer-embedding", false,
75+
"Enable managed TEI semantic optimizer (Tier 2); implies --optimizer")
76+
cmd.Flags().StringVar(&embeddingModel, "embedding-model", "BAAI/bge-small-en-v1.5",
77+
"HuggingFace model name for semantic search (Tier 2)")
78+
cmd.Flags().StringVar(&embeddingImage, "embedding-image",
79+
"ghcr.io/huggingface/text-embeddings-inference:cpu-latest", "TEI container image (Tier 2)")
6480
cmd.Flags().StringVar(&host, "host", "127.0.0.1", "Host address to bind to")
6581
cmd.Flags().IntVar(&port, "port", 4483, "Port to listen on")
6682
cmd.Flags().BoolVar(&enableAudit, "enable-audit", false, "Enable audit logging with default configuration")

docs/cli/thv_vmcp_serve.md

Lines changed: 10 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pkg/vmcp/cli/embedding_manager.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ const (
3131
// defaultTEIImage is the default HuggingFace Text Embeddings Inference image.
3232
defaultTEIImage = "ghcr.io/huggingface/text-embeddings-inference:cpu-latest"
3333

34+
// DefaultEmbeddingModel is the HuggingFace model used when EmbeddingModel is empty.
35+
DefaultEmbeddingModel = "BAAI/bge-small-en-v1.5"
36+
3437
// teiModelCachePath is the path inside the TEI container where models are cached.
3538
teiModelCachePath = "/data"
3639

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package cli
5+
6+
import (
7+
"context"
8+
"errors"
9+
"testing"
10+
11+
"github.com/stretchr/testify/assert"
12+
"github.com/stretchr/testify/require"
13+
14+
vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config"
15+
)
16+
17+
// stubEmbeddingManager is a test double for the embeddingManager interface.
18+
type stubEmbeddingManager struct {
19+
startURL string
20+
startErr error
21+
stopErr error
22+
startSeen bool
23+
stopSeen bool
24+
}
25+
26+
func (s *stubEmbeddingManager) Start(_ context.Context) (string, error) {
27+
s.startSeen = true
28+
return s.startURL, s.startErr
29+
}
30+
31+
func (s *stubEmbeddingManager) Stop(_ context.Context) error {
32+
s.stopSeen = true
33+
return s.stopErr
34+
}
35+
36+
func TestInjectOptimizerConfig_NeitherTierEnabled(t *testing.T) {
37+
t.Parallel()
38+
39+
vmcpCfg := &vmcpconfig.Config{}
40+
cfg := ServeConfig{EnableOptimizer: false, EnableEmbedding: false}
41+
42+
cleanup, err := injectOptimizerConfig(context.Background(), cfg, vmcpCfg, nil)
43+
44+
require.NoError(t, err)
45+
assert.Nil(t, cleanup)
46+
assert.Nil(t, vmcpCfg.Optimizer, "Optimizer must remain nil when neither tier is enabled")
47+
}
48+
49+
func TestInjectOptimizerConfig_Tier1Only(t *testing.T) {
50+
t.Parallel()
51+
52+
vmcpCfg := &vmcpconfig.Config{}
53+
cfg := ServeConfig{EnableOptimizer: true, EnableEmbedding: false}
54+
55+
cleanup, err := injectOptimizerConfig(context.Background(), cfg, vmcpCfg, nil)
56+
57+
require.NoError(t, err)
58+
assert.Nil(t, cleanup, "Tier 1 does not start TEI — no cleanup needed")
59+
require.NotNil(t, vmcpCfg.Optimizer)
60+
assert.Empty(t, vmcpCfg.Optimizer.EmbeddingService, "Tier 1 must not set an embedding service URL")
61+
}
62+
63+
func TestInjectOptimizerConfig_Tier1_PreservesExistingOptimizerConfig(t *testing.T) {
64+
t.Parallel()
65+
66+
existing := &vmcpconfig.OptimizerConfig{MaxToolsToReturn: 5}
67+
vmcpCfg := &vmcpconfig.Config{Optimizer: existing}
68+
cfg := ServeConfig{EnableOptimizer: true, EnableEmbedding: false}
69+
70+
_, err := injectOptimizerConfig(context.Background(), cfg, vmcpCfg, nil)
71+
72+
require.NoError(t, err)
73+
assert.Same(t, existing, vmcpCfg.Optimizer, "Existing optimizer config must not be replaced")
74+
}
75+
76+
func TestInjectOptimizerConfig_Tier2_SetsEmbeddingURL(t *testing.T) {
77+
t.Parallel()
78+
79+
stub := &stubEmbeddingManager{startURL: "http://127.0.0.1:8080"}
80+
vmcpCfg := &vmcpconfig.Config{}
81+
cfg := ServeConfig{EnableEmbedding: true}
82+
83+
cleanup, err := injectOptimizerConfig(context.Background(), cfg, vmcpCfg, stub)
84+
85+
require.NoError(t, err)
86+
require.NotNil(t, cleanup, "Tier 2 must return a cleanup func")
87+
assert.True(t, stub.startSeen, "Start must be called for Tier 2")
88+
require.NotNil(t, vmcpCfg.Optimizer)
89+
assert.Equal(t, "http://127.0.0.1:8080", vmcpCfg.Optimizer.EmbeddingService)
90+
}
91+
92+
func TestInjectOptimizerConfig_Tier2_ImpliesOptimizer(t *testing.T) {
93+
t.Parallel()
94+
95+
stub := &stubEmbeddingManager{startURL: "http://127.0.0.1:8080"}
96+
vmcpCfg := &vmcpconfig.Config{}
97+
// EnableOptimizer is false — EnableEmbedding alone must activate the optimizer.
98+
cfg := ServeConfig{EnableOptimizer: false, EnableEmbedding: true}
99+
100+
_, err := injectOptimizerConfig(context.Background(), cfg, vmcpCfg, stub)
101+
102+
require.NoError(t, err)
103+
assert.NotNil(t, vmcpCfg.Optimizer, "Tier 2 must activate optimizer even without --optimizer flag")
104+
}
105+
106+
func TestInjectOptimizerConfig_Tier2_StartError(t *testing.T) {
107+
t.Parallel()
108+
109+
startErr := errors.New("docker daemon unavailable")
110+
stub := &stubEmbeddingManager{startErr: startErr}
111+
vmcpCfg := &vmcpconfig.Config{}
112+
cfg := ServeConfig{EnableEmbedding: true}
113+
114+
cleanup, err := injectOptimizerConfig(context.Background(), cfg, vmcpCfg, stub)
115+
116+
require.Error(t, err)
117+
assert.ErrorContains(t, err, "docker daemon unavailable")
118+
assert.Nil(t, cleanup, "No cleanup func must be returned on Start failure")
119+
}
120+
121+
func TestInjectOptimizerConfig_Tier2_NilManagerReturnsError(t *testing.T) {
122+
t.Parallel()
123+
124+
vmcpCfg := &vmcpconfig.Config{}
125+
cfg := ServeConfig{EnableEmbedding: true}
126+
127+
cleanup, err := injectOptimizerConfig(context.Background(), cfg, vmcpCfg, nil)
128+
129+
require.Error(t, err)
130+
assert.ErrorContains(t, err, "embedding manager must not be nil")
131+
assert.Nil(t, cleanup)
132+
}
133+
134+
func TestInjectOptimizerConfig_Tier2_CleanupCallsStop(t *testing.T) {
135+
t.Parallel()
136+
137+
stub := &stubEmbeddingManager{startURL: "http://127.0.0.1:9999"}
138+
vmcpCfg := &vmcpconfig.Config{}
139+
cfg := ServeConfig{EnableEmbedding: true}
140+
141+
cleanup, err := injectOptimizerConfig(context.Background(), cfg, vmcpCfg, stub)
142+
require.NoError(t, err)
143+
require.NotNil(t, cleanup)
144+
145+
cleanup()
146+
assert.True(t, stub.stopSeen, "cleanup func must call Stop on the embedding manager")
147+
}

pkg/vmcp/cli/serve.go

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727
authserverconfig "github.com/stacklok/toolhive/pkg/authserver"
2828
authserverrunner "github.com/stacklok/toolhive/pkg/authserver/runner"
2929
"github.com/stacklok/toolhive/pkg/authserver/server/keys"
30+
"github.com/stacklok/toolhive/pkg/container"
3031
"github.com/stacklok/toolhive/pkg/container/runtime"
3132
"github.com/stacklok/toolhive/pkg/groups"
3233
"github.com/stacklok/toolhive/pkg/migration"
@@ -67,6 +68,18 @@ type ServeConfig struct {
6768
// EnableAudit enables audit logging with default configuration when
6869
// the loaded config does not already define an audit section.
6970
EnableAudit bool
71+
72+
// Optimizer tier selection (Phase 4 — flag-driven).
73+
// EnableOptimizer enables Tier 1 FTS5 keyword search (find_tool / call_tool).
74+
EnableOptimizer bool
75+
// EnableEmbedding enables Tier 2 TEI semantic search; implies EnableOptimizer.
76+
EnableEmbedding bool
77+
// EmbeddingModel is the HuggingFace model name for the managed TEI container.
78+
// Defaults to "BAAI/bge-small-en-v1.5" when empty.
79+
EmbeddingModel string
80+
// EmbeddingImage is the TEI container image.
81+
// Defaults to the CPU TEI image when empty.
82+
EmbeddingImage string
7083
}
7184

7285
// validateQuickModeHost returns an error when the config represents quick mode
@@ -280,6 +293,32 @@ func Serve(ctx context.Context, cfg ServeConfig) error {
280293
return fmt.Errorf("failed to create status reporter: %w", err)
281294
}
282295

296+
// Optimizer wiring — Phase 4: flag-driven Tier 1 (FTS5) and Tier 2 (TEI).
297+
// Build the embedding manager only when Tier 2 is requested, to avoid
298+
// unnecessary Docker / Kubernetes API calls for Tier 0 and Tier 1.
299+
var embMgr embeddingManager
300+
if cfg.EnableEmbedding {
301+
model := cfg.EmbeddingModel
302+
if model == "" {
303+
model = DefaultEmbeddingModel
304+
}
305+
m, err := NewEmbeddingServiceManager(container.NewFactory(), EmbeddingServiceManagerConfig{
306+
Model: model,
307+
Image: cfg.EmbeddingImage,
308+
})
309+
if err != nil {
310+
return fmt.Errorf("failed to create embedding service manager: %w", err)
311+
}
312+
embMgr = m
313+
}
314+
teiCleanup, err := injectOptimizerConfig(ctx, cfg, vmcpCfg, embMgr)
315+
if err != nil {
316+
return err
317+
}
318+
if teiCleanup != nil {
319+
defer teiCleanup()
320+
}
321+
283322
optCfg, err := optimizer.GetAndValidateConfig(vmcpCfg.Optimizer)
284323
if err != nil {
285324
return fmt.Errorf("failed to validate optimizer config: %w", err)
@@ -371,6 +410,39 @@ func Serve(ctx context.Context, cfg ServeConfig) error {
371410
return srv.Start(ctx)
372411
}
373412

413+
// embeddingManager is the minimal interface over *EmbeddingServiceManager needed
414+
// by the Serve lifecycle. Defined here to allow stub injection in unit tests;
415+
// production code passes a *EmbeddingServiceManager.
416+
type embeddingManager interface {
417+
Start(ctx context.Context) (string, error)
418+
Stop(ctx context.Context) error
419+
}
420+
421+
// injectOptimizerConfig ensures vmcpCfg.Optimizer is non-nil when flag-driven
422+
// optimizer tiers are active, and starts the TEI container when EnableEmbedding
423+
// is true. Returns a non-nil cleanup func only when a TEI container was started;
424+
// the caller must defer it. mgr must be non-nil when cfg.EnableEmbedding is true.
425+
func injectOptimizerConfig(ctx context.Context, cfg ServeConfig, vmcpCfg *config.Config, mgr embeddingManager) (func(), error) {
426+
if !cfg.EnableOptimizer && !cfg.EnableEmbedding {
427+
return nil, nil
428+
}
429+
if vmcpCfg.Optimizer == nil {
430+
vmcpCfg.Optimizer = &config.OptimizerConfig{}
431+
}
432+
if !cfg.EnableEmbedding {
433+
return nil, nil
434+
}
435+
if mgr == nil {
436+
return nil, fmt.Errorf("embedding manager must not be nil when EnableEmbedding is true")
437+
}
438+
teiURL, err := mgr.Start(ctx)
439+
if err != nil {
440+
return nil, fmt.Errorf("failed to start TEI embedding service: %w", err)
441+
}
442+
vmcpCfg.Optimizer.EmbeddingService = teiURL
443+
return func() { _ = mgr.Stop(context.Background()) }, nil
444+
}
445+
374446
// getStatusReportingInterval extracts the status reporting interval from config.
375447
// Returns 0 if not configured, which uses the default interval.
376448
func getStatusReportingInterval(cfg *config.Config) time.Duration {

0 commit comments

Comments
 (0)