diff --git a/cmd/gateway.go b/cmd/gateway.go index 0ebb2a899c..26702feef5 100644 --- a/cmd/gateway.go +++ b/cmd/gateway.go @@ -462,7 +462,7 @@ func runGateway() { instanceLoader.RegisterFactory(channels.TypeFeishu, feishu.FactoryWithPendingStoreAndAudio(pgStores.PendingMessages, audioMgr)) instanceLoader.RegisterFactory(channels.TypeZaloOA, zalo.Factory) instanceLoader.RegisterFactory(channels.TypeZaloPersonal, zalopersonal.FactoryWithPendingStore(pgStores.PendingMessages)) - instanceLoader.RegisterFactory(channels.TypeWhatsApp, whatsapp.FactoryWithDBAudio(pgStores.DB, pgStores.PendingMessages, "pgx", audioMgr, pgStores.BuiltinTools)) + instanceLoader.RegisterFactory(channels.TypeWhatsApp, whatsapp.FactoryWithDBAudio(pgStores.DB, pgStores.PendingMessages, "pgx", audioMgr, pgStores.BuiltinTools, pgStores.ChannelInstances)) instanceLoader.RegisterFactory(channels.TypeSlack, slackchannel.FactoryWithPendingStore(pgStores.PendingMessages)) instanceLoader.RegisterFactory(channels.TypeFacebook, facebook.Factory) instanceLoader.RegisterFactory(channels.TypePancake, pancake.Factory) diff --git a/cmd/gateway_channels_setup.go b/cmd/gateway_channels_setup.go index df2840a3f3..b89699262a 100644 --- a/cmd/gateway_channels_setup.go +++ b/cmd/gateway_channels_setup.go @@ -74,7 +74,11 @@ func registerConfigChannels(cfg *config.Config, channelMgr *channels.Manager, ms if strings.Contains(fmt.Sprintf("%T", pgStores.DB.Driver()), "sqlite") { waDialect = "sqlite3" } - wa, err := whatsapp.New(cfg.Channels.WhatsApp, msgBus, pgStores.Pairing, pgStores.DB, pgStores.PendingMessages, waDialect, audioMgr, pgStores.BuiltinTools) + // Config-only WhatsApp (single instance, no DB-backed channel_instances row); + // no instance store, no configJID — falls back to GetFirstDevice via NewDevice + // adoption when the legacy single-device store already exists. + wa, err := whatsapp.New(cfg.Channels.WhatsApp, msgBus, pgStores.Pairing, pgStores.DB, + pgStores.PendingMessages, waDialect, audioMgr, pgStores.BuiltinTools, nil, "") if err != nil { channelMgr.RecordFailure(channels.TypeWhatsApp, "", err) slog.Error("failed to initialize whatsapp channel", "error", err) diff --git a/cmd/gateway_cron.go b/cmd/gateway_cron.go index 8bb9759b94..da83da5dca 100644 --- a/cmd/gateway_cron.go +++ b/cmd/gateway_cron.go @@ -88,8 +88,11 @@ func makeCronJobHandler(sched *scheduler.Scheduler, msgBus *bus.MessageBus, cfg // Reset session before each cron run to prevent tool errors from previous // runs from polluting the context and blocking future executions (#294). // Save() persists the empty session to DB so stale data won't reload after restart. - // Stateless jobs skip this — they intentionally carry no session history. - if !job.Stateless { + // Always reset cron sessions to prevent message accumulation across runs. + // Stateless jobs especially need this — the agent loop persists messages + // to the session regardless of the stateless flag, so without a reset + // the session grows indefinitely. + { sessionMgr.Reset(cronCtx, sessionKey) sessionMgr.Save(cronCtx, sessionKey) } diff --git a/internal/channels/instance_loader.go b/internal/channels/instance_loader.go index df6d677f30..5ac521266a 100644 --- a/internal/channels/instance_loader.go +++ b/internal/channels/instance_loader.go @@ -270,6 +270,12 @@ func (l *InstanceLoader) loadInstance(ctx context.Context, inst store.ChannelIns if base, ok := ch.(interface{ SetTenantID(uuid.UUID) }); ok { base.SetTenantID(inst.TenantID) } + // Propagate instance_id so channels that maintain per-instance external state + // (e.g. WhatsApp's whatsmeow_device row scoped to this channel) can persist it + // back to channel_instances.config. + if base, ok := ch.(interface{ SetInstanceID(uuid.UUID) }); ok { + base.SetInstanceID(inst.ID) + } // Propagate tenant_id to pending history for compaction/sweep DB operations. // Factory creates PendingHistory before SetTenantID is called, so tenantID is uuid.Nil at construction. if ph, ok := ch.(interface{ SetPendingHistoryTenantID(uuid.UUID) }); ok { diff --git a/internal/channels/whatsapp/auth.go b/internal/channels/whatsapp/auth.go index 76deb6f3eb..ba15a81134 100644 --- a/internal/channels/whatsapp/auth.go +++ b/internal/channels/whatsapp/auth.go @@ -23,7 +23,7 @@ func (c *Channel) StartQRFlow(ctx context.Context) (<-chan whatsmeow.QRChannelIt if c.ctx == nil { c.ctx, c.cancel = context.WithCancel(context.Background()) } - deviceStore, err := c.container.GetFirstDevice(ctx) + deviceStore, err := c.resolveDevice(ctx) if err != nil { c.mu.Unlock() return nil, fmt.Errorf("whatsapp get device: %w", err) @@ -90,11 +90,12 @@ func (c *Channel) Reauth() error { } c.ctx, c.cancel = context.WithCancel(parent) - // Re-create client with fresh device store. - deviceStore, err := c.container.GetFirstDevice(context.Background()) - if err != nil { - return fmt.Errorf("whatsapp: get fresh device: %w", err) - } + // Re-create client with a fresh device. Reauth always forces a new pairing, + // so we bypass resolveDevice (which would try to adopt an existing device). + // configJID is also cleared so the next persistJID on PairSuccess writes the + // new JID into channel_instances.config without short-circuiting on equality. + c.configJID = "" + deviceStore := c.container.NewDevice() c.client = whatsmeow.NewClient(deviceStore, nil) c.client.AddEventHandler(c.handleEvent) diff --git a/internal/channels/whatsapp/factory.go b/internal/channels/whatsapp/factory.go index 78316aa4ac..e2c60d452e 100644 --- a/internal/channels/whatsapp/factory.go +++ b/internal/channels/whatsapp/factory.go @@ -20,18 +20,25 @@ type whatsappInstanceConfig struct { HistoryLimit int `json:"history_limit,omitempty"` AllowFrom []string `json:"allow_from,omitempty"` BlockReply *bool `json:"block_reply,omitempty"` + // JID is the whatsmeow device JID this instance was paired with on a prior boot. + // Set automatically on PairSuccess and on adoption of an existing single-device + // store. Empty for fresh instances; the channel will NewDevice + go through QR. + JID string `json:"jid,omitempty"` } // FactoryWithDB returns a ChannelFactory with DB access for whatsmeow auth state. // dialect must be "pgx" (PostgreSQL) or "sqlite3" (SQLite/desktop). func FactoryWithDB(db *sql.DB, pendingStore store.PendingMessageStore, dialect string) channels.ChannelFactory { - return FactoryWithDBAudio(db, pendingStore, dialect, nil, nil) + return FactoryWithDBAudio(db, pendingStore, dialect, nil, nil, nil) } // FactoryWithDBAudio returns a ChannelFactory with DB access, STT support, and builtin-tools store -// for reading stt.whatsapp_enabled opt-in setting per message. +// for reading stt.whatsapp_enabled opt-in setting per message. instanceStore is optional but +// required for multi-instance device scoping (passed from cmd/gateway.go); nil falls back to +// legacy single-instance GetFirstDevice behavior. func FactoryWithDBAudio(db *sql.DB, pendingStore store.PendingMessageStore, dialect string, - audioMgr *audio.Manager, builtinToolStore store.BuiltinToolStore) channels.ChannelFactory { + audioMgr *audio.Manager, builtinToolStore store.BuiltinToolStore, + instanceStore store.ChannelInstanceStore) channels.ChannelFactory { return func(name string, creds json.RawMessage, cfg json.RawMessage, msgBus *bus.MessageBus, pairingSvc store.PairingStore) (channels.Channel, error) { @@ -72,7 +79,8 @@ func FactoryWithDBAudio(db *sql.DB, pendingStore store.PendingMessageStore, dial waCfg.GroupPolicy = "pairing" } - ch, err := New(waCfg, msgBus, pairingSvc, db, pendingStore, dialect, audioMgr, builtinToolStore) + ch, err := New(waCfg, msgBus, pairingSvc, db, pendingStore, dialect, audioMgr, builtinToolStore, + instanceStore, ic.JID) if err != nil { return nil, err } diff --git a/internal/channels/whatsapp/whatsapp.go b/internal/channels/whatsapp/whatsapp.go index cbd8afd9ed..7309fde279 100644 --- a/internal/channels/whatsapp/whatsapp.go +++ b/internal/channels/whatsapp/whatsapp.go @@ -3,11 +3,13 @@ package whatsapp import ( "context" "database/sql" + "encoding/json" "fmt" "log/slog" "sync" "time" + "github.com/google/uuid" "go.mau.fi/whatsmeow" wastore "go.mau.fi/whatsmeow/store" "go.mau.fi/whatsmeow/store/sqlstore" @@ -58,6 +60,16 @@ type Channel struct { // reauthMu serializes Reauth() and StartQRFlow() to prevent race when user clicks reauth rapidly. reauthMu sync.Mutex // pairingService, pairingDebounce, approvedGroups, groupHistory are inherited from channels.BaseChannel. + + // instanceID + instanceStore scope this channel to a specific channel_instances row, + // so multiple WhatsApp channels in one deploy each bind to their own whatsmeow_device row. + // Without this, every channel reused the first device returned by GetFirstDevice and ended + // up logged in as the same WhatsApp account regardless of name. + instanceID uuid.UUID + instanceStore store.ChannelInstanceStore + // configJID is the device JID this channel adopted on a previous run, mirrored from the + // instance's config jsonb (key "jid"). Empty when the channel has never paired. + configJID string } // GetLastQRB64 returns the most recent QR PNG (base64). @@ -85,10 +97,15 @@ func (c *Channel) cacheQR(pngB64 string) { // dialect must be "pgx" (PostgreSQL) or "sqlite3" (SQLite/desktop). // audioMgr is optional (nil = STT disabled). // builtinToolStore is optional (nil = STT permanently opt-out regardless of admin toggle). +// instanceStore is optional but required for multi-instance device scoping; without it, +// the channel falls back to GetFirstDevice (legacy single-instance behavior). +// configJID is the JID adopted on a prior run (from instance config "jid"); empty for +// fresh instances that should NewDevice + QR. func New(cfg config.WhatsAppConfig, msgBus *bus.MessageBus, pairingSvc store.PairingStore, db *sql.DB, pendingStore store.PendingMessageStore, dialect string, audioMgr *audio.Manager, - builtinToolStore store.BuiltinToolStore) (*Channel, error) { + builtinToolStore store.BuiltinToolStore, + instanceStore store.ChannelInstanceStore, configJID string) (*Channel, error) { base := channels.NewBaseChannel(channels.TypeWhatsApp, msgBus, cfg.AllowFrom) base.ValidatePolicy(cfg.DMPolicy, cfg.GroupPolicy) @@ -104,12 +121,169 @@ func New(cfg config.WhatsAppConfig, msgBus *bus.MessageBus, container: container, audioMgr: audioMgr, builtinToolStore: builtinToolStore, + instanceStore: instanceStore, + configJID: configJID, } ch.SetPairingService(pairingSvc) ch.SetGroupHistory(channels.MakeHistory("whatsapp", pendingStore, base.TenantID())) return ch, nil } +// SetInstanceID associates this channel with its channel_instances row. +// Called by InstanceLoader after construction so we can persist the paired JID +// back to the row's config jsonb on PairSuccess. +func (c *Channel) SetInstanceID(id uuid.UUID) { c.instanceID = id } + +// resolveDevice returns the *store.Device this channel should use, scoped to the +// channel_instances row identified by configJID/instanceID. Three paths: +// 1. configJID set + device exists in whatsmeow_device → reuse it. +// 2. configJID empty + adoption succeeds → adopt an unclaimed orphan device +// (covers single-channel deploys upgrading to multi-channel without re-pair). +// 3. Otherwise → NewDevice() returns a fresh in-memory device that whatsmeow +// will persist via Connect → QR pairing flow. +func (c *Channel) resolveDevice(ctx context.Context) (*wastore.Device, error) { + if c.configJID != "" { + jid, err := types.ParseJID(c.configJID) + if err == nil { + dev, err := c.container.GetDevice(ctx, jid) + if err != nil { + return nil, fmt.Errorf("whatsapp get device by jid %s: %w", jid, err) + } + if dev != nil { + return dev, nil + } + slog.Warn("whatsapp: stored JID not found in device store, falling back to fresh pairing", + "channel", c.Name(), "jid", c.configJID) + } else { + slog.Warn("whatsapp: stored JID is malformed, falling back to fresh pairing", + "channel", c.Name(), "jid", c.configJID, "error", err) + } + } + if dev, ok := c.adoptOrphanDevice(ctx); ok { + slog.Info("whatsapp: adopted existing device for instance", + "channel", c.Name(), "jid", dev.ID) + // Persist the adopted JID so subsequent boots take the configJID path + // directly and don't risk re-adopting a device already claimed by another + // channel that just happened to start later. + if dev.ID != nil { + c.persistJID(ctx, *dev.ID) + } + return dev, nil + } + return c.container.NewDevice(), nil +} + +// adoptOrphanDevice handles the upgrade case where a deploy with a single +// pre-existing whatsmeow_device row gains a second WhatsApp channel_instance. +// To avoid stealing the legacy device from the wrong instance, we only adopt +// when ALL of the following hold: +// - exactly one whatsmeow_device row exists in the store (so there is no +// ambiguity about which device is "the legacy one"), AND +// - exactly one WhatsApp channel_instance exists in the database (so the +// legacy device unambiguously belongs to that instance), AND +// - this channel IS that single instance. +// +// In every other configuration (multi-instance deploys, fresh installs, etc.) +// adoption is skipped and the channel goes through QR pairing. +func (c *Channel) adoptOrphanDevice(ctx context.Context) (*wastore.Device, bool) { + if c.instanceStore == nil || c.instanceID == uuid.Nil { + return nil, false + } + devs, err := c.container.GetAllDevices(ctx) + if err != nil || len(devs) != 1 { + return nil, false + } + dev := devs[0] + if dev == nil || dev.ID == nil { + return nil, false + } + listCtx := store.WithCrossTenant(ctx) + instances, err := c.instanceStore.ListAllInstances(listCtx) + if err != nil { + slog.Warn("whatsapp: list instances for adoption failed", "error", err) + return nil, false + } + var ( + whatsappCount int + soleID uuid.UUID + soleJID string + ) + for _, inst := range instances { + if inst.ChannelType != channels.TypeWhatsApp { + continue + } + whatsappCount++ + if whatsappCount > 1 { + return nil, false + } + soleID = inst.ID + var ic struct { + JID string `json:"jid"` + } + if len(inst.Config) > 0 { + _ = json.Unmarshal(inst.Config, &ic) + } + soleJID = ic.JID + } + if whatsappCount != 1 || soleID != c.instanceID { + return nil, false + } + if soleJID != "" && soleJID != dev.ID.String() { + // Sole instance already claims a different JID — refuse to adopt. + return nil, false + } + return dev, true +} + +// persistJID writes the device JID back to channel_instances.config so the next +// channel start binds to the same device without going through QR. Best-effort: +// failures are logged but don't fail the boot — the channel is already connected. +func (c *Channel) persistJID(ctx context.Context, jid types.JID) { + if c.instanceStore == nil || c.instanceID == uuid.Nil { + return + } + jidStr := jid.String() + if jidStr == c.configJID { + return + } + tenantID := c.TenantID() + scopeCtx := ctx + if tenantID != uuid.Nil { + scopeCtx = store.WithTenantID(ctx, tenantID) + } else { + scopeCtx = store.WithCrossTenant(ctx) + } + inst, err := c.instanceStore.Get(scopeCtx, c.instanceID) + if err != nil { + slog.Warn("whatsapp: persist JID — instance lookup failed", + "channel", c.Name(), "instance_id", c.instanceID, "error", err) + return + } + cfgMap := map[string]any{} + if len(inst.Config) > 0 { + if err := json.Unmarshal(inst.Config, &cfgMap); err != nil { + slog.Warn("whatsapp: persist JID — config unmarshal failed", + "channel", c.Name(), "error", err) + cfgMap = map[string]any{} + } + } + cfgMap["jid"] = jidStr + cfgBytes, err := json.Marshal(cfgMap) + if err != nil { + slog.Warn("whatsapp: persist JID — config marshal failed", "error", err) + return + } + if err := c.instanceStore.Update(scopeCtx, c.instanceID, + map[string]any{"config": cfgBytes}); err != nil { + slog.Warn("whatsapp: persist JID — update failed", + "channel", c.Name(), "instance_id", c.instanceID, "error", err) + return + } + c.configJID = jidStr + slog.Info("whatsapp: persisted device JID to channel instance", + "channel", c.Name(), "jid", jidStr) +} + // Start initializes the whatsmeow client and connects to WhatsApp. func (c *Channel) Start(ctx context.Context) error { slog.Info("starting whatsapp channel (whatsmeow)") @@ -118,7 +292,7 @@ func (c *Channel) Start(ctx context.Context) error { c.parentCtx = ctx c.ctx, c.cancel = context.WithCancel(ctx) - deviceStore, err := c.container.GetFirstDevice(ctx) + deviceStore, err := c.resolveDevice(ctx) if err != nil { return fmt.Errorf("whatsapp get device: %w", err) } @@ -183,7 +357,15 @@ func (c *Channel) handleEvent(evt any) { case *events.LoggedOut: c.handleLoggedOut(v) case *events.PairSuccess: - slog.Info("whatsapp: pair success", "channel", c.Name()) + slog.Info("whatsapp: pair success", "channel", c.Name(), "jid", v.ID.String()) + // Bind this freshly-paired device to our channel_instances row so the next + // boot reuses the same device instead of going back through QR (or worse, + // adopting a sibling channel's device). + if c.parentCtx != nil { + c.persistJID(c.parentCtx, v.ID) + } else { + c.persistJID(context.Background(), v.ID) + } } } diff --git a/internal/config/config.go b/internal/config/config.go index 80ca1722be..fc07571032 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -248,6 +248,7 @@ type SandboxConfig struct { // Enhanced security User string `json:"user,omitempty"` // container user (e.g. "1000:1000", "nobody") TmpfsSizeMB int `json:"tmpfs_size_mb,omitempty"` // default tmpfs size in MB (0 = Docker default) + AllowTmpExec bool `json:"allow_tmp_exec,omitempty"` // drop `noexec` from tmpfs mounts (still keeps nosuid+nodev). Required by some bundled-binary CLIs that extract+exec from /tmp at runtime. MaxOutputBytes int `json:"max_output_bytes,omitempty"` // limit exec output capture (default 1MB) // Pruning (matching TS SandboxPruneSettings) @@ -319,6 +320,7 @@ func (sc *SandboxConfig) ToSandboxConfig() sandbox.Config { if sc.TmpfsSizeMB > 0 { cfg.TmpfsSizeMB = sc.TmpfsSizeMB } + cfg.AllowTmpExec = sc.AllowTmpExec if sc.MaxOutputBytes > 0 { cfg.MaxOutputBytes = sc.MaxOutputBytes } diff --git a/internal/config/config_load.go b/internal/config/config_load.go index a844e1aeaf..40b0b6f093 100644 --- a/internal/config/config_load.go +++ b/internal/config/config_load.go @@ -269,6 +269,10 @@ func (c *Config) applyEnvOverrides() { ensureSandbox() c.Agents.Defaults.Sandbox.NetworkEnabled = v == "true" || v == "1" } + if v := os.Getenv("GOCLAW_SANDBOX_TMP_EXEC"); v != "" { + ensureSandbox() + c.Agents.Defaults.Sandbox.AllowTmpExec = v == "true" || v == "1" + } // Browser (for Docker-compose browser sidecar overlay) envStr("GOCLAW_BROWSER_REMOTE_URL", &c.Tools.Browser.RemoteURL) diff --git a/internal/hooks/dispatcher.go b/internal/hooks/dispatcher.go index c608d0cea1..c608cea5e0 100644 --- a/internal/hooks/dispatcher.go +++ b/internal/hooks/dispatcher.go @@ -241,7 +241,10 @@ func (d *stdDispatcher) runSync(ctx context.Context, ev Event, chain []HookConfi switch dec { case DecisionBlock: d.cb.record(ctx, cfg.ID, d.now(), d.store) - return FireResult{Decision: DecisionBlock}, nil + // Forward the script reason so callers can surface a self- + // documenting message to the agent. Reason stays empty for + // non-script handlers and for scripts that did not set one. + return FireResult{Decision: DecisionBlock, Reason: scriptRes.Reason}, nil case DecisionTimeout: d.cb.record(ctx, cfg.ID, d.now(), d.store) if cfg.OnTimeout == DecisionBlock { diff --git a/internal/hooks/dispatcher_test.go b/internal/hooks/dispatcher_test.go index 8d2eff42a5..2aee35bcc4 100644 --- a/internal/hooks/dispatcher_test.go +++ b/internal/hooks/dispatcher_test.go @@ -86,11 +86,16 @@ func (f *fakeStore) snapshotUpdates() []fakeUpdate { } // fakeHandler returns a scripted decision + optional sleep (for timeout tests). +// +// reason mirrors the script-handler contract: when set, it is written into the +// per-execution ScriptResult (carried via ctx) so the dispatcher can forward +// it on the FireResult. Used by the block-reason propagation test. type fakeHandler struct { decision hooks.Decision sleep time.Duration err error calls int32 + reason string } func (h *fakeHandler) Execute(ctx context.Context, _ hooks.HookConfig, _ hooks.Event) (hooks.Decision, error) { @@ -103,6 +108,11 @@ func (h *fakeHandler) Execute(ctx context.Context, _ hooks.HookConfig, _ hooks.E return hooks.DecisionTimeout, ctx.Err() } } + if h.reason != "" { + if r := hooks.ScriptResultFrom(ctx); r != nil { + r.Reason = h.reason + } + } return h.decision, h.err } @@ -627,3 +637,62 @@ func allowlistID(name string) uuid.UUID { // Stable deterministic UUID per label so cfg.ID matches the lookup. return uuid.NewSHA1(uuid.NameSpaceDNS, []byte("test.allowlist/"+name)) } + +// ── Block-reason propagation ──────────────────────────────────────────────── + +// TestDispatcher_BlockReason_PropagatesToFireResult verifies that when a +// script-handler hook blocks with a non-empty `reason`, the dispatcher copies +// the reason onto FireResult.Reason so callers (pipeline tool stage / context +// stage) can surface a self-documenting block message. +func TestDispatcher_BlockReason_PropagatesToFireResult(t *testing.T) { + cfg := newBaseHook(hooks.HandlerScript, hooks.EventPreToolUse) + fs := &fakeStore{hooks: []hooks.HookConfig{cfg}} + blocker := &fakeHandler{decision: hooks.DecisionBlock, reason: "use rtk prefix"} + + d := hooks.NewStdDispatcher(hooks.StdDispatcherOpts{ + Store: fs, + Audit: hooks.NewAuditWriter(fs, ""), + Handlers: map[hooks.HandlerType]hooks.Handler{hooks.HandlerScript: blocker}, + }) + r, err := d.Fire(context.Background(), hooks.Event{ + EventID: "e-reason", + HookEvent: hooks.EventPreToolUse, + }) + if err != nil { + t.Fatalf("Fire: %v", err) + } + if r.Decision != hooks.DecisionBlock { + t.Fatalf("decision=%q, want block", r.Decision) + } + if r.Reason != "use rtk prefix" { + t.Errorf("reason=%q, want %q", r.Reason, "use rtk prefix") + } +} + +// TestDispatcher_Block_NoReason_LeavesReasonEmpty verifies the no-reason path +// stays backward-compatible: blocking handlers that do not populate +// ScriptResult.Reason yield FireResult.Reason == "". +func TestDispatcher_Block_NoReason_LeavesReasonEmpty(t *testing.T) { + cfg := newBaseHook(hooks.HandlerHTTP, hooks.EventPreToolUse) + fs := &fakeStore{hooks: []hooks.HookConfig{cfg}} + blocker := &fakeHandler{decision: hooks.DecisionBlock} + + d := hooks.NewStdDispatcher(hooks.StdDispatcherOpts{ + Store: fs, + Audit: hooks.NewAuditWriter(fs, ""), + Handlers: map[hooks.HandlerType]hooks.Handler{hooks.HandlerHTTP: blocker}, + }) + r, err := d.Fire(context.Background(), hooks.Event{ + EventID: "e-noreason", + HookEvent: hooks.EventPreToolUse, + }) + if err != nil { + t.Fatalf("Fire: %v", err) + } + if r.Decision != hooks.DecisionBlock { + t.Fatalf("decision=%q, want block", r.Decision) + } + if r.Reason != "" { + t.Errorf("reason=%q, want empty for handler that did not set one", r.Reason) + } +} diff --git a/internal/hooks/types.go b/internal/hooks/types.go index 0c53c3ad3b..8e1d91133e 100644 --- a/internal/hooks/types.go +++ b/internal/hooks/types.go @@ -147,10 +147,17 @@ func (d Decision) IsBlock() bool { // For non-builtin scripts returning updatedInput the dispatcher strips the // mutation + logs a WARN; Updated* stay nil (defense-in-depth against a // tenant-authored script escalating its capability tier). +// +// Reason is populated only on DecisionBlock paths and only when a script- +// handler hook returned a non-empty `reason` field. Callers (e.g. the +// pipeline tool stage) surface it to the agent as the synthetic tool message +// so the hook can self-document why the operation was blocked. Empty when no +// script reason is available — callers fall back to a generic message. type FireResult struct { Decision Decision UpdatedToolInput map[string]any UpdatedRawInput *string + Reason string } // ─── Config & execution structs ────────────────────────────────────────────── diff --git a/internal/http/secure_cli.go b/internal/http/secure_cli.go index 0eac1b7f8c..ca03f81d21 100644 --- a/internal/http/secure_cli.go +++ b/internal/http/secure_cli.go @@ -161,6 +161,7 @@ type secureCLICreateRequest struct { TimeoutSeconds int `json:"timeout_seconds,omitempty"` Tips string `json:"tips,omitempty"` IsGlobal *bool `json:"is_global,omitempty"` + AllowChainExec *bool `json:"allow_chain_exec,omitempty"` Enabled bool `json:"enabled"` } @@ -225,6 +226,7 @@ func (h *SecureCLIHandler) handleCreate(w http.ResponseWriter, r *http.Request) TimeoutSeconds: req.TimeoutSeconds, Tips: req.Tips, IsGlobal: req.IsGlobal == nil || *req.IsGlobal, // default true + AllowChainExec: req.AllowChainExec != nil && *req.AllowChainExec, Enabled: req.Enabled, CreatedBy: store.UserIDFromContext(r.Context()), } @@ -282,7 +284,7 @@ func (h *SecureCLIHandler) handleUpdate(w http.ResponseWriter, r *http.Request) allowed := map[string]bool{ "binary_name": true, "binary_path": true, "description": true, "env": true, "deny_args": true, "deny_verbose": true, - "timeout_seconds": true, "tips": true, "is_global": true, "enabled": true, + "timeout_seconds": true, "tips": true, "is_global": true, "allow_chain_exec": true, "enabled": true, } for k := range updates { if !allowed[k] { diff --git a/internal/pipeline/context_stage.go b/internal/pipeline/context_stage.go index 5f4d7050cd..775a0e018b 100644 --- a/internal/pipeline/context_stage.go +++ b/internal/pipeline/context_stage.go @@ -59,6 +59,9 @@ func (s *ContextStage) Execute(ctx context.Context, state *RunState) error { RawInput: state.Input.Message, HookEvent: hooks.EventUserPromptSubmit, }); r.Decision == hooks.DecisionBlock { + if r.Reason != "" { + return fmt.Errorf("hook blocked user_prompt_submit: %s", r.Reason) + } return fmt.Errorf("hook blocked user_prompt_submit") } else if r.UpdatedRawInput != nil { state.Input.Message = *r.UpdatedRawInput diff --git a/internal/pipeline/tool_stage.go b/internal/pipeline/tool_stage.go index 8779822b46..719522ba42 100644 --- a/internal/pipeline/tool_stage.go +++ b/internal/pipeline/tool_stage.go @@ -61,9 +61,16 @@ func (s *ToolStage) Execute(ctx context.Context, state *RunState) error { HookEvent: hooks.EventPreToolUse, }); r.Decision == hooks.DecisionBlock { // Inject synthetic blocked tool message and skip actual execution. + // Surface the script reason when present so the hook can self- + // document why the call was blocked (e.g. retry hints). Fall back + // to the generic line for non-script handlers and silent scripts. + content := "Hook blocked: pre_tool_use" + if r.Reason != "" { + content = "Hook blocked: pre_tool_use — " + r.Reason + } state.Messages.AppendPending(providers.Message{ Role: "tool", - Content: "Hook blocked: pre_tool_use", + Content: content, ToolCallID: tc.ID, }) state.Tool.TotalToolCalls++ diff --git a/internal/providers/anthropic_message_cache_test.go b/internal/providers/anthropic_message_cache_test.go new file mode 100644 index 0000000000..a47f76c024 --- /dev/null +++ b/internal/providers/anthropic_message_cache_test.go @@ -0,0 +1,156 @@ +package providers + +import ( + "encoding/json" + "testing" +) + +// TestApplyCacheControlEmptyMessages verifies the helper is a no-op on an +// empty message slice (e.g. tool-result-only first turn before any user msg). +func TestApplyCacheControlEmptyMessages(t *testing.T) { + var messages []map[string]any + applyCacheControlToLastMessage(messages) + if len(messages) != 0 { + t.Fatalf("expected 0 messages, got %d", len(messages)) + } +} + +// TestApplyCacheControlStringContent verifies that a plain string content +// is converted to a single text block carrying the cache_control marker. +// Anthropic accepts both shapes; converting lets us attach the breakpoint. +func TestApplyCacheControlStringContent(t *testing.T) { + messages := []map[string]any{ + {"role": "user", "content": "hello"}, + } + applyCacheControlToLastMessage(messages) + + content, ok := messages[0]["content"].([]map[string]any) + if !ok { + t.Fatalf("expected []map[string]any, got %T", messages[0]["content"]) + } + if len(content) != 1 { + t.Fatalf("expected 1 block, got %d", len(content)) + } + if content[0]["type"] != "text" || content[0]["text"] != "hello" { + t.Errorf("block content mismatch: %+v", content[0]) + } + if content[0]["cache_control"] == nil { + t.Error("missing cache_control on converted text block") + } +} + +// TestApplyCacheControlBlockArrayContent verifies that an existing block +// array (multi-modal user, tool_result, assistant text+tool_use) gets +// cache_control on the last block only — earlier blocks stay untouched. +func TestApplyCacheControlBlockArrayContent(t *testing.T) { + messages := []map[string]any{ + { + "role": "user", + "content": []map[string]any{ + {"type": "image", "source": map[string]any{"type": "base64"}}, + {"type": "text", "text": "describe this"}, + }, + }, + } + applyCacheControlToLastMessage(messages) + + content := messages[0]["content"].([]map[string]any) + if content[0]["cache_control"] != nil { + t.Error("first block should not have cache_control") + } + if content[1]["cache_control"] == nil { + t.Error("last block missing cache_control") + } +} + +// TestApplyCacheControlToolResultContent verifies tool_result messages +// (sent as user role with a tool_result block) get the breakpoint. Tool +// results are deterministic and stable across replays — safe to cache. +func TestApplyCacheControlToolResultContent(t *testing.T) { + messages := []map[string]any{ + {"role": "user", "content": "first turn"}, + { + "role": "user", + "content": []map[string]any{ + { + "type": "tool_result", + "tool_use_id": "tool_123", + "content": "result data", + }, + }, + }, + } + applyCacheControlToLastMessage(messages) + + first := messages[0]["content"] + if _, isString := first.(string); !isString { + t.Errorf("first message should remain a plain string, got %T", first) + } + last := messages[1]["content"].([]map[string]any) + if last[0]["cache_control"] == nil { + t.Error("tool_result block missing cache_control") + } +} + +// TestApplyCacheControlRawAssistantBlocks verifies the json.RawMessage path +// used for assistant turns that preserve thinking signatures. The last +// raw block must round-trip through JSON with cache_control attached. +func TestApplyCacheControlRawAssistantBlocks(t *testing.T) { + thinking, _ := json.Marshal(map[string]any{ + "type": "thinking", + "thinking": "let me consider", + "signature": "abc123", + }) + text, _ := json.Marshal(map[string]any{ + "type": "text", + "text": "here is the answer", + }) + + messages := []map[string]any{ + { + "role": "assistant", + "content": []json.RawMessage{thinking, text}, + }, + } + applyCacheControlToLastMessage(messages) + + raw := messages[0]["content"].([]json.RawMessage) + + var firstBlock map[string]any + if err := json.Unmarshal(raw[0], &firstBlock); err != nil { + t.Fatalf("first block unmarshal: %v", err) + } + if firstBlock["cache_control"] != nil { + t.Error("first raw block should not have cache_control") + } + + var lastBlock map[string]any + if err := json.Unmarshal(raw[1], &lastBlock); err != nil { + t.Fatalf("last block unmarshal: %v", err) + } + if lastBlock["cache_control"] == nil { + t.Error("last raw block missing cache_control") + } + // Preserved fields stay intact. + if lastBlock["text"] != "here is the answer" { + t.Errorf("last block text mutated: %v", lastBlock["text"]) + } +} + +// TestApplyCacheControlRawAssistantInvalidJSON verifies graceful handling +// when the last raw block is not valid JSON: helper silently skips rather +// than corrupting the request body. +func TestApplyCacheControlRawAssistantInvalidJSON(t *testing.T) { + bad := json.RawMessage("not valid json") + messages := []map[string]any{ + { + "role": "assistant", + "content": []json.RawMessage{bad}, + }, + } + applyCacheControlToLastMessage(messages) + raw := messages[0]["content"].([]json.RawMessage) + if string(raw[0]) != "not valid json" { + t.Errorf("invalid block was mutated: %s", raw[0]) + } +} diff --git a/internal/providers/anthropic_request.go b/internal/providers/anthropic_request.go index 58dfb3a08b..315995bba6 100644 --- a/internal/providers/anthropic_request.go +++ b/internal/providers/anthropic_request.go @@ -32,6 +32,58 @@ func splitSystemPromptForCache(content string) []map[string]any { return blocks } +// applyCacheControlToLastMessage tags the last message with an ephemeral +// cache_control breakpoint so Anthropic prompt caching rolls forward across +// turns. On turn N+1, content up to this breakpoint becomes a cache hit +// (10% of input cost) and only the new tail writes to cache. +// +// Without this, the entire conversation history is sent uncached every turn, +// which dominates cost on long agent sessions (observed: 36% effective cache +// hit on a 187-message Slack thread, vs. ~80% expected for Claude agents). +// +// Anthropic allows up to 4 cache breakpoints per request. The system prompt +// and last tool definition use 2; this leaves 2 free, and we use one here. +// String content is converted to a single text block to attach the marker. +func applyCacheControlToLastMessage(messages []map[string]any) { + if len(messages) == 0 { + return + } + last := messages[len(messages)-1] + ephemeral := map[string]any{"type": "ephemeral"} + + switch c := last["content"].(type) { + case string: + // Plain string content (typical for text-only user messages). + // Convert to a single-element block array so we can attach + // cache_control. Anthropic accepts both shapes. + last["content"] = []map[string]any{ + {"type": "text", "text": c, "cache_control": ephemeral}, + } + case []map[string]any: + // Block array (multi-modal user, assistant text+tool_use, tool_result). + if len(c) > 0 { + c[len(c)-1]["cache_control"] = ephemeral + } + case []json.RawMessage: + // Assistant raw blocks preserve thinking signatures for tool-use + // passback. Re-marshal the last block with cache_control attached; + // silently skip on decode failure to avoid corrupting the request. + if len(c) == 0 { + return + } + var block map[string]any + if err := json.Unmarshal(c[len(c)-1], &block); err != nil { + return + } + block["cache_control"] = ephemeral + marshaled, err := json.Marshal(block) + if err != nil { + return + } + c[len(c)-1] = marshaled + } +} + // buildRawBlock reconstructs a complete content block from streaming data. // This is needed to preserve thinking blocks (with signatures) for tool use passback. func (p *AnthropicProvider) buildRawBlock(blockType string, result *ChatResponse, toolCallJSON map[int]string, _ int) json.RawMessage { @@ -175,6 +227,10 @@ func (p *AnthropicProvider) buildRequestBody(model string, req ChatRequest, stre } } + // Roll prompt cache forward across turns. Without this, conversation + // history is uncached on every turn — the dominant cost on long sessions. + applyCacheControlToLastMessage(messages) + body := map[string]any{ "model": model, "max_tokens": 4096, diff --git a/internal/sandbox/docker.go b/internal/sandbox/docker.go index 4347053d88..127ce8798e 100644 --- a/internal/sandbox/docker.go +++ b/internal/sandbox/docker.go @@ -48,14 +48,39 @@ func newDockerSandbox(ctx context.Context, name string, cfg Config, workspace st if cfg.ReadOnlyRoot { args = append(args, "--read-only") } + // Default tmpfs flags. `noexec` blocks running binaries extracted to + // /tmp at runtime (e.g. staticx-wrapped CLIs). Operators who need + // that capability can opt out via Config.AllowTmpExec; we then mount + // with an explicit `exec` flag because Docker silently re-adds + // `noexec` to tmpfs mounts whose flag list doesn't override it. + // `nosuid` + `nodev` stay either way. + baseTmpfsOpts := "noexec,nosuid,nodev" + if cfg.AllowTmpExec { + baseTmpfsOpts = "exec,nosuid,nodev" + } for _, t := range cfg.Tmpfs { if !strings.Contains(t, ":") { // Always add security flags; optionally add size limit - opts := "noexec,nosuid,nodev" + opts := baseTmpfsOpts if cfg.TmpfsSizeMB > 0 { opts = fmt.Sprintf("size=%dm,%s", cfg.TmpfsSizeMB, opts) } t = fmt.Sprintf("%s:%s", t, opts) + } else if cfg.AllowTmpExec { + // User-specified options + opt-out: strip any `noexec` + // the user passed and inject explicit `exec` so Docker + // doesn't re-add the default. Also enforce the security + // floor (nosuid, nodev). + t = stripTmpfsOpt(t, "noexec") + if !strings.Contains(t, "exec") { + t += ",exec" + } + if !strings.Contains(t, "nosuid") { + t += ",nosuid" + } + if !strings.Contains(t, "nodev") { + t += ",nodev" + } } else if !strings.Contains(t, "noexec") { // User-specified options but missing noexec — append security flags t += ",noexec,nosuid,nodev" @@ -98,6 +123,9 @@ func newDockerSandbox(ctx context.Context, name string, cfg Config, workspace st hostPath := resolveHostWorkspacePath(ctx, workspace) args = append(args, "-v", fmt.Sprintf("%s:%s:%s", hostPath, containerWorkdir, mountOpt)) } + + // Mount data volume read-only so sandbox can access skills, config, etc. + args = append(args, "-v", "app_goclaw-data:/app/data:ro") args = append(args, "-w", containerWorkdir) // Environment variables @@ -432,6 +460,7 @@ func sanitizeKey(key string) string { "/", "-", " ", "-", ".", "-", + "@", "-", ).Replace(key) if len(safe) > 50 { @@ -468,3 +497,28 @@ func (lb *limitedBuffer) Write(p []byte) (int, error) { func (lb *limitedBuffer) String() string { return lb.buf.String() } + +// stripTmpfsOpt removes a single comma-separated option (e.g. "noexec") +// from a tmpfs spec like "/tmp:size=64m,noexec,nosuid,nodev". Leaves +// the path prefix untouched. Used by Config.AllowTmpExec to drop +// noexec from operator-specified tmpfs entries without disturbing +// other flags. +func stripTmpfsOpt(spec, opt string) string { + idx := strings.Index(spec, ":") + if idx < 0 { + return spec + } + path, opts := spec[:idx], spec[idx+1:] + parts := strings.Split(opts, ",") + kept := parts[:0] + for _, p := range parts { + if p == opt { + continue + } + kept = append(kept, p) + } + if len(kept) == 0 { + return path + } + return path + ":" + strings.Join(kept, ",") +} diff --git a/internal/sandbox/docker_test.go b/internal/sandbox/docker_test.go index c402e22b6c..d9bfa88b10 100644 --- a/internal/sandbox/docker_test.go +++ b/internal/sandbox/docker_test.go @@ -97,6 +97,7 @@ func TestSanitizeKey(t *testing.T) { {"simple", "simple"}, {"has/slash", "has-slash"}, {"has space", "has-space"}, + {"agent:chloe:whatsapp:551152861098:5@s.whatsapp.net", "agent-chloe-whatsapp-551152861098-5-s-whatsapp-net"}, {strings.Repeat("x", 100), strings.Repeat("x", 50)}, } for _, tc := range tests { diff --git a/internal/sandbox/sandbox.go b/internal/sandbox/sandbox.go index 523f5f51b5..fce10323ec 100644 --- a/internal/sandbox/sandbox.go +++ b/internal/sandbox/sandbox.go @@ -67,8 +67,9 @@ type Config struct { // Security hardening (matching TS buildSandboxCreateArgs) ReadOnlyRoot bool `json:"read_only_root"` CapDrop []string `json:"cap_drop,omitempty"` - Tmpfs []string `json:"tmpfs,omitempty"` // e.g. "/tmp", "/tmp:size=64m" - TmpfsSizeMB int `json:"tmpfs_size_mb,omitempty"` // default size for tmpfs mounts without explicit :size= (0 = Docker default) + Tmpfs []string `json:"tmpfs,omitempty"` // e.g. "/tmp", "/tmp:size=64m" + TmpfsSizeMB int `json:"tmpfs_size_mb,omitempty"` // default size for tmpfs mounts without explicit :size= (0 = Docker default) + AllowTmpExec bool `json:"allow_tmp_exec,omitempty"` // drop `noexec` from tmpfs mounts (still keeps nosuid+nodev). Required by some bundled-binary CLIs (e.g. staticx-wrapped gam) that extract+exec from /tmp at runtime. PidsLimit int `json:"pids_limit,omitempty"` User string `json:"user,omitempty"` // container user (e.g. "1000:1000", "nobody") MaxOutputBytes int `json:"max_output_bytes,omitempty"` // limit exec stdout+stderr capture (default 1MB, 0 = unlimited) diff --git a/internal/store/pg/secure_cli.go b/internal/store/pg/secure_cli.go index ec4a481cdb..79472b55e6 100644 --- a/internal/store/pg/secure_cli.go +++ b/internal/store/pg/secure_cli.go @@ -27,12 +27,12 @@ func NewPGSecureCLIStore(db *sql.DB, encryptionKey string) *PGSecureCLIStore { } const secureCLISelectCols = `id, binary_name, binary_path, description, encrypted_env, - deny_args, deny_verbose, timeout_seconds, tips, is_global, enabled, created_by, created_at, updated_at` + deny_args, deny_verbose, timeout_seconds, tips, is_global, allow_chain_exec, enabled, created_by, created_at, updated_at` // secureCLISelectColsAliased is prefixed with table alias "b." // Required for LookupByBinary which uses LEFT JOIN (ambiguous column names without prefix). const secureCLISelectColsAliased = `b.id, b.binary_name, b.binary_path, b.description, b.encrypted_env, - b.deny_args, b.deny_verbose, b.timeout_seconds, b.tips, b.is_global, b.enabled, b.created_by, b.created_at, b.updated_at` + b.deny_args, b.deny_verbose, b.timeout_seconds, b.tips, b.is_global, b.allow_chain_exec, b.enabled, b.created_by, b.created_at, b.updated_at` func (s *PGSecureCLIStore) Create(ctx context.Context, b *store.SecureCLIBinary) error { if err := store.ValidateUserID(b.CreatedBy); err != nil { @@ -69,13 +69,13 @@ func (s *PGSecureCLIStore) Create(ctx context.Context, b *store.SecureCLIBinary) _, err := s.db.ExecContext(ctx, `INSERT INTO secure_cli_binaries (id, binary_name, binary_path, description, encrypted_env, - deny_args, deny_verbose, timeout_seconds, tips, is_global, enabled, created_by, created_at, updated_at, tenant_id) - VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15)`, + deny_args, deny_verbose, timeout_seconds, tips, is_global, allow_chain_exec, enabled, created_by, created_at, updated_at, tenant_id) + VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16)`, b.ID, b.BinaryName, nilStr(derefStr(b.BinaryPath)), b.Description, envBytes, jsonOrEmptyArray(b.DenyArgs), jsonOrEmptyArray(b.DenyVerbose), b.TimeoutSeconds, b.Tips, - b.IsGlobal, b.Enabled, + b.IsGlobal, b.AllowChainExec, b.Enabled, b.CreatedBy, now, now, tenantID, ) return err @@ -105,7 +105,7 @@ func (s *PGSecureCLIStore) scanRow(row *sql.Row) (*store.SecureCLIBinary, error) err := row.Scan( &b.ID, &b.BinaryName, &binaryPath, &b.Description, &env, &denyArgs, &denyVerbose, - &b.TimeoutSeconds, &b.Tips, &b.IsGlobal, + &b.TimeoutSeconds, &b.Tips, &b.IsGlobal, &b.AllowChainExec, &b.Enabled, &b.CreatedBy, &b.CreatedAt, &b.UpdatedAt, ) if err != nil { @@ -147,7 +147,7 @@ func (s *PGSecureCLIStore) scanRows(rows *sql.Rows) ([]store.SecureCLIBinary, er if err := rows.Scan( &b.ID, &b.BinaryName, &binaryPath, &b.Description, &env, &denyArgs, &denyVerbose, - &b.TimeoutSeconds, &b.Tips, &b.IsGlobal, + &b.TimeoutSeconds, &b.Tips, &b.IsGlobal, &b.AllowChainExec, &b.Enabled, &b.CreatedBy, &b.CreatedAt, &b.UpdatedAt, ); err != nil { continue @@ -177,7 +177,7 @@ func (s *PGSecureCLIStore) scanRows(rows *sql.Rows) ([]store.SecureCLIBinary, er var secureCLIAllowedFields = map[string]bool{ "binary_name": true, "binary_path": true, "description": true, "encrypted_env": true, "deny_args": true, "deny_verbose": true, - "timeout_seconds": true, "tips": true, "is_global": true, "enabled": true, + "timeout_seconds": true, "tips": true, "is_global": true, "allow_chain_exec": true, "enabled": true, "updated_at": true, } @@ -344,7 +344,7 @@ func (s *PGSecureCLIStore) scanRowWithGrantAndUserEnv(row *sql.Row) (*store.Secu err := row.Scan( &b.ID, &b.BinaryName, &binaryPath, &b.Description, &env, &denyArgs, &denyVerbose, - &b.TimeoutSeconds, &b.Tips, &b.IsGlobal, + &b.TimeoutSeconds, &b.Tips, &b.IsGlobal, &b.AllowChainExec, &b.Enabled, &b.CreatedBy, &b.CreatedAt, &b.UpdatedAt, // Grant columns &grantDenyArgs, &grantDenyVerbose, &grantTimeout, &grantTips, &grantEnabled, &grantID, @@ -498,7 +498,7 @@ func (s *PGSecureCLIStore) ListForAgent(ctx context.Context, agentID uuid.UUID) if err := rows.Scan( &b.ID, &b.BinaryName, &binaryPath, &b.Description, &env, &denyArgs, &denyVerbose, - &b.TimeoutSeconds, &b.Tips, &b.IsGlobal, + &b.TimeoutSeconds, &b.Tips, &b.IsGlobal, &b.AllowChainExec, &b.Enabled, &b.CreatedBy, &b.CreatedAt, &b.UpdatedAt, &grantDenyArgs, &grantDenyVerbose, &grantTimeout, &grantTips, &grantID, ); err != nil { diff --git a/internal/store/secure_cli_store.go b/internal/store/secure_cli_store.go index aa846f2f62..10a909bf69 100644 --- a/internal/store/secure_cli_store.go +++ b/internal/store/secure_cli_store.go @@ -21,6 +21,7 @@ type SecureCLIBinary struct { TimeoutSeconds int `json:"timeout_seconds" db:"timeout_seconds"` Tips string `json:"tips" db:"tips"` // hint injected into TOOLS.md context IsGlobal bool `json:"is_global" db:"is_global"` + AllowChainExec bool `json:"allow_chain_exec" db:"allow_chain_exec"` Enabled bool `json:"enabled" db:"enabled"` CreatedBy string `json:"created_by" db:"created_by"` UserEnv []byte `json:"-" db:"-"` // per-user encrypted env (populated by LookupByBinary LEFT JOIN) diff --git a/internal/store/sqlitestore/secure-cli.go b/internal/store/sqlitestore/secure-cli.go index ac2ce1996e..d0895cb944 100644 --- a/internal/store/sqlitestore/secure-cli.go +++ b/internal/store/sqlitestore/secure-cli.go @@ -31,7 +31,7 @@ func NewSQLiteSecureCLIStore(db *sql.DB, encKey string) *SQLiteSecureCLIStore { } const secureCLISelectCols = `id, binary_name, binary_path, description, encrypted_env, - deny_args, deny_verbose, timeout_seconds, tips, is_global, enabled, created_by, created_at, updated_at` + deny_args, deny_verbose, timeout_seconds, tips, is_global, allow_chain_exec, enabled, created_by, created_at, updated_at` const secureCLISelectColsAliased = `b.id, b.binary_name, b.binary_path, b.description, b.encrypted_env, b.deny_args, b.deny_verbose, b.timeout_seconds, b.tips, b.is_global, b.enabled, b.created_by, b.created_at, b.updated_at` @@ -71,13 +71,13 @@ func (s *SQLiteSecureCLIStore) Create(ctx context.Context, b *store.SecureCLIBin _, err := s.db.ExecContext(ctx, `INSERT INTO secure_cli_binaries (id, binary_name, binary_path, description, encrypted_env, - deny_args, deny_verbose, timeout_seconds, tips, is_global, enabled, created_by, created_at, updated_at, tenant_id) - VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)`, + deny_args, deny_verbose, timeout_seconds, tips, is_global, allow_chain_exec, enabled, created_by, created_at, updated_at, tenant_id) + VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)`, b.ID, b.BinaryName, nilStr(derefStr(b.BinaryPath)), b.Description, envBytes, jsonOrEmptyArray(b.DenyArgs), jsonOrEmptyArray(b.DenyVerbose), b.TimeoutSeconds, b.Tips, - b.IsGlobal, b.Enabled, + b.IsGlobal, b.AllowChainExec, b.Enabled, b.CreatedBy, nowStr, nowStr, tenantID, ) return err @@ -108,7 +108,7 @@ func (s *SQLiteSecureCLIStore) scanRow(row *sql.Row) (*store.SecureCLIBinary, er err := row.Scan( &b.ID, &b.BinaryName, &binaryPath, &b.Description, &env, &denyArgs, &denyVerbose, - &b.TimeoutSeconds, &b.Tips, &b.IsGlobal, + &b.TimeoutSeconds, &b.Tips, &b.IsGlobal, &b.AllowChainExec, &b.Enabled, &b.CreatedBy, &createdAt, &updatedAt, ) if err != nil { @@ -153,7 +153,7 @@ func (s *SQLiteSecureCLIStore) scanRows(rows *sql.Rows) ([]store.SecureCLIBinary if err := rows.Scan( &b.ID, &b.BinaryName, &binaryPath, &b.Description, &env, &denyArgs, &denyVerbose, - &b.TimeoutSeconds, &b.Tips, &b.IsGlobal, + &b.TimeoutSeconds, &b.Tips, &b.IsGlobal, &b.AllowChainExec, &b.Enabled, &b.CreatedBy, &createdAt, &updatedAt, ); err != nil { return nil, fmt.Errorf("scan secure_cli_binaries row: %w", err) @@ -185,7 +185,7 @@ func (s *SQLiteSecureCLIStore) scanRows(rows *sql.Rows) ([]store.SecureCLIBinary var secureCLIAllowedFields = map[string]bool{ "binary_name": true, "binary_path": true, "description": true, "encrypted_env": true, "deny_args": true, "deny_verbose": true, - "timeout_seconds": true, "tips": true, "is_global": true, "enabled": true, + "timeout_seconds": true, "tips": true, "is_global": true, "allow_chain_exec": true, "enabled": true, "updated_at": true, } @@ -345,7 +345,7 @@ func (s *SQLiteSecureCLIStore) scanRowWithGrantAndUserEnv(row *sql.Row) (*store. err := row.Scan( &b.ID, &b.BinaryName, &binaryPath, &b.Description, &env, &denyArgs, &denyVerbose, - &b.TimeoutSeconds, &b.Tips, &b.IsGlobal, + &b.TimeoutSeconds, &b.Tips, &b.IsGlobal, &b.AllowChainExec, &b.Enabled, &b.CreatedBy, &createdAt, &updatedAt, &grantDenyArgs, &grantDenyVerbose, &grantTimeout, &grantTips, &grantEnabled, &grantID, &userEnv, @@ -500,7 +500,7 @@ func (s *SQLiteSecureCLIStore) ListForAgent(ctx context.Context, agentID uuid.UU if err := rows.Scan( &b.ID, &b.BinaryName, &binaryPath, &b.Description, &env, &denyArgs, &denyVerbose, - &b.TimeoutSeconds, &b.Tips, &b.IsGlobal, + &b.TimeoutSeconds, &b.Tips, &b.IsGlobal, &b.AllowChainExec, &b.Enabled, &b.CreatedBy, &createdAt, &updatedAt, &grantDenyArgs, &grantDenyVerbose, &grantTimeout, &grantTips, &grantID, ); err != nil { diff --git a/internal/tools/credential_context.go b/internal/tools/credential_context.go index 0d5d3b48da..8797292ef1 100644 --- a/internal/tools/credential_context.go +++ b/internal/tools/credential_context.go @@ -32,7 +32,11 @@ func GenerateCredentialContext(creds []store.SecureCLIBinary) string { b.WriteString("### Available CLIs:\n\n") for _, c := range creds { - b.WriteString(fmt.Sprintf("**%s** — %s\n", c.BinaryName, c.Description)) + if c.AllowChainExec { + b.WriteString(fmt.Sprintf("**%s** — %s (chain exec: credentials injected even in shell chains)\n", c.BinaryName, c.Description)) + } else { + b.WriteString(fmt.Sprintf("**%s** — %s\n", c.BinaryName, c.Description)) + } if blocked := summarizeDenyPatterns(c.DenyArgs); blocked != "" { b.WriteString(fmt.Sprintf(" Blocked: %s\n", blocked)) } diff --git a/internal/tools/credentialed_exec.go b/internal/tools/credentialed_exec.go index 55fe295cc4..b126f8afc0 100644 --- a/internal/tools/credentialed_exec.go +++ b/internal/tools/credentialed_exec.go @@ -665,3 +665,250 @@ func credentialedExecFailError(binary string, args []string, exitCode int, outpu IsError: true, } } + +// detectCredentialedBinaryInChain scans a command that contains shell operators +// for any token that matches a registered credentialed binary. Returns the +// binary name if found, empty string otherwise. This catches cases where the +// LLM wraps a credentialed CLI in a shell chain (e.g. "which gh && gh pr list") +// — the first binary ("which") is not credentialed so lookupCredentialedBinary +// misses it, but "gh" deeper in the chain IS credentialed and would run without +// token injection if allowed to fall through to regular exec. +// +// Uses extractUnquotedSegments (quote-aware) so that operators inside quoted +// arguments (e.g. --jq '.[0] | .name') are not mistaken for command chains. +func (t *ExecTool) detectCredentialedBinaryInChain(ctx context.Context, command string) string { + if t.secureCLIStore == nil { + return "" + } + // Only check commands that have shell operators outside of quotes. + // detectUnquotedShellOperators already uses extractUnquotedSegments + // internally, so quoted pipes/semicolons are ignored. + if ops := detectUnquotedShellOperators(command); len(ops) == 0 { + return "" + } + // Extract only the unquoted portions of the command. This collapses + // quoted strings so that operators inside quotes disappear entirely, + // preventing false splits on e.g. '.[0] | .name'. + unquoted := extractUnquotedSegments(command) + // Split the unquoted text on shell operator characters. This is safe + // because extractUnquotedSegments already stripped all quoted content. + segments := shellOperatorPattern.Split(unquoted, -1) + for _, seg := range segments { + seg = strings.TrimSpace(seg) + if seg == "" { + continue + } + // Use go-shellwords to parse the segment's first token, handling + // edge cases like leading whitespace or escaped characters. + parser := shellwords.NewParser() + parser.ParseBacktick = false + parser.ParseEnv = false + words, err := parser.Parse(seg) + if err != nil || len(words) == 0 { + // Fallback to simple field split if shellwords fails + fields := strings.Fields(seg) + if len(fields) == 0 { + continue + } + words = fields[:1] + } + binary := normalizeBinaryName(words[0]) + gctx, cancel := context.WithTimeout(ctx, 2*time.Second) + registered, rerr := t.secureCLIStore.IsRegisteredBinary(gctx, binary) + cancel() + if rerr == nil && registered { + return binary + } + } + return "" +} + +// handleCredentialedChain handles commands where a credentialed binary appears +// inside a shell operator chain (e.g. "which gh && gh pr list"). Two modes: +// +// - allow_chain_exec=false (default): returns an error telling the LLM to +// call the CLI directly without shell operators. +// - allow_chain_exec=true: injects all matching credential env vars into +// the full command and executes via shell. Less secure (tokens visible +// to all commands in the chain) but works with LLMs that habitually +// use shell operators. +// +// Returns nil if no credentialed binary is detected in the chain. +func (t *ExecTool) handleCredentialedChain(ctx context.Context, normalizedCmd, rawCmd string, args map[string]any) *Result { + if t.secureCLIStore == nil { + return nil + } + if ops := detectUnquotedShellOperators(normalizedCmd); len(ops) == 0 { + return nil + } + + // Scan all segments for credentialed binaries + unquoted := extractUnquotedSegments(normalizedCmd) + segments := shellOperatorPattern.Split(unquoted, -1) + + type chainMatch struct { + binary string + cred *store.SecureCLIBinary + } + var matches []chainMatch + anyAllowChain := false + + agentID := store.AgentIDFromContext(ctx) + var agentIDPtr *uuid.UUID + if agentID != uuid.Nil { + agentIDPtr = &agentID + } + userID := store.CredentialUserIDFromContext(ctx) + + for _, seg := range segments { + seg = strings.TrimSpace(seg) + if seg == "" { + continue + } + parser := shellwords.NewParser() + parser.ParseBacktick = false + parser.ParseEnv = false + words, err := parser.Parse(seg) + if err != nil || len(words) == 0 { + fields := strings.Fields(seg) + if len(fields) == 0 { + continue + } + words = fields[:1] + } + binary := normalizeBinaryName(words[0]) + gctx, cancel := context.WithTimeout(ctx, 2*time.Second) + cred, lookupErr := t.secureCLIStore.LookupByBinary(gctx, binary, agentIDPtr, userID) + cancel() + if lookupErr != nil || cred == nil { + continue + } + matches = append(matches, chainMatch{binary: binary, cred: cred}) + if cred.AllowChainExec { + anyAllowChain = true + } + } + + if len(matches) == 0 { + return nil + } + + // Default mode: return error telling LLM to call directly + if !anyAllowChain { + first := matches[0].binary + return &Result{ + ForLLM: fmt.Sprintf("[CREDENTIALED CLI] Command contains credentialed binary %q but uses shell operators.\n"+ + "Shell operators (; && || |) prevent credential injection.\n"+ + "Call the CLI directly as the ONLY command: exec(\"%s ...\")\n"+ + "Do NOT combine with other commands, pipes, or redirects.", first, first), + ForUser: fmt.Sprintf("Command contains %q with shell operators — call it directly.", first), + IsError: true, + } + } + + // Chain injection mode: merge all matched credential env vars and execute + // the full command via shell with credentials injected. + slog.Info("security.credentialed_chain_exec", + "binaries", len(matches), + "command_prefix", truncateCmd(normalizedCmd, 80), + "agent_id", agentID) + + envMap := make(map[string]string) + for _, m := range matches { + if len(m.cred.EncryptedEnv) > 0 { + var credEnv map[string]string + if err := json.Unmarshal(m.cred.EncryptedEnv, &credEnv); err == nil { + for k, v := range credEnv { + envMap[k] = v + } + } + } + // Merge per-user env overrides + if len(m.cred.UserEnv) > 0 { + var userEnvMap map[string]string + if err := json.Unmarshal(m.cred.UserEnv, &userEnvMap); err == nil { + for k, v := range userEnvMap { + envMap[k] = v + } + } + } + // Register for output scrubbing + for _, v := range envMap { + AddCredentialScrubValues(v) + } + } + + // Use the longest timeout from matched credentials + timeout := 30 * time.Second + for _, m := range matches { + if d := time.Duration(m.cred.TimeoutSeconds) * time.Second; d > timeout { + timeout = d + } + } + + // Resolve working directory + cwd := ToolWorkspaceFromCtx(ctx) + if cwd == "" { + cwd = t.workspace + } + if wd, _ := args["working_dir"].(string); wd != "" { + cwd = wd + } + + // Execute via sandbox or host — using shell mode (sh -c) since the command + // contains intentional shell operators. + sandboxKey := ToolSandboxKeyFromCtx(ctx) + if t.sandboxMgr != nil && sandboxKey != "" { + sb, err := t.sandboxMgr.Get(ctx, sandboxKey, t.workspace, SandboxConfigFromCtx(ctx)) + if err != nil { + return ErrorResult("credentialed chain exec requires sandbox but sandbox is unavailable: " + err.Error()) + } + result, err := sb.Exec(ctx, []string{"sh", "-c", rawCmd}, cwd, sandbox.WithEnv(envMap)) + if err != nil { + return ErrorResult(fmt.Sprintf("credentialed chain exec: %v", err)) + } + output := result.Stdout + if result.Stderr != "" { + if output != "" { + output += "\n" + } + output += "STDERR:\n" + result.Stderr + } + if result.ExitCode != 0 { + return credentialedExecFailError("sh -c ", []string{truncateCmd(rawCmd, 80)}, result.ExitCode, ScrubCredentials(output)) + } + if output == "" { + output = "(command completed with no output)" + } + return SilentResult(capExecOutput(ScrubCredentials(output), execMaxOutputChars)) + } + + // Host execution with shell + ctx2, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + cmd := exec.Command("sh", "-c", rawCmd) + cmd.Dir = cwd + setProcessGroup(cmd) + cmd.Env = buildCredentialedEnv(envMap) + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + if err := cmd.Start(); err != nil { + return ErrorResult(fmt.Sprintf("credentialed chain exec: failed to start: %v", err)) + } + done := make(chan error, 1) + go func() { done <- cmd.Wait() }() + select { + case err := <-done: + return formatCredentialedResult("sh -c ", []string{truncateCmd(rawCmd, 80)}, stdout.String(), stderr.String(), err, ctx2, timeout) + case <-ctx2.Done(): + _ = killProcessGroup(cmd, syscallSIGTERM) + select { + case <-done: + case <-time.After(3 * time.Second): + _ = killProcessGroup(cmd, syscallSIGKILL) + <-done + } + return ErrorResult(fmt.Sprintf("[CREDENTIALED CHAIN EXEC] Command timed out after %s.", timeout)) + } +} diff --git a/internal/tools/shell.go b/internal/tools/shell.go index 5611108b64..31f2067e11 100644 --- a/internal/tools/shell.go +++ b/internal/tools/shell.go @@ -301,6 +301,13 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *Result { return t.executeCredentialed(ctx, cred, binary, cmdArgs, cwd, sandboxKey, command) } + // Chain detection: credentialed binary found deeper in a shell operator chain. + // If allow_chain_exec is enabled for any matched binary, inject credentials + // into the full command. Otherwise return an actionable error. + if chainResult := t.handleCredentialedChain(ctx, normalizedCommand, command, args); chainResult != nil { + return chainResult + } + // Secure CLI gate: registered-but-not-granted binaries MUST NOT fall through // to host exec with parent env. Works on the already-normalized command // (Red Team F6) and unwraps shell wrappers up to depth 3 (Red Team F1). diff --git a/internal/upgrade/version.go b/internal/upgrade/version.go index 83df8b2e0c..fc18492ddf 100644 --- a/internal/upgrade/version.go +++ b/internal/upgrade/version.go @@ -2,4 +2,4 @@ package upgrade // RequiredSchemaVersion is the schema migration version this binary requires. // Bump this whenever adding a new SQL migration file. -const RequiredSchemaVersion uint = 56 +const RequiredSchemaVersion uint = 57 diff --git a/migrations/000057_secure_cli_allow_chain_exec.down.sql b/migrations/000057_secure_cli_allow_chain_exec.down.sql new file mode 100644 index 0000000000..4bf551f395 --- /dev/null +++ b/migrations/000057_secure_cli_allow_chain_exec.down.sql @@ -0,0 +1 @@ +ALTER TABLE secure_cli_binaries DROP COLUMN IF EXISTS allow_chain_exec; diff --git a/migrations/000057_secure_cli_allow_chain_exec.up.sql b/migrations/000057_secure_cli_allow_chain_exec.up.sql new file mode 100644 index 0000000000..7f67cd306a --- /dev/null +++ b/migrations/000057_secure_cli_allow_chain_exec.up.sql @@ -0,0 +1 @@ +ALTER TABLE secure_cli_binaries ADD COLUMN IF NOT EXISTS allow_chain_exec BOOLEAN NOT NULL DEFAULT false; diff --git a/ui/web/src/i18n/locales/en/cli-credentials.json b/ui/web/src/i18n/locales/en/cli-credentials.json index 99ac5c724b..92ffd45825 100644 --- a/ui/web/src/i18n/locales/en/cli-credentials.json +++ b/ui/web/src/i18n/locales/en/cli-credentials.json @@ -46,6 +46,8 @@ "binaryNotFound": "Binary not found in server PATH", "checking": "Checking...", "isGlobal": "Available to all agents", + "allowChainExec": "Allow chain execution", + "allowChainExecHint": "Inject credentials into shell command chains (e.g. which gh && gh pr list). Less secure but works with LLMs that use shell operators.", "isGlobalHint": "When off, only agents with explicit grants can use this CLI" }, "placeholders": { diff --git a/ui/web/src/pages/cli-credentials/cli-credential-form-dialog.tsx b/ui/web/src/pages/cli-credentials/cli-credential-form-dialog.tsx index e7cfc08e75..2ef2f1ca3d 100644 --- a/ui/web/src/pages/cli-credentials/cli-credential-form-dialog.tsx +++ b/ui/web/src/pages/cli-credentials/cli-credential-form-dialog.tsx @@ -62,6 +62,7 @@ export function CliCredentialFormDialog({ open, onOpenChange, credential, preset timeout: 30, tips: "", isGlobal: true, + allowChainExec: false, enabled: true, }, }); @@ -79,6 +80,7 @@ export function CliCredentialFormDialog({ open, onOpenChange, credential, preset timeout: credential?.timeout_seconds ?? 30, tips: credential?.tips ?? "", isGlobal: credential?.is_global ?? true, + allowChainExec: credential?.allow_chain_exec ?? false, enabled: credential?.enabled ?? true, }); setEnvValues({}); @@ -183,6 +185,7 @@ export function CliCredentialFormDialog({ open, onOpenChange, credential, preset timeout_seconds: values.timeout, tips: values.tips?.trim() ?? "", is_global: values.isGlobal, + allow_chain_exec: values.allowChainExec, enabled: values.enabled, }; if (selectedPreset !== NONE_PRESET) payload.preset = selectedPreset; diff --git a/ui/web/src/pages/cli-credentials/cli-credential-scope-fields.tsx b/ui/web/src/pages/cli-credentials/cli-credential-scope-fields.tsx index aaf8d58a3f..c1900be73b 100644 --- a/ui/web/src/pages/cli-credentials/cli-credential-scope-fields.tsx +++ b/ui/web/src/pages/cli-credentials/cli-credential-scope-fields.tsx @@ -31,6 +31,21 @@ export function CliCredentialScopeFields({ form }: CliCredentialScopeFieldsProps /> + {/* Allow chain exec */} +
+
+ +

{t("form.allowChainExecHint")}

+
+ ( + + )} + /> +
+
; }