diff --git a/cmd/gateway.go b/cmd/gateway.go index 0ebb2a899c..2c7e982a75 100644 --- a/cmd/gateway.go +++ b/cmd/gateway.go @@ -333,7 +333,7 @@ func runGateway() { httpapi.InitGatewayToken(cfg.Gateway.Token) exportTokenStore := httpapi.InitExportTokenStore() defer exportTokenStore.Stop() - agentsH, skillsH, tracesH, mcpH, channelInstancesH, providersH, builtinToolsH, pendingMessagesH, teamEventsH, secureCLIH, secureCLIGrantH, mcpUserCredsH := wireHTTP(pgStores, cfg.Agents.Defaults.Workspace, dataDir, bundledSkillsDir, msgBus, toolsReg, providerRegistry, modelReg, permPE.IsOwner, gatewayAddr, mcpToolLister) + agentsH, skillsH, tracesH, mcpH, channelInstancesH, providersH, builtinToolsH, pendingMessagesH, teamEventsH, secureCLIH, secureCLIGrantH, mcpUserCredsH := wireHTTP(pgStores, cfg.Agents.Defaults.Workspace, dataDir, bundledSkillsDir, msgBus, toolsReg, providerRegistry, modelReg, permPE.IsOwner, gatewayAddr, cfg.Gateway.Token, mcpToolLister) // Wire dependencies for system prompt preview parity. if agentsH != nil { diff --git a/cmd/gateway_agents.go b/cmd/gateway_agents.go index 0dcbda6bdb..22485b5ea0 100644 --- a/cmd/gateway_agents.go +++ b/cmd/gateway_agents.go @@ -134,6 +134,29 @@ func buildEmbeddingProvider( "provider", dbp.Name, "requested", es.Dimensions, "required", store.RequiredMemoryEmbeddingDimensions) } + // Gemini native provider — uses its own embedding API (not OpenAI-compatible). + if dbp.ProviderType == store.ProviderGeminiNative { + apiKey := dbp.APIKey + if providerReg != nil { + if regProv, regErr := providerReg.Get(context.Background(), dbp.Name); regErr == nil { + if gp, ok := regProv.(interface{ APIKey() string }); ok && gp.APIKey() != "" { + apiKey = gp.APIKey() + } + } + } + if apiKey == "" { + slog.Warn("gemini embedding provider has no API key", "name", dbp.Name) + return nil + } + if model == "" { + model = memory.GeminiDefaultEmbeddingModel + } + ep := memory.NewGeminiEmbeddingProvider(dbp.Name, apiKey, apiBase, model) + ep.WithDimensions(dims) + slog.Info("gemini embedding provider configured", "name", dbp.Name, "model", model, "dims", dims) + return ep + } + // Try registry first for the actual API key / base (handles runtime-registered providers) if providerReg != nil { if regProv, regErr := providerReg.Get(context.Background(), dbp.Name); regErr == nil { diff --git a/cmd/gateway_http_handlers.go b/cmd/gateway_http_handlers.go index 4ddb0e52b6..c6ecc2fb0c 100644 --- a/cmd/gateway_http_handlers.go +++ b/cmd/gateway_http_handlers.go @@ -9,7 +9,7 @@ import ( ) // wireHTTP creates HTTP handlers (agents + skills + traces + MCP + channel instances + providers + builtin tools + pending messages). -func wireHTTP(stores *store.Stores, defaultWorkspace, dataDir, bundledSkillsDir string, msgBus *bus.MessageBus, toolsReg *tools.Registry, providerReg *providers.Registry, modelReg providers.ModelRegistry, isOwner func(string) bool, gatewayAddr string, mcpToolLister httpapi.MCPToolLister) (*httpapi.AgentsHandler, *httpapi.SkillsHandler, *httpapi.TracesHandler, *httpapi.MCPHandler, *httpapi.ChannelInstancesHandler, *httpapi.ProvidersHandler, *httpapi.BuiltinToolsHandler, *httpapi.PendingMessagesHandler, *httpapi.TeamEventsHandler, *httpapi.SecureCLIHandler, *httpapi.SecureCLIGrantHandler, *httpapi.MCPUserCredentialsHandler) { +func wireHTTP(stores *store.Stores, defaultWorkspace, dataDir, bundledSkillsDir string, msgBus *bus.MessageBus, toolsReg *tools.Registry, providerReg *providers.Registry, modelReg providers.ModelRegistry, isOwner func(string) bool, gatewayAddr, gatewayToken string, mcpToolLister httpapi.MCPToolLister) (*httpapi.AgentsHandler, *httpapi.SkillsHandler, *httpapi.TracesHandler, *httpapi.MCPHandler, *httpapi.ChannelInstancesHandler, *httpapi.ProvidersHandler, *httpapi.BuiltinToolsHandler, *httpapi.PendingMessagesHandler, *httpapi.TeamEventsHandler, *httpapi.SecureCLIHandler, *httpapi.SecureCLIGrantHandler, *httpapi.MCPUserCredentialsHandler) { var agentsH *httpapi.AgentsHandler var skillsH *httpapi.SkillsHandler var tracesH *httpapi.TracesHandler @@ -70,6 +70,10 @@ func wireHTTP(stores *store.Stores, defaultWorkspace, dataDir, bundledSkillsDir if stores.MCP != nil { providersH.SetMCPServerLookup(buildMCPServerLookup(stores.MCP)) } + acpMCPData = buildACPMCPData(gatewayAddr, gatewayToken, stores.MCP) + providersH.SetProviderReloadFn(func(p *store.LLMProviderData) { + registerACPFromDB(providerReg, *p) + }) if stores.Tracing != nil { providersH.SetTracingStore(stores.Tracing) } diff --git a/cmd/gateway_managed.go b/cmd/gateway_managed.go index 5d2c111ba2..ffa1a0fd1c 100644 --- a/cmd/gateway_managed.go +++ b/cmd/gateway_managed.go @@ -680,6 +680,8 @@ func wireExtras( // Unregister old instance (closes ProcessPool) then re-register providerReg.Unregister(p.Name) if p.Enabled { + acpMCPData = buildACPMCPData(loopbackAddr(appCfg.Gateway.Host, appCfg.Gateway.Port), + appCfg.Gateway.Token, stores.MCP) registerACPFromDB(providerReg, *p) } }) diff --git a/cmd/gateway_providers.go b/cmd/gateway_providers.go index b174ef44e7..ec75ec68a4 100644 --- a/cmd/gateway_providers.go +++ b/cmd/gateway_providers.go @@ -19,6 +19,36 @@ import ( "github.com/nextlevelbuilder/goclaw/internal/tools" ) +// acpMCPData is the package-level MCP bridge config consumed by +// registerACPFromConfig and registerACPFromDB. Callers populate it via +// buildACPMCPData before invoking either register* — keeping these functions +// at their original 2-arg signatures (registry + cfg/p) and making +// hot-reload paths idempotent (they refresh this var before re-registering). +// +// Single-process gateway scope means a package-level var is acceptable here: +// gateway addr/token/MCPStore are fixed for the lifetime of the binary, and +// the four set-sites (startup config, startup DB iteration, two hot-reload +// closures) all derive identical values. +var acpMCPData *providers.MCPConfigData + +// buildACPMCPData assembles the MCP bridge config consumed by ACP providers. +// Returns nil when no gateway addr is available, which makes downstream +// settings.MCPData nil and the ACP provider skip MCP server injection. +// mcpStore is optional — when non-nil, the AgentMCPLookup closure is attached +// so per-agent MCP servers are surfaced to the ACP subprocess at session/new +// time (DB-registered providers only; config-based providers run without +// per-agent MCP). +func buildACPMCPData(gatewayAddr, gatewayToken string, mcpStore store.MCPServerStore) *providers.MCPConfigData { + if gatewayAddr == "" { + return nil + } + data := providers.BuildCLIMCPConfigData(nil, gatewayAddr, gatewayToken) + if mcpStore != nil { + data.AgentMCPLookup = buildMCPServerLookup(mcpStore) + } + return data +} + // loopbackAddr normalizes a gateway address for local connections. // CLI processes on the same machine can't connect to 0.0.0.0 on some OSes. func loopbackAddr(host string, port int) string { @@ -29,6 +59,7 @@ func loopbackAddr(host string, port int) string { } func registerProviders(registry *providers.Registry, cfg *config.Config, modelReg providers.ModelRegistry) { + gatewayAddr := loopbackAddr(cfg.Gateway.Host, cfg.Gateway.Port) if cfg.Providers.Anthropic.APIKey != "" { registry.Register(providers.NewAnthropicProvider(cfg.Providers.Anthropic.APIKey, providers.WithAnthropicBaseURL(cfg.Providers.Anthropic.APIBase), @@ -188,7 +219,6 @@ func registerProviders(registry *providers.Registry, cfg *config.Config, modelRe opts = append(opts, providers.WithClaudeCLIPermMode(cfg.Providers.ClaudeCLI.PermMode)) } // Build per-session MCP config: external MCP servers + GoClaw bridge - gatewayAddr := loopbackAddr(cfg.Gateway.Host, cfg.Gateway.Port) mcpData := providers.BuildCLIMCPConfigData(cfg.Tools.McpServers, gatewayAddr, cfg.Gateway.Token) opts = append(opts, providers.WithClaudeCLIMCPConfigData(mcpData)) // Enable GoClaw security hooks (shell deny patterns, path restrictions) @@ -200,6 +230,7 @@ func registerProviders(registry *providers.Registry, cfg *config.Config, modelRe // ACP provider (config-based) — orchestrates any ACP-compatible agent binary if cfg.Providers.ACP.Binary != "" { + acpMCPData = buildACPMCPData(gatewayAddr, cfg.Gateway.Token, nil) registerACPFromConfig(registry, cfg.Providers.ACP) } } @@ -276,6 +307,7 @@ func registerProvidersFromDB(registry *providers.Registry, provStore store.Provi slog.Warn("failed to load providers from DB", "error", err) return } + acpMCPData = buildACPMCPData(gatewayAddr, gatewayToken, mcpStore) for _, p := range dbProviders { // Claude CLI doesn't need API key if !p.Enabled { @@ -411,35 +443,43 @@ func registerProvidersFromDB(registry *providers.Registry, provStore store.Provi } // registerACPFromConfig registers an ACP provider from config file settings. +// All ACP options consume one shared *providers.ACPSettings populated from cfg; +// per-binary defaults (e.g. gemini's --include-directories) are applied inside +// the relevant With* option in the providers package. The MCP bridge config +// is read from the package-level acpMCPData (set by callers via +// buildACPMCPData before invocation). func registerACPFromConfig(registry *providers.Registry, cfg config.ACPConfig) { if _, err := exec.LookPath(cfg.Binary); err != nil { slog.Warn("acp: binary not found, skipping", "binary", cfg.Binary, "error", err) return } - idleTTL := 5 * time.Minute - if cfg.IdleTTL != "" { - if d, err := time.ParseDuration(cfg.IdleTTL); err == nil { - idleTTL = d - } - } - workDir := cfg.WorkDir - if workDir == "" { - workDir = defaultACPWorkDir() - } - var opts []providers.ACPOption - if cfg.Model != "" { - opts = append(opts, providers.WithACPModel(cfg.Model)) - } - if cfg.PermMode != "" { - opts = append(opts, providers.WithACPPermMode(cfg.PermMode)) + settings := &providers.ACPSettings{ + Binary: cfg.Binary, + Args: cfg.Args, + Model: cfg.Model, + PermMode: cfg.PermMode, + IdleTTL: cfg.IdleTTL, + WorkDir: cfg.WorkDir, + MCPData: acpMCPData, } registry.Register(providers.NewACPProvider( - cfg.Binary, cfg.Args, workDir, idleTTL, tools.DefaultDenyPatterns(), opts..., + settings.Binary, settings.Args, settings.WorkDirOrDefault(), + settings.IdleTTLOrDefault(5*time.Minute), + tools.DefaultDenyPatterns(), + providers.WithACPModel(settings), + providers.WithACPPermMode(settings), + providers.WithACPMCPConfigData(settings), + providers.WithIncludeDirectories(settings), )) - slog.Info("registered provider", "name", "acp", "binary", cfg.Binary) + slog.Info("registered provider", "name", "acp", "binary", cfg.Binary, "args", cfg.Args) } -// registerACPFromDB registers an ACP provider from a DB provider row. +// registerACPFromDB registers an ACP provider from a DB row. +// Called at startup (via registerProvidersFromDB) and on hot-reload. +// DB JSONB unmarshals directly into providers.ACPSettings — the shared struct's +// json tags match the historic schema (args, idle_ttl, perm_mode, work_dir, +// include_directories). The MCP bridge config is read from the package-level +// acpMCPData (set by callers via buildACPMCPData before invocation). func registerACPFromDB(registry *providers.Registry, p store.LLMProviderData) { binary := p.APIBase // repurpose api_base as binary path if binary == "" { @@ -454,37 +494,27 @@ func registerACPFromDB(registry *providers.Registry, p store.LLMProviderData) { slog.Warn("acp: binary not found, skipping", "binary", binary, "error", err) return } - // Parse settings JSONB for extra config - var settings struct { - Args []string `json:"args"` - IdleTTL string `json:"idle_ttl"` - PermMode string `json:"perm_mode"` - WorkDir string `json:"work_dir"` + settings := &providers.ACPSettings{ + Name: p.Name, + Binary: binary, + Model: p.Name, // historical: provider name doubles as default agent/model } if p.Settings != nil { - if err := json.Unmarshal(p.Settings, &settings); err != nil { + if err := json.Unmarshal(p.Settings, settings); err != nil { slog.Warn("acp: invalid settings JSON, using defaults", "name", p.Name, "error", err) } } - idleTTL := 5 * time.Minute - if settings.IdleTTL != "" { - if d, err := time.ParseDuration(settings.IdleTTL); err == nil { - idleTTL = d - } - } - workDir := settings.WorkDir - if workDir == "" { - workDir = defaultACPWorkDir() - } + settings.MCPData = acpMCPData registry.RegisterForTenant(p.TenantID, providers.NewACPProvider( - binary, settings.Args, workDir, idleTTL, tools.DefaultDenyPatterns(), - providers.WithACPName(p.Name), - providers.WithACPModel(p.Name), + settings.Binary, settings.Args, settings.WorkDirOrDefault(), + settings.IdleTTLOrDefault(5*time.Minute), + tools.DefaultDenyPatterns(), + providers.WithACPName(settings), + providers.WithACPModel(settings), + providers.WithACPPermMode(settings), + providers.WithACPMCPConfigData(settings), + providers.WithIncludeDirectories(settings), )) slog.Info("registered provider from DB", "name", p.Name, "type", "acp") } -// defaultACPWorkDir returns the default workspace directory for ACP agents. -func defaultACPWorkDir() string { - return filepath.Join(config.ResolvedDataDirFromEnv(), "acp-workspaces") -} diff --git a/cmd/providers_cmd.go b/cmd/providers_cmd.go index 4ad8581e8c..77e988f79a 100644 --- a/cmd/providers_cmd.go +++ b/cmd/providers_cmd.go @@ -181,7 +181,7 @@ func runProvidersAdd() { if providerID != "" { verify, err := promptConfirm("Verify connection now?", true) if err == nil && verify { - runProviderVerify(providerID, "") + runProviderVerify(providerID) } } } @@ -266,41 +266,45 @@ func providersDeleteCmd() *cobra.Command { } func providersVerifyCmd() *cobra.Command { - var modelFlag string - cmd := &cobra.Command{ + return &cobra.Command{ Use: "verify ", - Short: "Verify provider connectivity (ping) or a specific model", - Long: "Without --model: pings the provider (registered + reachable check).\nWith --model: sends a small chat request to validate the model alias.", + Short: "Verify provider connectivity and list models", Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { requireRunningGatewayHTTP() - runProviderVerify(args[0], modelFlag) + runProviderVerify(args[0]) }, } - cmd.Flags().StringVar(&modelFlag, "model", "", "model alias to verify (omit for connectivity ping)") - return cmd } -func runProviderVerify(providerID, model string) { +func runProviderVerify(providerID string) { fmt.Print("Verifying provider... ") - var body any - if model != "" { - body = map[string]string{"model": model} - } - resp, err := gatewayHTTPPost("/v1/providers/"+url.PathEscape(providerID)+"/verify", body) + resp, err := gatewayHTTPPost("/v1/providers/"+url.PathEscape(providerID)+"/verify", nil) if err != nil { fmt.Printf("FAILED\n %v\n", err) return } - if valid, _ := resp["valid"].(bool); valid { + + if ok, _ := resp["success"].(bool); ok { fmt.Println("OK") - return - } - msg, _ := resp["error"].(string) - if msg == "" { - msg = "verification failed" + // Show available models + raw, _ := json.Marshal(resp["models"]) + var models []httpProviderModel + if json.Unmarshal(raw, &models) == nil && len(models) > 0 { + fmt.Printf(" Available models: %d\n", len(models)) + limit := 10 + for i, m := range models { + if i >= limit { + fmt.Printf(" ... and %d more\n", len(models)-limit) + break + } + fmt.Printf(" - %s\n", m.ID) + } + } + } else { + msg, _ := resp["error"].(string) + fmt.Printf("FAILED\n %s\n", msg) } - fmt.Printf("FAILED\n %s\n", msg) } // defaultBaseURL returns the default API base URL for a provider type. diff --git a/cmd/setup_provider.go b/cmd/setup_provider.go index 0e9ffbee48..9a5fd2eedc 100644 --- a/cmd/setup_provider.go +++ b/cmd/setup_provider.go @@ -1,6 +1,7 @@ package cmd import ( + "encoding/json" "fmt" "net/url" "os" @@ -94,7 +95,7 @@ func addProvider() { providerID, _ := resp["id"].(string) fmt.Printf(" Provider %q created.\n", name) - // Auto-verify (ping mode — empty body) + // Auto-verify if providerID != "" { fmt.Print(" Verifying... ") verifyResp, err := gatewayHTTPPost("/v1/providers/"+url.PathEscape(providerID)+"/verify", nil) @@ -102,13 +103,15 @@ func addProvider() { fmt.Printf("FAILED (%v)\n", err) return } - if valid, _ := verifyResp["valid"].(bool); valid { + if ok, _ := verifyResp["success"].(bool); ok { fmt.Println("OK") + raw, _ := json.Marshal(verifyResp["models"]) + var models []httpProviderModel + if json.Unmarshal(raw, &models) == nil { + fmt.Printf(" %d models available.\n", len(models)) + } } else { msg, _ := verifyResp["error"].(string) - if msg == "" { - msg = "verification failed" - } fmt.Printf("FAILED (%s)\n", msg) fmt.Println(" You can update the API key later with 'goclaw providers update'.") } diff --git a/internal/http/provider_verify.go b/internal/http/provider_verify.go index 48e4e591af..5ffc03a075 100644 --- a/internal/http/provider_verify.go +++ b/internal/http/provider_verify.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "errors" - "io" "net/http" "os/exec" "path/filepath" @@ -18,14 +17,6 @@ import ( "github.com/nextlevelbuilder/goclaw/internal/store" ) -// HandleVerifyProviderForTest invokes the verify handler directly without auth -// middleware. Integration tests must inject the desired tenant_id into the -// request context before calling. Production code MUST go through RegisterRoutes -// so the auth/locale/tenant pipeline runs first. -func (h *ProvidersHandler) HandleVerifyProviderForTest(w http.ResponseWriter, r *http.Request) { - h.handleVerifyProvider(w, r) -} - // handleVerifyProvider tests a provider+model combination with a minimal LLM call. // // POST /v1/providers/{id}/verify @@ -42,17 +33,14 @@ func (h *ProvidersHandler) handleVerifyProvider(w http.ResponseWriter, r *http.R var req struct { Model string `json:"model"` } - // Empty body == ping mode (connectivity check only). Truncated/malformed - // JSON still returns 400. io.EOF on Decode unambiguously means no body; - // io.ErrUnexpectedEOF is what truncated JSON returns. if err := json.NewDecoder(http.MaxBytesReader(w, r.Body, 1<<16)).Decode(&req); err != nil { - if !errors.Is(err, io.EOF) { - writeJSON(w, http.StatusBadRequest, map[string]string{"error": i18n.T(locale, i18n.MsgInvalidJSON)}) - return - } - // empty body — req.Model stays "" → pingMode below + writeJSON(w, http.StatusBadRequest, map[string]string{"error": i18n.T(locale, i18n.MsgInvalidJSON)}) + return + } + if req.Model == "" { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": i18n.T(locale, i18n.MsgRequired, "model")}) + return } - pingMode := req.Model == "" // Look up provider record from DB to get the provider name p, err := h.store.GetProvider(r.Context(), id) @@ -63,10 +51,6 @@ func (h *ProvidersHandler) handleVerifyProvider(w http.ResponseWriter, r *http.R // ACP: verify binary exists on the server (no LLM call needed) if p.ProviderType == store.ProviderACP { - if pingMode { - writeJSON(w, http.StatusOK, map[string]any{"valid": true}) - return - } binary := p.APIBase if binary == "" { binary = "claude" @@ -86,10 +70,6 @@ func (h *ProvidersHandler) handleVerifyProvider(w http.ResponseWriter, r *http.R // Claude CLI: validate model alias locally (no LLM call needed) if p.ProviderType == "claude_cli" { - if pingMode { - writeJSON(w, http.StatusOK, map[string]any{"valid": true}) - return - } validModels := map[string]bool{"sonnet": true, "opus": true, "haiku": true} if validModels[req.Model] { writeJSON(w, http.StatusOK, map[string]any{"valid": true}) @@ -112,11 +92,6 @@ func (h *ProvidersHandler) handleVerifyProvider(w http.ResponseWriter, r *http.R return } - if pingMode { - writeJSON(w, http.StatusOK, map[string]any{"valid": true}) - return - } - // Non-chat models (image/video generation) can't be verified via Chat API. // Accept them if the provider is reachable (already validated above). if isNonChatModel(req.Model) { diff --git a/internal/http/providers.go b/internal/http/providers.go index ff8fa17718..31e38b7e79 100644 --- a/internal/http/providers.go +++ b/internal/http/providers.go @@ -35,9 +35,10 @@ type ProvidersHandler struct { cliMu sync.Mutex // serializes Claude CLI provider create to prevent duplicates msgBus *bus.MessageBus sysConfigStore store.SystemConfigStore - tracingStore store.TracingStore // optional: for provider-scoped pool activity - agents store.AgentCRUDStore // optional: for provider pool activity agent lookup - modelReg providers.ModelRegistry // optional: forward-compat model resolver for Anthropic + tracingStore store.TracingStore // optional: for provider-scoped pool activity + agents store.AgentCRUDStore // optional: for provider pool activity agent lookup + modelReg providers.ModelRegistry // optional: forward-compat model resolver for Anthropic + providerReloadFn func(*store.LLMProviderData) // optional: hot-reload process-based providers without restart } // NewProvidersHandler creates a handler for provider management endpoints. @@ -84,6 +85,12 @@ func (h *ProvidersHandler) SetModelRegistry(r providers.ModelRegistry) { h.modelReg = r } +// SetProviderReloadFn sets a callback invoked when a process-based provider (ACP) +// is created or updated, so the change takes effect without a gateway restart. +func (h *ProvidersHandler) SetProviderReloadFn(fn func(*store.LLMProviderData)) { + h.providerReloadFn = fn +} + // resolveAPIBase returns the provider's api_base, falling back to config/env if empty. // For Ollama/OllamaCloud providers, applies a safety-net normalization: if the stored // value is missing the /v1 suffix (pre-existing record before write-time normalization), @@ -160,9 +167,11 @@ func (h *ProvidersHandler) registerInMemory(p *store.LLMProviderData) { if h.providerReg == nil || !p.Enabled { return } - // ACP agents don't need an API key — skip in-memory registration - // (ACP providers are registered via gateway_providers.go on startup or restart) + // ACP providers use a process pool; delegate to the reload callback wired from cmd/. if p.ProviderType == store.ProviderACP { + if h.providerReloadFn != nil { + h.providerReloadFn(p) + } return } // Claude CLI doesn't need an API key — register immediately diff --git a/internal/memory/embedding_gemini.go b/internal/memory/embedding_gemini.go new file mode 100644 index 0000000000..d965fb4000 --- /dev/null +++ b/internal/memory/embedding_gemini.go @@ -0,0 +1,163 @@ +package memory + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +const ( + geminiEmbeddingBatchSize = 100 // batchEmbedContents API limit + geminiEmbeddingAPIBase = "https://generativelanguage.googleapis.com/v1beta" + // GeminiDefaultEmbeddingModel is the default model for Gemini embeddings. + // text-embedding-004/005 cap at 768 dims; gemini-embedding-2 supports higher dims (incl. 1536). + GeminiDefaultEmbeddingModel = "gemini-embedding-2" +) + +// GeminiEmbeddingProvider implements EmbeddingProvider using the Google Generative Language API. +type GeminiEmbeddingProvider struct { + name string + model string + apiKey string + apiBase string + dims int + client *http.Client +} + +// NewGeminiEmbeddingProvider creates an embedding provider backed by the Gemini API. +// apiBase may be empty (defaults to generativelanguage.googleapis.com). +func NewGeminiEmbeddingProvider(name, apiKey, apiBase, model string) *GeminiEmbeddingProvider { + if apiBase == "" { + apiBase = geminiEmbeddingAPIBase + } + if model == "" { + model = GeminiDefaultEmbeddingModel + } + return &GeminiEmbeddingProvider{ + name: name, + model: model, + apiKey: apiKey, + apiBase: strings.TrimRight(apiBase, "/"), + dims: 0, + client: &http.Client{Timeout: 60 * time.Second}, + } +} + +// WithDimensions sets the outputDimensionality sent to the API. +// Must match the pgvector column size (RequiredMemoryEmbeddingDimensions = 1536). +func (p *GeminiEmbeddingProvider) WithDimensions(d int) *GeminiEmbeddingProvider { + p.dims = d + return p +} + +func (p *GeminiEmbeddingProvider) Name() string { return p.name } +func (p *GeminiEmbeddingProvider) Model() string { return p.model } + +func (p *GeminiEmbeddingProvider) Embed(ctx context.Context, texts []string) ([][]float32, error) { + if len(texts) == 0 { + return nil, nil + } + results := make([][]float32, len(texts)) + for start := 0; start < len(texts); start += geminiEmbeddingBatchSize { + end := min(start+geminiEmbeddingBatchSize, len(texts)) + batch, err := p.embedBatch(ctx, texts[start:end]) + if err != nil { + return nil, fmt.Errorf("gemini embedding batch [%d:%d]: %w", start, end, err) + } + for i, emb := range batch { + results[start+i] = emb + } + } + return results, nil +} + +// modelName returns the fully-qualified model name required by the Gemini API. +// e.g. "gemini-embedding-exp-03-07" → "models/gemini-embedding-exp-03-07" +func (p *GeminiEmbeddingProvider) modelName() string { + if strings.HasPrefix(p.model, "models/") { + return p.model + } + return "models/" + p.model +} + +func (p *GeminiEmbeddingProvider) embedBatch(ctx context.Context, texts []string) ([][]float32, error) { + type contentPart struct { + Text string `json:"text"` + } + type content struct { + Parts []contentPart `json:"parts"` + } + type embedRequest struct { + Model string `json:"model"` + Content content `json:"content"` + OutputDimensionality *int `json:"outputDimensionality,omitempty"` + } + type batchRequest struct { + Requests []embedRequest `json:"requests"` + } + + reqs := make([]embedRequest, len(texts)) + for i, t := range texts { + r := embedRequest{ + Model: p.modelName(), + Content: content{Parts: []contentPart{{Text: t}}}, + } + if p.dims > 0 { + d := p.dims + r.OutputDimensionality = &d + } + reqs[i] = r + } + + body, err := json.Marshal(batchRequest{Requests: reqs}) + if err != nil { + return nil, fmt.Errorf("marshal: %w", err) + } + + endpoint := fmt.Sprintf("%s/models/%s:batchEmbedContents", p.apiBase, p.model) + // Use the unqualified model name in the URL path. + if strings.HasPrefix(p.model, "models/") { + endpoint = fmt.Sprintf("%s/%s:batchEmbedContents", p.apiBase, p.model) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("x-goog-api-key", p.apiKey) + + resp, err := p.client.Do(req) + if err != nil { + return nil, fmt.Errorf("http: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("gemini embedding API %d: %s", resp.StatusCode, string(b)) + } + + var result struct { + Embeddings []struct { + Values []float32 `json:"values"` + } `json:"embeddings"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("decode: %w", err) + } + if len(result.Embeddings) != len(texts) { + return nil, fmt.Errorf("gemini embedding count mismatch: got %d, want %d", len(result.Embeddings), len(texts)) + } + + out := make([][]float32, len(texts)) + for i, e := range result.Embeddings { + out[i] = e.Values + } + return out, nil +} diff --git a/internal/providers/acp/acp_gemini_test.go b/internal/providers/acp/acp_gemini_test.go index b4324a68a6..a2fc8cdaf7 100644 --- a/internal/providers/acp/acp_gemini_test.go +++ b/internal/providers/acp/acp_gemini_test.go @@ -27,7 +27,7 @@ func TestGeminiProtocolMapping(t *testing.T) { t.Fatalf("Spawn failed: %v", err) } - sid, err := proc.NewSession(ctx) + sid, err := proc.NewSession(ctx, "") if err != nil { t.Fatalf("NewSession failed: %v", err) } diff --git a/internal/providers/acp/process.go b/internal/providers/acp/process.go index 7e34de5152..7dbe1cd982 100644 --- a/internal/providers/acp/process.go +++ b/internal/providers/acp/process.go @@ -19,14 +19,16 @@ type ACPProcess struct { cmd *exec.Cmd conn *Conn - agentCaps AgentCaps - workDir string - lastActive time.Time - inUse atomic.Int32 // >0 means at least one prompt is active — reaper must skip - mu sync.Mutex - ctx context.Context - cancel context.CancelFunc - exited chan struct{} // closed when process exits + agentCaps AgentCaps + workDir string + mcpServersFn func(context.Context) []McpServer // invoked on every session/new + session/load + promptTimeout time.Duration // overrides promptInactivityTimeout when non-zero + lastActive time.Time + inUse atomic.Int32 // >0 means at least one prompt is active — reaper must skip + mu sync.Mutex + ctx context.Context + cancel context.CancelFunc + exited chan struct{} // closed when process exits // updateFns routes session/update notifications to the correct active prompt. updateFns map[string]func(SessionUpdate) @@ -38,6 +40,11 @@ func (p *ACPProcess) AgentCaps() AgentCaps { return p.agentCaps } +// WorkDir returns the process pool's base work directory. Callers building +// per-session workspaces should join a session-specific segment under this +// path and pass the result as the cwd argument to NewSession/LoadSession. +func (p *ACPProcess) WorkDir() string { return p.workDir } + // registerUpdateFn registers a callback for session/update notifications on sessionID. func (p *ACPProcess) registerUpdateFn(sid string, fn func(SessionUpdate)) { p.updateMu.Lock() @@ -107,16 +114,18 @@ func (p *ACPProcess) dispatchUpdate(update SessionUpdate) { // Typically a single shared process is used (poolKey = binary identifier), // and multiple ACP sessions are multiplexed over it. type ProcessPool struct { - processes sync.Map // poolKey → *ACPProcess - spawnMu sync.Map // poolKey → *sync.Mutex — prevents concurrent spawn - agentBinary string - agentArgs []string - workDir string - idleTTL time.Duration - mu sync.RWMutex // protects toolHandler - toolHandler RequestHandler - done chan struct{} - closeOnce sync.Once + processes sync.Map // poolKey → *ACPProcess + spawnMu sync.Map // poolKey → *sync.Mutex — prevents concurrent spawn + agentBinary string + agentArgs []string + workDir string + mcpServersFn func(context.Context) []McpServer // resolved per session/new + session/load + idleTTL time.Duration + promptTimeout time.Duration + mu sync.RWMutex // protects toolHandler, mcpServersFn, promptTimeout + toolHandler RequestHandler + done chan struct{} + closeOnce sync.Once } // NewProcessPool creates a pool that spawns ACP agents as subprocesses. @@ -132,6 +141,37 @@ func NewProcessPool(binary string, args []string, workDir string, idleTTL time.D return pp } +// SetMcpServersFunc configures the callback used to build the MCP server list +// on every session/new and session/load request. The callback receives the +// request context (with agent/tenant IDs) so it can return per-agent servers +// resolved from the MCP store. Must be called before GetOrSpawn; spawned +// processes inherit the current value at spawn time. +func (pp *ProcessPool) SetMcpServersFunc(fn func(context.Context) []McpServer) { + pp.mu.Lock() + defer pp.mu.Unlock() + pp.mcpServersFn = fn +} + +func (pp *ProcessPool) getMcpServersFn() func(context.Context) []McpServer { + pp.mu.RLock() + defer pp.mu.RUnlock() + return pp.mcpServersFn +} + +// SetPromptTimeout sets the inactivity timeout used by Prompt() watchdogs in +// newly spawned processes. Existing processes are not affected. +func (pp *ProcessPool) SetPromptTimeout(d time.Duration) { + pp.mu.Lock() + defer pp.mu.Unlock() + pp.promptTimeout = d +} + +func (pp *ProcessPool) getPromptTimeout() time.Duration { + pp.mu.RLock() + defer pp.mu.RUnlock() + return pp.promptTimeout +} + // SetToolHandler sets the agent→client request handler (tool bridge). // Must be called before any GetOrSpawn calls. func (pp *ProcessPool) SetToolHandler(h RequestHandler) { @@ -176,7 +216,9 @@ func (pp *ProcessPool) spawn(ctx context.Context, poolKey string) (*ACPProcess, cmd := exec.CommandContext(procCtx, pp.agentBinary, pp.agentArgs...) cmd.Dir = pp.workDir - cmd.Env = filterACPEnv(os.Environ()) + cmd.Env = append(filterACPEnv(os.Environ()), + "GEMINI_TELEMETRY_ENABLED=false", + ) cmd.SysProcAttr = sysProcAttr() stdinPipe, err := cmd.StdinPipe() @@ -198,18 +240,20 @@ func (pp *ProcessPool) spawn(ctx context.Context, poolKey string) (*ACPProcess, } proc := &ACPProcess{ - cmd: cmd, - lastActive: time.Now(), - ctx: procCtx, - cancel: cancel, - exited: make(chan struct{}), - workDir: pp.workDir, + cmd: cmd, + lastActive: time.Now(), + ctx: procCtx, + cancel: cancel, + exited: make(chan struct{}), + workDir: pp.workDir, + mcpServersFn: pp.getMcpServersFn(), + promptTimeout: pp.getPromptTimeout(), } // Notification handler: log all notifications and dispatch session/update to callers notifyHandler := func(method string, params json.RawMessage) { slog.Info("acp: notification received", "method", method) - slog.Debug("acp: notification params", "method", method, "params", string(params)) + slog.Info("acp: notification params", "method", method, "params", string(params)) if method == "session/update" { var update SessionUpdate if err := json.Unmarshal(params, &update); err != nil { @@ -260,8 +304,8 @@ func (pp *ProcessPool) reapLoop() { proc.mu.Unlock() if idle { slog.Info("acp: reaping idle process", "pool_key", key) + pp.processes.Delete(key) // delete before cancel so a concurrent GetOrSpawn sees no stale entry proc.cancel() - pp.processes.Delete(key) } return true }) diff --git a/internal/providers/acp/session.go b/internal/providers/acp/session.go index 016d4303a7..75fb0bc826 100644 --- a/internal/providers/acp/session.go +++ b/internal/providers/acp/session.go @@ -5,16 +5,22 @@ import ( "fmt" "log/slog" "path/filepath" + "sync/atomic" "time" ) +// promptInactivityTimeout is the maximum time Prompt() will wait without +// receiving any session/update notification before cancelling the prompt. +// Exposed as a package var so tests can shorten it. +var promptInactivityTimeout = 10 * time.Minute + // Initialize sends the ACP initialize request to establish capabilities. func (p *ACPProcess) Initialize(ctx context.Context) error { ctx, cancel := context.WithTimeout(ctx, 60*time.Second) defer cancel() req := InitializeRequest{ ProtocolVersion: 1, - ClientInfo: ClientInfo{Name: "GoClaw", Version: "1.0"}, + ClientInfo: ClientInfo{Name: "", Version: "1.0"}, Capabilities: ClientCaps{}, } var resp InitializeResponse @@ -26,51 +32,85 @@ func (p *ACPProcess) Initialize(ctx context.Context) error { return nil } +// resolveCwd returns the provided override if non-empty, otherwise the +// process pool's default work directory (falling back to CWD as last resort). +func (p *ACPProcess) resolveCwd(override string) string { + if override != "" { + return override + } + if p.workDir != "" { + return p.workDir + } + cwd, _ := filepath.Abs(".") + return cwd +} + // NewSession creates a new ACP session and returns its session ID. -func (p *ACPProcess) NewSession(ctx context.Context) (string, error) { +// If cwd is non-empty it is used as the session working directory; otherwise +// the process pool's workDir is used. Gemini CLI 0.36.x honors the per-session +// cwd even when it differs from the subprocess spawn directory, enabling +// per-goclaw-session workspace isolation. +func (p *ACPProcess) NewSession(ctx context.Context, cwd string) (string, error) { ctx, cancel := context.WithTimeout(ctx, 60*time.Second) defer cancel() + sessionCwd := p.resolveCwd(cwd) - cwd := p.workDir - if cwd == "" { - cwd, _ = filepath.Abs(".") + var servers []McpServer + if p.mcpServersFn != nil { + servers = p.mcpServersFn(ctx) } - - req := NewSessionRequest{ - Cwd: cwd, - McpServers: []string{}, + if servers == nil { + servers = []McpServer{} } + req := NewSessionRequest{Cwd: sessionCwd, McpServers: servers} var resp NewSessionResponse if err := p.conn.Call(ctx, "session/new", req, &resp); err != nil { return "", fmt.Errorf("acp session/new: %w", err) } - slog.Info("acp: session/new", "sid", resp.SessionID, "cwd", cwd) + slog.Info("acp: session/new", "sid", resp.SessionID, "cwd", sessionCwd, "mcpServers", len(servers)) + for _, s := range servers { + switch sv := s.(type) { + case McpServerHTTP: + slog.Info("acp: mcp server (http)", "name", sv.Name, "url", sv.URL, "headers", len(sv.Headers)) + case McpServerStdio: + slog.Info("acp: mcp server (stdio)", "name", sv.Name, "command", sv.Command, "args", sv.Args) + } + } return resp.SessionID, nil } // LoadSession restores a previous ACP session by ID (used after process restart). // Returns the session ID to use going forward (may equal the requested ID). // Only call if AgentCaps().LoadSession is true. -func (p *ACPProcess) LoadSession(ctx context.Context, sessionID string) (string, error) { +// cwd has the same semantics as NewSession — pass the per-goclaw-session +// directory so tool calls resolve paths against it. +func (p *ACPProcess) LoadSession(ctx context.Context, sessionID, cwd string) (string, error) { ctx, cancel := context.WithTimeout(ctx, 60*time.Second) defer cancel() + sessionCwd := p.resolveCwd(cwd) - cwd := p.workDir - if cwd == "" { - cwd, _ = filepath.Abs(".") + var servers []McpServer + if p.mcpServersFn != nil { + servers = p.mcpServersFn(ctx) } - - req := LoadSessionRequest{SessionID: sessionID, Cwd: cwd} + if servers == nil { + servers = []McpServer{} + } + req := LoadSessionRequest{SessionID: sessionID, Cwd: sessionCwd, McpServers: servers} var resp LoadSessionResponse if err := p.conn.Call(ctx, "session/load", req, &resp); err != nil { return "", fmt.Errorf("acp session/load: %w", err) } - slog.Info("acp: session/load", "sid", resp.SessionID) + slog.Info("acp: session/load", "sid", resp.SessionID, "cwd", sessionCwd) return resp.SessionID, nil } // Prompt sends user content to sessionID and blocks until the agent completes, // invoking onUpdate for each session/update notification received. +// +// An inactivity watchdog cancels the prompt if no session/update arrives within +// promptInactivityTimeout. This guards against silent hangs where the ACP agent +// stops responding without closing the connection. func (p *ACPProcess) Prompt(ctx context.Context, sessionID string, content []ContentBlock, onUpdate func(SessionUpdate)) (*PromptResponse, error) { p.inUse.Add(1) defer p.inUse.Add(-1) @@ -79,11 +119,61 @@ func (p *ACPProcess) Prompt(ctx context.Context, sessionID string, content []Con p.lastActive = time.Now() p.mu.Unlock() - p.registerUpdateFn(sessionID, onUpdate) + timeout := p.promptTimeout + if timeout <= 0 { + timeout = promptInactivityTimeout + } + + // lastActivity is refreshed by every session/update; watchdog fires when stale. + var lastActivity atomic.Int64 + lastActivity.Store(time.Now().UnixNano()) + + watchdogDone := make(chan struct{}) + go func() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + if time.Since(time.Unix(0, lastActivity.Load())) > timeout { + slog.Warn("acp: prompt inactivity timeout, cancelling", + "sid", sessionID, "timeout", timeout) + _ = p.conn.Notify("session/cancel", CancelNotification{SessionID: sessionID}) + return + } + case <-watchdogDone: + return + case <-ctx.Done(): + return + } + } + }() + + // Wrap onUpdate to refresh lastActivity on every notification. + p.registerUpdateFn(sessionID, func(update SessionUpdate) { + lastActivity.Store(time.Now().UnixNano()) + if update.ToolCall != nil { + slog.Info("acp: tool call update", "sid", sessionID, "tool", update.ToolCall.Name, "status", update.ToolCall.Status, "id", update.ToolCall.ID) + } else if update.Kind != "" { + slog.Info("acp: session update", "sid", sessionID, "kind", update.Kind, "sessionUpdate", update.Update.SessionUpdate, "status", update.Update.Status) + } + if onUpdate != nil { + onUpdate(update) + } + }) defer p.unregisterUpdateFn(sessionID) + defer close(watchdogDone) goclawSession := goclawSessionFromCtx(ctx) - slog.Info("acp: session/prompt", "session", goclawSession, "sid", sessionID) + var contentPreview string + if len(content) > 0 && content[0].Type == "text" { + if len(content[0].Text) > 200 { + contentPreview = content[0].Text[:200] + "..." + } else { + contentPreview = content[0].Text + } + } + slog.Info("acp: session/prompt", "session", goclawSession, "sid", sessionID, "blocks", len(content), "preview", contentPreview) req := PromptRequest{ SessionID: sessionID, Prompt: content, diff --git a/internal/providers/acp/session_test.go b/internal/providers/acp/session_test.go index 3f048b5e94..9a70f10968 100644 --- a/internal/providers/acp/session_test.go +++ b/internal/providers/acp/session_test.go @@ -213,7 +213,7 @@ func TestACPProcess_NewSession_Success(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - sid, err := proc.NewSession(ctx) + sid, err := proc.NewSession(ctx, "") if err != nil { t.Fatalf("NewSession error: %v", err) } @@ -233,7 +233,7 @@ func TestACPProcess_NewSession_Error(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - _, err := proc.NewSession(ctx) + _, err := proc.NewSession(ctx, "") if err == nil { t.Fatal("expected error from NewSession") } diff --git a/internal/providers/acp/tool_bridge.go b/internal/providers/acp/tool_bridge.go index ba36aa58d0..df6955d0b7 100644 --- a/internal/providers/acp/tool_bridge.go +++ b/internal/providers/acp/tool_bridge.go @@ -57,34 +57,50 @@ func NewToolBridge(workspace string, opts ...ToolBridgeOption) *ToolBridge { // Handle dispatches agent→client requests by method name. // Implements the RequestHandler signature for Conn. func (tb *ToolBridge) Handle(ctx context.Context, method string, params json.RawMessage) (any, error) { + session := goclawSessionFromCtx(ctx) switch method { case "fs/readTextFile": if tb.permMode == "deny-all" { + slog.Warn("security.tool_denied", "session", session, "tool", method, "reason", "deny-all") return nil, fmt.Errorf("read denied by permission mode: %s", tb.permMode) } var req ReadTextFileRequest if err := json.Unmarshal(params, &req); err != nil { return nil, fmt.Errorf("invalid params: %w", err) } - return tb.readFile(req) + result, err := tb.readFile(req) + if err == nil { + slog.Info("security.tool_granted", "session", session, "tool", method, "path", req.Path) + } + return result, err case "fs/writeTextFile": if tb.permMode == "deny-all" || tb.permMode == "approve-reads" { + slog.Warn("security.tool_denied", "session", session, "tool", method, "reason", tb.permMode) return nil, fmt.Errorf("write denied by permission mode: %s", tb.permMode) } var req WriteTextFileRequest if err := json.Unmarshal(params, &req); err != nil { return nil, fmt.Errorf("invalid params: %w", err) } - return tb.writeFile(req) + result, err := tb.writeFile(req) + if err == nil { + slog.Info("security.tool_granted", "session", session, "tool", method, "path", req.Path) + } + return result, err case "terminal/create": if tb.permMode == "deny-all" || tb.permMode == "approve-reads" { + slog.Warn("security.tool_denied", "session", session, "tool", method, "reason", tb.permMode) return nil, fmt.Errorf("terminal denied by permission mode: %s", tb.permMode) } var req CreateTerminalRequest if err := json.Unmarshal(params, &req); err != nil { return nil, fmt.Errorf("invalid params: %w", err) } - return tb.createTerminal(req) + result, err := tb.createTerminal(req) + if err == nil { + slog.Info("security.tool_granted", "session", session, "tool", method, "command", req.Command) + } + return result, err case "terminal/output": var req TerminalOutputRequest if err := json.Unmarshal(params, &req); err != nil { @@ -105,6 +121,7 @@ func (tb *ToolBridge) Handle(ctx context.Context, method string, params json.Raw return tb.waitForExit(ctx, req) case "terminal/kill": if tb.permMode == "deny-all" { + slog.Warn("security.tool_denied", "session", session, "tool", method, "reason", "deny-all") return nil, fmt.Errorf("terminal kill denied by permission mode: %s", tb.permMode) } var req KillTerminalRequest @@ -117,7 +134,13 @@ func (tb *ToolBridge) Handle(ctx context.Context, method string, params json.Raw if err := json.Unmarshal(params, &req); err != nil { return nil, fmt.Errorf("invalid params: %w", err) } - return tb.handlePermission(req) + return tb.handlePermission(ctx, req) + case "session/request_permission": + var req SessionRequestPermissionRequest + if err := json.Unmarshal(params, &req); err != nil { + return nil, fmt.Errorf("invalid params: %w", err) + } + return tb.handleSessionPermission(ctx, req) default: return nil, fmt.Errorf("unknown method: %s", method) } @@ -151,21 +174,73 @@ func (tb *ToolBridge) writeFile(req WriteTextFileRequest) (*WriteTextFileRespons return &WriteTextFileResponse{}, nil } +// handleSessionPermission handles Gemini CLI's "session/request_permission" ACP method. +// Gemini CLI expects a nested outcome object that differs from the generic "permission/request" format. +// Responding with "proceed_always_server" adds the entire goclaw-bridge server to Gemini's +// allowlist so all subsequent tool calls in the session skip the confirmation step. +func (tb *ToolBridge) handleSessionPermission(ctx context.Context, req SessionRequestPermissionRequest) (*SessionRequestPermissionResponse, error) { + session := goclawSessionFromCtx(ctx) + + available := make(map[string]bool, len(req.Options)) + for _, opt := range req.Options { + available[opt.OptionID] = true + } + + switch tb.permMode { + case "deny-all": + slog.Warn("security.tool_denied", "session", session, "tool", req.ToolCall.Title, "reason", "deny-all") + return &SessionRequestPermissionResponse{ + Outcome: SessionPermOutcome{Outcome: "cancelled"}, + }, nil + case "approve-reads": + lower := strings.ToLower(req.ToolCall.Title) + if strings.Contains(lower, "read") || strings.Contains(lower, "glob") || + strings.Contains(lower, "grep") || strings.Contains(lower, "search") || + strings.Contains(lower, "list") || strings.Contains(lower, "view") { + slog.Info("security.tool_granted", "session", session, "tool", req.ToolCall.Title, "mode", "approve-reads") + return &SessionRequestPermissionResponse{ + Outcome: SessionPermOutcome{Outcome: "selected", OptionID: "proceed_once"}, + }, nil + } + slog.Warn("security.tool_denied", "session", session, "tool", req.ToolCall.Title, "reason", "approve-reads:write-blocked") + return &SessionRequestPermissionResponse{ + Outcome: SessionPermOutcome{Outcome: "cancelled"}, + }, nil + default: // "approve-all" + // Prefer server-wide approval so all subsequent goclaw-bridge tool calls skip confirmation. + optionID := "proceed_once" + for _, pref := range []string{"proceed_always_server", "proceed_always_tool", "proceed_once"} { + if available[pref] { + optionID = pref + break + } + } + slog.Info("security.tool_granted", "session", session, "tool", req.ToolCall.Title, "mode", "approve-all", "optionId", optionID) + return &SessionRequestPermissionResponse{ + Outcome: SessionPermOutcome{Outcome: "selected", OptionID: optionID}, + }, nil + } +} + // handlePermission responds to permission requests based on configured mode. -func (tb *ToolBridge) handlePermission(req RequestPermissionRequest) (*RequestPermissionResponse, error) { +func (tb *ToolBridge) handlePermission(ctx context.Context, req RequestPermissionRequest) (*RequestPermissionResponse, error) { + session := goclawSessionFromCtx(ctx) switch tb.permMode { case "deny-all": + slog.Warn("security.tool_denied", "session", session, "tool", req.ToolName, "reason", "deny-all") return &RequestPermissionResponse{Outcome: "denied"}, nil case "approve-reads": - // Approve read-only tools, deny write/exec tools lower := strings.ToLower(req.ToolName) if strings.Contains(lower, "read") || strings.Contains(lower, "glob") || strings.Contains(lower, "grep") || strings.Contains(lower, "search") || strings.Contains(lower, "list") || strings.Contains(lower, "view") { + slog.Info("security.tool_granted", "session", session, "tool", req.ToolName, "mode", "approve-reads") return &RequestPermissionResponse{Outcome: "approved"}, nil } + slog.Warn("security.tool_denied", "session", session, "tool", req.ToolName, "reason", "approve-reads:write-blocked") return &RequestPermissionResponse{Outcome: "denied"}, nil default: // "approve-all" or unknown → approve + slog.Info("security.tool_granted", "session", session, "tool", req.ToolName, "mode", "approve-all") return &RequestPermissionResponse{Outcome: "approved"}, nil } } diff --git a/internal/providers/acp/tool_bridge_test.go b/internal/providers/acp/tool_bridge_test.go index 7f54f74b71..fb9394092b 100644 --- a/internal/providers/acp/tool_bridge_test.go +++ b/internal/providers/acp/tool_bridge_test.go @@ -139,7 +139,7 @@ func TestResolvePath_NonExistentFile_AllowedForWrites(t *testing.T) { func TestHandlePermission_ApproveAll(t *testing.T) { tb, _ := newTestBridge(t, WithPermMode("approve-all")) - resp, err := tb.handlePermission(RequestPermissionRequest{ToolName: "bash", Description: "run"}) + resp, err := tb.handlePermission(context.Background(), RequestPermissionRequest{ToolName: "bash", Description: "run"}) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -150,7 +150,7 @@ func TestHandlePermission_ApproveAll(t *testing.T) { func TestHandlePermission_DenyAll(t *testing.T) { tb, _ := newTestBridge(t, WithPermMode("deny-all")) - resp, err := tb.handlePermission(RequestPermissionRequest{ToolName: "any_tool"}) + resp, err := tb.handlePermission(context.Background(), RequestPermissionRequest{ToolName: "any_tool"}) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -164,7 +164,7 @@ func TestHandlePermission_ApproveReads_ReadTool(t *testing.T) { cases := []string{"readFile", "glob_files", "search_code", "list_dir", "grep_search", "view_file"} for _, name := range cases { t.Run(name, func(t *testing.T) { - resp, err := tb.handlePermission(RequestPermissionRequest{ToolName: name}) + resp, err := tb.handlePermission(context.Background(), RequestPermissionRequest{ToolName: name}) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -177,7 +177,7 @@ func TestHandlePermission_ApproveReads_ReadTool(t *testing.T) { func TestHandlePermission_ApproveReads_WriteTool(t *testing.T) { tb, _ := newTestBridge(t, WithPermMode("approve-reads")) - resp, err := tb.handlePermission(RequestPermissionRequest{ToolName: "write_file"}) + resp, err := tb.handlePermission(context.Background(), RequestPermissionRequest{ToolName: "write_file"}) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -189,7 +189,7 @@ func TestHandlePermission_ApproveReads_WriteTool(t *testing.T) { func TestHandlePermission_DefaultMode_Approves(t *testing.T) { // permMode = "" defaults to "approve-all" behaviour (unknown → approve) tb := &ToolBridge{permMode: "unknown-mode"} - resp, err := tb.handlePermission(RequestPermissionRequest{ToolName: "anything"}) + resp, err := tb.handlePermission(context.Background(), RequestPermissionRequest{ToolName: "anything"}) if err != nil { t.Fatalf("unexpected error: %v", err) } diff --git a/internal/providers/acp/types.go b/internal/providers/acp/types.go index 9177e50507..bfb571fb8a 100644 --- a/internal/providers/acp/types.go +++ b/internal/providers/acp/types.go @@ -61,9 +61,57 @@ type MCPCaps struct { // --- Session Methods --- +// McpServer is a discriminated-union transport descriptor for MCP servers. +// Concrete types: McpServerHTTP, McpServerStdio (SSE unimplemented). +// Per ACP spec (zed-industries/agent-client-protocol), the wire format is a +// JSON object tagged by `type`; Go's encoding/json handles this via concrete +// values held in the interface. +type McpServer interface{ mcpServerKind() } + +// McpServerHTTP carries HTTP transport MCP config. +// Headers is a {name,value} array — Gemini CLI 0.36.x rejects object-shaped +// headers with schema error "expected array, received object", so we diverge +// from the zed-industries ACP schema (which specifies object) to match the +// implementation that actually consumes the payload. +type McpServerHTTP struct { + Type string `json:"type"` // always "http" + Name string `json:"name"` + URL string `json:"url"` + Headers []McpServerKV `json:"headers"` +} + +func (McpServerHTTP) mcpServerKind() {} + +// McpServerStdio carries stdio transport MCP config. +type McpServerStdio struct { + Type string `json:"type"` // always "stdio" + Name string `json:"name"` + Command string `json:"command"` + Args []string `json:"args"` + Env []McpServerKV `json:"env"` +} + +func (McpServerStdio) mcpServerKind() {} + +// McpServerKV is a {name, value} pair used for both HTTP headers and stdio env. +type McpServerKV struct { + Name string `json:"name"` + Value string `json:"value"` +} + +// Alias retained for backward compatibility with any caller that constructed +// env entries by the older name. New code should use McpServerKV directly. +type McpServerEnv = McpServerKV + +// NewHTTPMcpServer returns an HTTP-transport McpServer with an empty headers +// slice (the field must be present per schema). +func NewHTTPMcpServer(name, url string) McpServer { + return McpServerHTTP{Type: "http", Name: name, URL: url, Headers: []McpServerKV{}} +} + type NewSessionRequest struct { - Cwd string `json:"cwd"` - McpServers []string `json:"mcpServers"` + Cwd string `json:"cwd"` + McpServers []McpServer `json:"mcpServers"` } type NewSessionResponse struct { @@ -71,9 +119,9 @@ type NewSessionResponse struct { } type LoadSessionRequest struct { - SessionID string `json:"sessionId"` - Cwd string `json:"cwd,omitempty"` - McpServers []string `json:"mcpServers"` + SessionID string `json:"sessionId"` + Cwd string `json:"cwd,omitempty"` + McpServers []McpServer `json:"mcpServers"` } type LoadSessionResponse struct { @@ -206,3 +254,35 @@ type RequestPermissionRequest struct { type RequestPermissionResponse struct { Outcome string `json:"outcome"` // "proceed_always", "approved", "denied" } + +// SessionRequestPermissionRequest is sent by Gemini CLI (method "session/request_permission") +// to request approval before executing an MCP tool. +type SessionRequestPermissionRequest struct { + SessionID string `json:"sessionId"` + Options []SessionPermOpt `json:"options"` + ToolCall SessionPermTool `json:"toolCall"` +} + +type SessionPermOpt struct { + OptionID string `json:"optionId"` + Name string `json:"name"` + Kind string `json:"kind"` +} + +type SessionPermTool struct { + ToolCallID string `json:"toolCallId"` + Status string `json:"status"` + Title string `json:"title"` + Kind string `json:"kind,omitempty"` +} + +// SessionRequestPermissionResponse matches Gemini CLI's RequestPermissionResponseSchema. +// Wire format: {"outcome":{"outcome":"cancelled"}} or {"outcome":{"outcome":"selected","optionId":"..."}} +type SessionRequestPermissionResponse struct { + Outcome SessionPermOutcome `json:"outcome"` +} + +type SessionPermOutcome struct { + Outcome string `json:"outcome"` // "cancelled" or "selected" + OptionID string `json:"optionId,omitempty"` // required when outcome="selected" +} diff --git a/internal/providers/acp_provider.go b/internal/providers/acp_provider.go index 17380d5d5e..35c0b422b4 100644 --- a/internal/providers/acp_provider.go +++ b/internal/providers/acp_provider.go @@ -5,14 +5,89 @@ import ( "errors" "fmt" "log/slog" + "os" + "path/filepath" "regexp" "strings" "sync" "time" + "github.com/nextlevelbuilder/goclaw/internal/config" "github.com/nextlevelbuilder/goclaw/internal/providers/acp" ) +// ACPSettings is the unified configuration shape for ACP-based providers. +// Both config-based (config.json `providers.acp`) and DB-based (llm_providers.settings JSONB) +// registration paths populate this struct; all ACP `With*` options consume it as a +// common argument and pick the field they configure. Fields left zero / empty are +// treated as "use built-in default" inside each option, so callers only need to set +// values they want to override. +// +// Duration fields (IdleTTL, SessionTTL, PromptTimeout) are stored as strings in the +// duration syntax accepted by time.ParseDuration ("5m", "30s", etc.) so the same +// struct shape works for JSON unmarshal (DB JSONB) without custom decoding logic. +type ACPSettings struct { + Name string `json:"name,omitempty"` // provider display name + Binary string `json:"-"` // resolved binary path (DB: api_base column; config: cfg.Binary) + Args []string `json:"args,omitempty"` // extra CLI args (excluding goclaw-injected --include-directories) + Model string `json:"model,omitempty"` // default model/agent name + PermMode string `json:"perm_mode,omitempty"` // tool bridge permission mode + IdleTTL string `json:"idle_ttl,omitempty"` // duration string; pool/session reaper idle timeout + SessionTTL string `json:"session_ttl,omitempty"` // duration string; session reaper override (else falls back to IdleTTL) + PromptTimeout string `json:"prompt_timeout,omitempty"` // duration string; per-Prompt() inactivity watchdog + WorkDir string `json:"work_dir,omitempty"` // process pool base cwd + IncludeDirs []string `json:"include_directories,omitempty"` + MCPData *MCPConfigData `json:"-"` // MCP bridge config; never in JSONB +} + +// IdleTTLOrDefault parses IdleTTL with a fallback when unset / invalid. +func (s *ACPSettings) IdleTTLOrDefault(fallback time.Duration) time.Duration { + if s == nil || s.IdleTTL == "" { + return fallback + } + if d, err := time.ParseDuration(s.IdleTTL); err == nil && d > 0 { + return d + } + return fallback +} + +// WorkDirOrDefault returns s.WorkDir or the package default ACP workspace root. +func (s *ACPSettings) WorkDirOrDefault() string { + if s != nil && s.WorkDir != "" { + return s.WorkDir + } + return defaultACPWorkDir() +} + +// defaultACPWorkDir returns the standard ACP process workspace root used when +// callers don't override via ACPSettings.WorkDir. Located under the resolved +// data dir so it survives across deployments without leaking outside goclaw. +func defaultACPWorkDir() string { + return filepath.Join(config.ResolvedDataDirFromEnv(), "acp-workspaces") +} + +// defaultGoclawSkillDirs returns the canonical filesystem-backed skill source +// directories that gemini ACP should expose via --include-directories when no +// explicit IncludeDirs are configured. Mirrors three of the loader's runtime +// slots — workspace-relative slots are intentionally omitted because the ACP +// session cwd lives under acp-workspaces, not the gateway workspace. +// +// Sources covered: +// - /skills-store (managedSkillsDir) +// - /skills (globalSkills) +// - ~/.agents/skills (personalAgentSkills) +func defaultGoclawSkillDirs() []string { + dataDir := config.ResolvedDataDirFromEnv() + dirs := []string{ + filepath.Join(dataDir, "skills-store"), + filepath.Join(dataDir, "skills"), + } + if home, err := os.UserHomeDir(); err == nil && home != "" { + dirs = append(dirs, filepath.Join(home, ".agents", "skills")) + } + return dirs +} + // acpSessionEntry tracks a live ACP session for one goclaw conversation. type acpSessionEntry struct { id string // ACP session ID returned by session/new or session/load @@ -20,15 +95,35 @@ type acpSessionEntry struct { lastUsed time.Time } +// acpRoutingKey is the private context key for per-call routing values. +type acpRoutingKey struct{} + +// acpRoutingValues holds values extracted from ChatRequest.Options for MCP bridge headers. +type acpRoutingValues struct { + agentID string + userID string + channel string + chatID string + peerKind string + workspace string + tenantID string + localKey string + sessionKey string +} + // ACPProvider implements Provider by orchestrating ACP-compatible agent subprocesses. // One shared Gemini process is used; each goclaw conversation gets its own ACP session. type ACPProvider struct { - name string - pool *acp.ProcessPool - bridge *acp.ToolBridge - defaultModel string - permMode string - poolKey string // key for the shared process in the pool (binary + args) + name string + pool *acp.ProcessPool + bridge *acp.ToolBridge + defaultModel string + permMode string + poolKey string // key for the shared process in the pool (binary + args) + mcpConfigData *MCPConfigData // MCP bridge config (gateway addr, token, lookup) + sessionIdleTTL time.Duration // idle TTL for ACP session reaper + promptTimeout time.Duration // inactivity timeout for Prompt() watchdog + includeDirs []string // candidate dirs appended as --include-directories for gemini acpSessions sync.Map // goclawSessionKey → *acpSessionEntry sessionMu sync.Map // goclawSessionKey → *sync.Mutex (prevents concurrent session creation) @@ -40,51 +135,144 @@ type ACPProvider struct { // ACPOption configures an ACPProvider. type ACPOption func(*ACPProvider) -// WithACPName overrides the provider name (default: "acp"). -func WithACPName(name string) ACPOption { +// All ACP With* options below take a *ACPSettings as a common argument and read +// only the field they configure. Empty / zero values are treated as "no override" +// so callers can build one settings struct and pass it to every option without +// worrying about clobbering defaults set elsewhere. + +// WithACPName overrides the provider name (default: "acp"). Reads s.Name. +func WithACPName(s *ACPSettings) ACPOption { + return func(p *ACPProvider) { + if s == nil || s.Name == "" { + return + } + p.name = s.Name + } +} + +// WithACPModel sets the default model/agent name. Reads s.Model. +func WithACPModel(s *ACPSettings) ACPOption { return func(p *ACPProvider) { - if name != "" { - p.name = name + if s == nil || s.Model == "" { + return } + p.defaultModel = s.Model } } -// WithACPModel sets the default model/agent name. -func WithACPModel(model string) ACPOption { +// WithACPPermMode sets the permission mode for the tool bridge. Reads s.PermMode. +func WithACPPermMode(s *ACPSettings) ACPOption { return func(p *ACPProvider) { - if model != "" { - p.defaultModel = model + if s == nil || s.PermMode == "" { + return } + p.permMode = s.PermMode } } -// WithACPPermMode sets the permission mode for the tool bridge. -func WithACPPermMode(mode string) ACPOption { +// WithACPSessionTTL overrides the idle TTL used by the session reaper. +// Reads s.SessionTTL (duration string). When unset/invalid, NewACPProvider +// falls back to the process pool's idleTTL. +func WithACPSessionTTL(s *ACPSettings) ACPOption { return func(p *ACPProvider) { - if mode != "" { - p.permMode = mode + if s == nil || s.SessionTTL == "" { + return + } + if d, err := time.ParseDuration(s.SessionTTL); err == nil && d > 0 { + p.sessionIdleTTL = d } } } -// NewACPProvider creates a provider that orchestrates ACP agents as subprocesses. -func NewACPProvider(binary string, args []string, workDir string, idleTTL time.Duration, denyPatterns []*regexp.Regexp, opts ...ACPOption) *ACPProvider { - // Pool key identifies the shared process: binary + args combination - poolKey := binary - if len(args) > 0 { - poolKey += "|" + strings.Join(args, " ") +// WithACPPromptTimeout sets the inactivity timeout for Prompt() watchdogs. +// Reads s.PromptTimeout (duration string). When unset/invalid, the +// package-level promptInactivityTimeout default (10 min) applies. +func WithACPPromptTimeout(s *ACPSettings) ACPOption { + return func(p *ACPProvider) { + if s == nil || s.PromptTimeout == "" { + return + } + if d, err := time.ParseDuration(s.PromptTimeout); err == nil && d > 0 { + p.promptTimeout = d + } + } +} + +// WithIncludeDirectories registers candidate directories that should be exposed +// to the agent's filesystem sandbox. The actual binary gating happens in +// NewACPProvider, which only emits `--include-directories ` pairs for +// gemini and stat-filters non-existent entries. Storing the list on the +// provider for non-gemini binaries is harmless (never consumed downstream). +// +// When s.IncludeDirs is empty, falls back to the canonical goclaw skill source +// dirs (skills-store, global skills, personal agent skills) so the typical +// deployment "just works" without admin needing to enumerate paths. +func WithIncludeDirectories(s *ACPSettings) ACPOption { + return func(p *ACPProvider) { + if s == nil { + return + } + dirs := s.IncludeDirs + if len(dirs) == 0 { + dirs = defaultGoclawSkillDirs() + } + p.includeDirs = dirs + } +} + +// WithACPMCPConfigData registers MCP bridge config (gateway address, token, server lookup). +// Reads s.MCPData. Mirrors the Claude CLI pattern: provider builds the MCP server +// list per session using routing values from ChatRequest.Options. +func WithACPMCPConfigData(s *ACPSettings) ACPOption { + return func(p *ACPProvider) { + if s == nil || s.MCPData == nil { + return + } + p.mcpConfigData = s.MCPData } +} +// NewACPProvider creates a provider that orchestrates ACP agents as subprocesses. +func NewACPProvider(binary string, args []string, workDir string, idleTTL time.Duration, denyPatterns []*regexp.Regexp, opts ...ACPOption) *ACPProvider { p := &ACPProvider{ name: "acp", defaultModel: "claude", - poolKey: poolKey, done: make(chan struct{}), } for _, opt := range opts { opt(p) } + // Gemini sandbox needs --include-directories to read goclaw skill paths + // outside the cwd. Non-gemini binaries (claude, codex) handle filesystem + // access differently, so includeDirs is a no-op for them. + if filepath.Base(binary) == "gemini" && len(p.includeDirs) > 0 { + for _, d := range p.includeDirs { + if d == "" { + continue + } + if info, err := os.Stat(d); err == nil && info.IsDir() { + args = append(args, "--include-directories", d) + } + } + } + + // poolKey uniquely identifies a subprocess configuration so that providers + // differing in any of the five dimensions always spawn separate processes. + // permMode is included explicitly; it is no longer injected into CLI args + // because ACP permission/request RPCs are handled entirely by ToolBridge. + p.poolKey = fmt.Sprintf("%s|%s|%s|%s|%s", + binary, + strings.Join(args, " "), + workDir, + idleTTL, + p.permMode, + ) + + if p.sessionIdleTTL == 0 { + p.sessionIdleTTL = idleTTL + } + var bridgeOpts []acp.ToolBridgeOption if len(denyPatterns) > 0 { bridgeOpts = append(bridgeOpts, acp.WithDenyPatterns(denyPatterns)) @@ -96,15 +284,24 @@ func NewACPProvider(binary string, args []string, workDir string, idleTTL time.D p.pool = acp.NewProcessPool(binary, args, workDir, idleTTL) p.pool.SetToolHandler(p.bridge.Handle) + if p.mcpConfigData != nil { + cd := p.mcpConfigData + p.pool.SetMcpServersFunc(func(ctx context.Context) []acp.McpServer { + rv, _ := ctx.Value(acpRoutingKey{}).(acpRoutingValues) + return p.buildACPServers(ctx, cd, rv) + }) + } + if p.promptTimeout > 0 { + p.pool.SetPromptTimeout(p.promptTimeout) + } go p.sessionReaper() return p } -// sessionReaper removes ACP sessions idle for more than 30 minutes. +// sessionReaper removes ACP sessions idle for more than sessionIdleTTL. // Sends session/cancel to release resources on the agent side before purging locally. func (p *ACPProvider) sessionReaper() { - const sessionIdleTTL = 30 * time.Minute ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() for { @@ -112,12 +309,13 @@ func (p *ACPProvider) sessionReaper() { case <-ticker.C: p.acpSessions.Range(func(key, value any) bool { entry := value.(*acpSessionEntry) - if time.Since(entry.lastUsed) > sessionIdleTTL { - slog.Info("acp: expiring idle session", "goclaw_session", key, "sid", entry.id) + if time.Since(entry.lastUsed) > p.sessionIdleTTL { + slog.Info("acp: expiring idle session", "goclaw_session", key, "sid", entry.id, "ttl", p.sessionIdleTTL) if entry.proc != nil { _ = entry.proc.Cancel(entry.id) } p.acpSessions.Delete(key) + p.sessionMu.Delete(key) } return true }) @@ -127,10 +325,59 @@ func (p *ACPProvider) sessionReaper() { } } +// ensureSessionDir creates and returns a per-goclaw-session workspace under +// the process pool's base work directory. Mirrors the claude_cli provider's +// ensureWorkDir pattern so acp-workspaces layout matches cli-workspaces: +// +// /agent--ws-direct-/ +// +// Falls back to the pool's workDir (shared) if the base is unset or MkdirAll +// fails — safer than /tmp since the caller passes Authorization-protected +// paths to the ACP agent. +func (p *ACPProvider) ensureSessionDir(proc *acp.ACPProcess, goclawKey string) string { + base := proc.WorkDir() + if base == "" { + return "" + } + safe := sanitizePathSegment(goclawKey) + if safe == "" { + return base + } + dir := filepath.Join(base, safe) + if err := os.MkdirAll(dir, 0o755); err != nil { + slog.Warn("acp: failed to create per-session workspace, using pool default", + "goclaw_session", goclawKey, "dir", dir, "error", err) + return base + } + return dir +} + +// writeGeminiMD writes the system prompt to GEMINI.md in the session workspace. +// Gemini CLI reads this file automatically from the session cwd (mirrors writeClaudeMD). +// Skips write if content is unchanged. Returns true if the file was rewritten, +// signalling the caller to invalidate the live ACP session so the next request +// starts a fresh session with the updated instructions. +func (p *ACPProvider) writeGeminiMD(sessionDir, systemPrompt string) bool { + if sessionDir == "" || systemPrompt == "" { + return false + } + path := filepath.Join(sessionDir, "GEMINI.md") + if existing, err := os.ReadFile(path); err == nil && string(existing) == systemPrompt { + return false + } + if err := os.WriteFile(path, []byte(systemPrompt), 0600); err != nil { + slog.Warn("acp: failed to write GEMINI.md", "path", path, "error", err) + return false + } + return true +} + // resolveSession returns the ACP session ID for a goclaw session key. -// It creates a new session if none exists, or reloads it after a process respawn. +// sessionDir is the pre-computed per-session workspace (caller must ensure it exists). +// Returns isNew=true only when a brand-new session is created via session/new — +// callers use this to inject full conversation history into the first prompt. // A per-key mutex prevents concurrent creation races for the same session. -func (p *ACPProvider) resolveSession(ctx context.Context, proc *acp.ACPProcess, goclawKey string) (string, error) { +func (p *ACPProvider) resolveSession(ctx context.Context, proc *acp.ACPProcess, sessionDir, goclawKey string) (sid string, isNew bool, err error) { actual, _ := p.sessionMu.LoadOrStore(goclawKey, &sync.Mutex{}) mu := actual.(*sync.Mutex) mu.Lock() @@ -141,29 +388,29 @@ func (p *ACPProvider) resolveSession(ctx context.Context, proc *acp.ACPProcess, if entry.proc == proc { // Same process instance: session is still live, just update last-used entry.lastUsed = time.Now() - return entry.id, nil + return entry.id, false, nil } // Process was respawned — try to restore the session slog.Info("acp: process respawned, attempting session restore", "goclaw_session", goclawKey, "old_sid", entry.id) if proc.AgentCaps().LoadSession { - sid, err := proc.LoadSession(ctx, entry.id) + sid, err := proc.LoadSession(ctx, entry.id, sessionDir) if err == nil { p.acpSessions.Store(goclawKey, &acpSessionEntry{id: sid, proc: proc, lastUsed: time.Now()}) - return sid, nil + return sid, false, nil } slog.Warn("acp: session/load failed, creating new session", "old_sid", entry.id, "error", err) } // session/load not supported or failed — fall through to create new } - slog.Info("acp: creating new session", "goclaw_session", goclawKey, "pool_key", p.poolKey) - sid, err := proc.NewSession(ctx) + slog.Info("acp: creating new session", "goclaw_session", goclawKey, "pool_key", p.poolKey, "cwd", sessionDir) + sid, err = proc.NewSession(ctx, sessionDir) if err != nil { - return "", err + return "", false, err } p.acpSessions.Store(goclawKey, &acpSessionEntry{id: sid, proc: proc, lastUsed: time.Now()}) - return sid, nil + return sid, true, nil } func (p *ACPProvider) Name() string { return p.name } @@ -183,8 +430,118 @@ func (p *ACPProvider) Capabilities() ProviderCapabilities { } } +// injectRoutingFromOpts stores all MCP bridge routing values from ChatRequest.Options +// into ctx. Mirrors Claude CLI's bridgeContextFromOpts pattern: the pipeline sets +// all Opt* values in loop_pipeline_callbacks.go so they are always available here. +func injectRoutingFromOpts(ctx context.Context, opts map[string]any) context.Context { + return context.WithValue(ctx, acpRoutingKey{}, acpRoutingValues{ + agentID: extractStringOpt(opts, OptAgentID), + userID: extractStringOpt(opts, OptUserID), + channel: extractStringOpt(opts, OptChannel), + chatID: extractStringOpt(opts, OptChatID), + peerKind: extractStringOpt(opts, OptPeerKind), + workspace: extractStringOpt(opts, OptWorkspace), + tenantID: extractStringOpt(opts, OptTenantID), + localKey: extractStringOpt(opts, OptLocalKey), + sessionKey: extractStringOpt(opts, OptSessionKey), + }) +} + +// buildACPServers constructs the []acp.McpServer list for session/new. +// Mirrors buildACPMcpServersFunc but lives inside the provider so it has +// access to all routing values from ChatRequest.Options via context. +func (p *ACPProvider) buildACPServers(ctx context.Context, cd *MCPConfigData, rv acpRoutingValues) []acp.McpServer { + if cd == nil || cd.GatewayAddr == "" { + return nil + } + safe := func(v string) bool { return !strings.ContainsAny(v, "\r\n\x00") } + bridgeURL := fmt.Sprintf("http://%s/mcp/bridge", cd.GatewayAddr) + + headers := []acp.McpServerKV{} + if cd.GatewayToken != "" { + headers = append(headers, acp.McpServerKV{Name: "Authorization", Value: "Bearer " + cd.GatewayToken}) + } + if rv.agentID != "" && safe(rv.agentID) { + headers = append(headers, acp.McpServerKV{Name: "X-Agent-ID", Value: rv.agentID}) + } + if rv.userID != "" && safe(rv.userID) { + headers = append(headers, acp.McpServerKV{Name: "X-User-ID", Value: rv.userID}) + } + if rv.channel != "" && safe(rv.channel) { + headers = append(headers, acp.McpServerKV{Name: "X-Channel", Value: rv.channel}) + } + if rv.chatID != "" && safe(rv.chatID) { + headers = append(headers, acp.McpServerKV{Name: "X-Chat-ID", Value: rv.chatID}) + } + if rv.peerKind != "" && safe(rv.peerKind) { + headers = append(headers, acp.McpServerKV{Name: "X-Peer-Kind", Value: rv.peerKind}) + } + if rv.workspace != "" && safe(rv.workspace) { + headers = append(headers, acp.McpServerKV{Name: "X-Workspace", Value: rv.workspace}) + } + if rv.tenantID != "" && safe(rv.tenantID) { + headers = append(headers, acp.McpServerKV{Name: "X-Tenant-ID", Value: rv.tenantID}) + } + if rv.localKey != "" && safe(rv.localKey) { + headers = append(headers, acp.McpServerKV{Name: "X-Local-Key", Value: rv.localKey}) + } + if rv.sessionKey != "" && safe(rv.sessionKey) { + headers = append(headers, acp.McpServerKV{Name: "X-Session-Key", Value: rv.sessionKey}) + } + if cd.GatewayToken != "" && (rv.agentID != "" || rv.userID != "") { + sig := SignBridgeContext(cd.GatewayToken, rv.agentID, rv.userID, rv.channel, rv.chatID, rv.peerKind, rv.workspace, rv.tenantID, rv.localKey, rv.sessionKey) + headers = append(headers, acp.McpServerKV{Name: "X-Bridge-Sig", Value: sig}) + } + + servers := []acp.McpServer{acp.McpServerHTTP{ + Type: "http", + Name: "goclaw-bridge", + URL: bridgeURL, + Headers: headers, + }} + + if cd.AgentMCPLookup != nil && rv.agentID != "" { + for _, entry := range cd.AgentMCPLookup(ctx, rv.agentID) { + servers = append(servers, acpServerEntryToMCP(entry)) + } + } + return servers +} + +// acpServerEntryToMCP converts an MCPServerEntry to the ACP schema. +func acpServerEntryToMCP(e MCPServerEntry) acp.McpServer { + if e.Transport == "stdio" { + env := make([]acp.McpServerKV, 0, len(e.Env)) + for k, v := range e.Env { + env = append(env, acp.McpServerKV{Name: k, Value: v}) + } + args := e.Args + if args == nil { + args = []string{} + } + return acp.McpServerStdio{ + Type: "stdio", + Name: e.Name, + Command: e.Command, + Args: args, + Env: env, + } + } + headers := make([]acp.McpServerKV, 0, len(e.Headers)) + for k, v := range e.Headers { + headers = append(headers, acp.McpServerKV{Name: k, Value: v}) + } + return acp.McpServerHTTP{ + Type: "http", + Name: e.Name, + URL: e.URL, + Headers: headers, + } +} + // Chat sends a prompt and returns the complete response (non-streaming). func (p *ACPProvider) Chat(ctx context.Context, req ChatRequest) (*ChatResponse, error) { + ctx = injectRoutingFromOpts(ctx, req.Options) sessionKey := extractStringOpt(req.Options, OptSessionKey) if sessionKey == "" { sessionKey = fmt.Sprintf("temp-%d", time.Now().UnixNano()) @@ -195,7 +552,15 @@ func (p *ACPProvider) Chat(ctx context.Context, req ChatRequest) (*ChatResponse, return nil, fmt.Errorf("acp: spawn failed: %w", err) } - acpSessionID, err := p.resolveSession(ctx, proc, sessionKey) + sessionDir := p.ensureSessionDir(proc, sessionKey) + systemPrompt, _, _ := extractFromMessages(req.Messages) + if p.writeGeminiMD(sessionDir, systemPrompt) { + // System prompt changed — invalidate live session so next resolveSession + // creates a fresh one that loads the updated GEMINI.md. + p.acpSessions.Delete(sessionKey) + } + + acpSessionID, isNew, err := p.resolveSession(ctx, proc, sessionDir, sessionKey) if err != nil { return nil, err } @@ -203,7 +568,7 @@ func (p *ACPProvider) Chat(ctx context.Context, req ChatRequest) (*ChatResponse, defer p.purgeSession(sessionKey) } - content := extractACPContent(req) + content := extractACPContent(req, isNew) if len(content) == 0 { return nil, fmt.Errorf("acp: no user message in request") } @@ -212,7 +577,10 @@ func (p *ACPProvider) Chat(ctx context.Context, req ChatRequest) (*ChatResponse, var buf strings.Builder var updateCount int - promptResp, err := proc.Prompt(ctx, acpSessionID, content, func(update acp.SessionUpdate) { + cb := func(update acp.SessionUpdate) { + if update.ToolCall != nil { + slog.Info("acp: tool call (chat)", "name", update.ToolCall.Name, "status", update.ToolCall.Status, "id", update.ToolCall.ID) + } if update.Message != nil { for _, block := range update.Message.Content { if block.Type == "text" { @@ -221,7 +589,20 @@ func (p *ACPProvider) Chat(ctx context.Context, req ChatRequest) (*ChatResponse, } } } - }) + } + + const maxACPRetry = 2 + var promptResp *acp.PromptResponse + for attempt := range maxACPRetry + 1 { + buf.Reset() + updateCount = 0 + promptResp, err = proc.Prompt(ctx, acpSessionID, content, cb) + if err == nil || !isMalformedFunctionCall(err) { + break + } + slog.Warn("acp: malformed function call, retrying", "attempt", attempt+1, "session", sessionKey, "sid", acpSessionID) + } + if err != nil { slog.Error("acp: chat error", "session", sessionKey, "sid", acpSessionID, "error", err) return &ChatResponse{ @@ -230,17 +611,32 @@ func (p *ACPProvider) Chat(ctx context.Context, req ChatRequest) (*ChatResponse, }, err } + if promptResp != nil && promptResp.StopReason == "cancelled" { + slog.Warn("acp: chat cancelled", "session", sessionKey, "sid", acpSessionID, "updates", updateCount) + errMsg := "[요청 취소] 응답 대기 중 타임아웃으로 취소됨" + if buf.Len() > 0 { + errMsg = buf.String() + "\n\n" + errMsg + } + return &ChatResponse{Content: errMsg, FinishReason: "stop"}, nil + } + + outputText := buf.String() slog.Info("acp: chat completed", "session", sessionKey, "sid", acpSessionID, - "stopReason", mapStopReason(promptResp), "updates", updateCount, "contentLen", buf.Len()) + "stopReason", mapStopReason(promptResp), "updates", updateCount, "contentLen", len(outputText)) return &ChatResponse{ - Content: buf.String(), + Content: outputText, FinishReason: mapStopReason(promptResp), - Usage: &Usage{}, + Usage: &Usage{ + PromptTokens: acpInputTokens(req.Messages), + CompletionTokens: acpEstimateTokens(outputText), + TotalTokens: acpInputTokens(req.Messages) + acpEstimateTokens(outputText), + }, }, nil } // ChatStream sends a prompt and streams response chunks via onChunk callback. func (p *ACPProvider) ChatStream(ctx context.Context, req ChatRequest, onChunk func(StreamChunk)) (*ChatResponse, error) { + ctx = injectRoutingFromOpts(ctx, req.Options) sessionKey := extractStringOpt(req.Options, OptSessionKey) if sessionKey == "" { sessionKey = fmt.Sprintf("temp-%d", time.Now().UnixNano()) @@ -251,7 +647,13 @@ func (p *ACPProvider) ChatStream(ctx context.Context, req ChatRequest, onChunk f return nil, fmt.Errorf("acp: spawn failed: %w", err) } - acpSessionID, err := p.resolveSession(ctx, proc, sessionKey) + sessionDir := p.ensureSessionDir(proc, sessionKey) + systemPrompt, _, _ := extractFromMessages(req.Messages) + if p.writeGeminiMD(sessionDir, systemPrompt) { + p.acpSessions.Delete(sessionKey) + } + + acpSessionID, isNew, err := p.resolveSession(ctx, proc, sessionDir, sessionKey) if err != nil { return nil, err } @@ -259,7 +661,7 @@ func (p *ACPProvider) ChatStream(ctx context.Context, req ChatRequest, onChunk f defer p.purgeSession(sessionKey) } - content := extractACPContent(req) + content := extractACPContent(req, isNew) if len(content) == 0 { return nil, fmt.Errorf("acp: no user message in request") } @@ -282,7 +684,7 @@ func (p *ACPProvider) ChatStream(ctx context.Context, req ChatRequest, onChunk f var buf strings.Builder var updateCount int - promptResp, err := proc.Prompt(ctx, acpSessionID, content, func(update acp.SessionUpdate) { + streamCb := func(update acp.SessionUpdate) { if update.Message != nil { for _, block := range update.Message.Content { if block.Type == "text" { @@ -292,10 +694,21 @@ func (p *ACPProvider) ChatStream(ctx context.Context, req ChatRequest, onChunk f } } } - if update.ToolCall != nil && update.ToolCall.Status == "running" { - slog.Debug("acp: tool call", "name", update.ToolCall.Name) + if update.ToolCall != nil { + slog.Info("acp: tool call (stream)", "name", update.ToolCall.Name, "status", update.ToolCall.Status, "id", update.ToolCall.ID) } - }) + } + + const maxACPRetry = 2 + var promptResp *acp.PromptResponse + for attempt := range maxACPRetry + 1 { + promptResp, err = proc.Prompt(ctx, acpSessionID, content, streamCb) + if err == nil || !isMalformedFunctionCall(err) { + break + } + slog.Warn("acp: malformed function call, retrying", "attempt", attempt+1, "session", sessionKey, "sid", acpSessionID) + } + if err != nil { slog.Error("acp: chat error", "session", sessionKey, "sid", acpSessionID, "error", err) return &ChatResponse{ @@ -304,14 +717,31 @@ func (p *ACPProvider) ChatStream(ctx context.Context, req ChatRequest, onChunk f }, err } + if promptResp != nil && promptResp.StopReason == "cancelled" { + slog.Warn("acp: chat stream cancelled", "session", sessionKey, "sid", acpSessionID, "updates", updateCount) + errMsg := "[요청 취소] 응답 대기 중 타임아웃으로 취소됨" + prefix := "\n\n" + if buf.Len() == 0 { + prefix = "" + } + onChunk(StreamChunk{Content: prefix + errMsg}) + onChunk(StreamChunk{Done: true}) + return &ChatResponse{Content: buf.String() + prefix + errMsg, FinishReason: "stop"}, nil + } + onChunk(StreamChunk{Done: true}) + outputText := buf.String() slog.Info("acp: chat stream completed", "session", sessionKey, "sid", acpSessionID, - "stopReason", mapStopReason(promptResp), "updates", updateCount, "contentLen", buf.Len()) + "stopReason", mapStopReason(promptResp), "updates", updateCount, "contentLen", len(outputText)) return &ChatResponse{ - Content: buf.String(), + Content: outputText, FinishReason: mapStopReason(promptResp), - Usage: &Usage{}, + Usage: &Usage{ + PromptTokens: acpInputTokens(req.Messages), + CompletionTokens: acpEstimateTokens(outputText), + TotalTokens: acpInputTokens(req.Messages) + acpEstimateTokens(outputText), + }, }, nil } @@ -339,31 +769,85 @@ func (p *ACPProvider) Close() error { return p.pool.Close() } -// extractACPContent extracts user message + images from ChatRequest into ACP ContentBlocks. -func extractACPContent(req ChatRequest) []acp.ContentBlock { - systemPrompt, userMsg, images := extractFromMessages(req.Messages) - if userMsg == "" { - return nil - } +// acpAllowedMIME is the set of image MIME types accepted by ACP providers. +var acpAllowedMIME = map[string]bool{ + "image/jpeg": true, + "image/png": true, + "image/webp": true, + "image/gif": true, +} - var blocks []acp.ContentBlock +// acpMaxImageBytes is the maximum decoded image size accepted (5 MB). +const acpMaxImageBytes = 5 * 1024 * 1024 - // Prepend system prompt to user message (ACP agents have no separate system prompt API) - text := userMsg - if systemPrompt != "" { - text = systemPrompt + "\n\n" + userMsg +// appendACPImages appends validated image ContentBlocks to blocks. +func appendACPImages(blocks []acp.ContentBlock, images []ImageContent) []acp.ContentBlock { + for _, img := range images { + if !acpAllowedMIME[img.MimeType] { + slog.Warn("acp: unsupported image MIME type, skipping", "mime", img.MimeType) + continue + } + if len(img.Data)*3/4 > acpMaxImageBytes { + slog.Warn("acp: image too large, skipping", "estimatedBytes", len(img.Data)*3/4, "limit", acpMaxImageBytes) + continue + } + blocks = append(blocks, acp.ContentBlock{Type: "image", Data: img.Data, MimeType: img.MimeType}) } - blocks = append(blocks, acp.ContentBlock{Type: "text", Text: text}) + return blocks +} - for _, img := range images { - blocks = append(blocks, acp.ContentBlock{ - Type: "image", - Data: img.Data, - MimeType: img.MimeType, - }) +// extractACPContent builds ACP ContentBlocks from a ChatRequest. +// +// isNew=false (normal turn): GEMINI.md in the session workspace already provides +// the system prompt, so only the current user message is sent. This avoids +// repeating the (often large) system prompt on every turn. +// +// isNew=true (fresh or reset session): the session has no prior context. +// All non-system messages from req.Messages are serialised as a conversation +// transcript so that compacted summaries and recent history are preserved. +// The system prompt is omitted here because writeGeminiMD wrote it to GEMINI.md +// before the session was created. +func extractACPContent(req ChatRequest, isNew bool) []acp.ContentBlock { + msgs := req.Messages + + if !isNew { + // Normal turn: send only the current user message. + _, userMsg, images := extractFromMessages(msgs) + if userMsg == "" { + return nil + } + blocks := []acp.ContentBlock{{Type: "text", Text: userMsg}} + return appendACPImages(blocks, images) } - return blocks + // New session: serialise full conversation context (summary + history + current). + // System prompt is excluded — GEMINI.md handles it. + var sb strings.Builder + var images []ImageContent + for i, m := range msgs { + switch m.Role { + case "system": + continue + case "user": + if i == len(msgs)-1 { + images = m.Images // collect images from last (current) user message + } + sb.WriteString("[User]\n") + sb.WriteString(m.Content) + sb.WriteString("\n\n") + case "assistant": + sb.WriteString("[Assistant]\n") + sb.WriteString(m.Content) + sb.WriteString("\n\n") + } + } + + text := strings.TrimRight(sb.String(), "\n") + if text == "" { + return nil + } + blocks := []acp.ContentBlock{{Type: "text", Text: text}} + return appendACPImages(blocks, images) } // mapStopReason converts ACP stopReason to GoClaw finish reason. @@ -374,9 +858,35 @@ func mapStopReason(resp *acp.PromptResponse) string { switch resp.StopReason { case "max_tokens", "maxContextLength": return "length" - case "cancelled": - return "stop" - default: + case "tool_use": + return "tool_calls" + case "error": + return "error" + default: // end_turn, stop_sequence, cancelled, "" return "stop" } } + +// isMalformedFunctionCall returns true when err indicates Gemini produced an +// invalid tool call JSON — a transient model glitch worth retrying. +func isMalformedFunctionCall(err error) bool { + return err != nil && strings.Contains(err.Error(), "malformed function call") +} + +// acpEstimateTokens returns a rough token count from character count (chars/4). +func acpEstimateTokens(s string) int { + n := len(s) / 4 + if n < 1 && len(s) > 0 { + return 1 + } + return n +} + +// acpInputTokens estimates input token count from all messages. +func acpInputTokens(msgs []Message) int { + var total int + for _, m := range msgs { + total += acpEstimateTokens(m.Content) + } + return total +} diff --git a/internal/providers/acp_provider_test.go b/internal/providers/acp_provider_test.go new file mode 100644 index 0000000000..5507ff7a96 --- /dev/null +++ b/internal/providers/acp_provider_test.go @@ -0,0 +1,161 @@ +package providers + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +// TestExtractACPContent_NormalTurn verifies that isNew=false sends only the +// current user message without system prompt prepend. +func TestExtractACPContent_NormalTurn(t *testing.T) { + req := ChatRequest{ + Messages: []Message{ + {Role: "system", Content: "You are Ender."}, + {Role: "user", Content: "hello"}, + {Role: "assistant", Content: "hi there"}, + {Role: "user", Content: "current question"}, + }, + } + blocks := extractACPContent(req, false) + if len(blocks) != 1 { + t.Fatalf("want 1 block, got %d", len(blocks)) + } + if blocks[0].Text != "current question" { + t.Errorf("want only current user message, got: %q", blocks[0].Text) + } + if strings.Contains(blocks[0].Text, "You are Ender") { + t.Error("system prompt must not appear in normal-turn content") + } +} + +// TestExtractACPContent_NewSession_WithHistory verifies that isNew=true serialises +// the full conversation (summary + history + current) excluding the system prompt. +func TestExtractACPContent_NewSession_WithHistory(t *testing.T) { + req := ChatRequest{ + Messages: []Message{ + {Role: "system", Content: "You are Ender."}, + {Role: "user", Content: "[Previous conversation summary]\nDiscussed KIS API setup."}, + {Role: "assistant", Content: "I understand the context from our previous conversation. How can I help you?"}, + {Role: "user", Content: "turn1 user"}, + {Role: "assistant", Content: "turn1 asst"}, + {Role: "user", Content: "current question"}, + }, + } + blocks := extractACPContent(req, true) + if len(blocks) != 1 { + t.Fatalf("want 1 block, got %d", len(blocks)) + } + text := blocks[0].Text + + // system must be excluded + if strings.Contains(text, "You are Ender") { + t.Error("system prompt must not appear in new-session transcript") + } + // summary must be present + if !strings.Contains(text, "Previous conversation summary") { + t.Error("episodic summary must be included in new-session transcript") + } + // history must be present + if !strings.Contains(text, "turn1 user") || !strings.Contains(text, "turn1 asst") { + t.Error("conversation history must be included in new-session transcript") + } + // current message must be present + if !strings.Contains(text, "current question") { + t.Error("current user message must be included in new-session transcript") + } + // role markers + if !strings.Contains(text, "[User]") || !strings.Contains(text, "[Assistant]") { + t.Error("role markers [User]/[Assistant] must be present") + } +} + +// TestExtractACPContent_NewSession_FirstEver verifies isNew=true with no prior +// history (very first message) behaves correctly and still includes current message. +func TestExtractACPContent_NewSession_FirstEver(t *testing.T) { + req := ChatRequest{ + Messages: []Message{ + {Role: "system", Content: "You are Ender."}, + {Role: "user", Content: "first ever message"}, + }, + } + blocks := extractACPContent(req, true) + if len(blocks) != 1 { + t.Fatalf("want 1 block, got %d", len(blocks)) + } + if !strings.Contains(blocks[0].Text, "first ever message") { + t.Errorf("current message must be present, got: %q", blocks[0].Text) + } + if strings.Contains(blocks[0].Text, "You are Ender") { + t.Error("system prompt must not appear even on first-ever message") + } +} + +// TestExtractACPContent_NoUserMessage verifies that an empty request returns nil. +func TestExtractACPContent_NoUserMessage(t *testing.T) { + req := ChatRequest{ + Messages: []Message{ + {Role: "system", Content: "You are Ender."}, + }, + } + if got := extractACPContent(req, false); got != nil { + t.Errorf("want nil for missing user message, got %v", got) + } + if got := extractACPContent(req, true); got != nil { + t.Errorf("want nil for missing user message (isNew), got %v", got) + } +} + +// TestWriteGeminiMD_WritesFile verifies the file is written and true is returned. +func TestWriteGeminiMD_WritesFile(t *testing.T) { + dir := t.TempDir() + p := &ACPProvider{} + + changed := p.writeGeminiMD(dir, "system prompt content") + if !changed { + t.Fatal("want changed=true for new file") + } + data, err := os.ReadFile(filepath.Join(dir, "GEMINI.md")) + if err != nil { + t.Fatalf("GEMINI.md not created: %v", err) + } + if string(data) != "system prompt content" { + t.Errorf("unexpected content: %q", string(data)) + } +} + +// TestWriteGeminiMD_NoopIfUnchanged verifies no write and false return when unchanged. +func TestWriteGeminiMD_NoopIfUnchanged(t *testing.T) { + dir := t.TempDir() + p := &ACPProvider{} + + p.writeGeminiMD(dir, "same content") + path := filepath.Join(dir, "GEMINI.md") + info1, _ := os.Stat(path) + + changed := p.writeGeminiMD(dir, "same content") + if changed { + t.Fatal("want changed=false when content is identical") + } + info2, _ := os.Stat(path) + if info1.ModTime() != info2.ModTime() { + t.Error("file must not be rewritten when content is unchanged") + } +} + +// TestWriteGeminiMD_UpdatesOnChange verifies file is rewritten and true returned when content changes. +func TestWriteGeminiMD_UpdatesOnChange(t *testing.T) { + dir := t.TempDir() + p := &ACPProvider{} + + p.writeGeminiMD(dir, "old system prompt") + changed := p.writeGeminiMD(dir, "new system prompt") + if !changed { + t.Fatal("want changed=true when content differs") + } + data, _ := os.ReadFile(filepath.Join(dir, "GEMINI.md")) + if string(data) != "new system prompt" { + t.Errorf("expected updated content, got: %q", string(data)) + } +}