Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,16 @@ NEO4J_DATABASE=neo4j
NEO4J_PASSWORD=devpassword # change this to improve security
NEO4J_URI=bolt://neo4j:7687

# === SAGE Persistent Memory System ===
## Set SAGE_ENABLED=true and SAGE_URL=http://sage:9000 to enable SAGE persistent memory
## Run with: docker compose -f docker-compose.yml -f docker-compose-sage.yml up -d
SAGE_ENABLED=false
SAGE_TIMEOUT=30
SAGE_URL=
# Path to persistent Ed25519 identity file (auto-generated if empty)
SAGE_KEY_PATH=
SAGE_BOT_NAME=pentagi

## PentAGI image settings
PENTAGI_IMAGE=

Expand Down
33 changes: 33 additions & 0 deletions backend/cmd/ftester/mocks/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,39 @@ func MockResponse(funcName string, args json.RawMessage) (string, error) {

resultObj = builder.String()

case tools.SageRecallToolName:
var recallArgs tools.SageRecallAction
if err := json.Unmarshal(args, &recallArgs); err != nil {
return "", fmt.Errorf("error unmarshaling sage recall arguments: %w", err)
}

terminal.PrintMock("SAGE Recall:")
terminal.PrintKeyValue("Query", recallArgs.Query)
terminal.PrintKeyValue("Domain", recallArgs.Domain)

var builder strings.Builder
builder.WriteString("# SAGE Cross-Session Memory Results\n\n")
builder.WriteString(fmt.Sprintf("**Query:** %s\n\n", recallArgs.Query))
builder.WriteString("## Memory 1 (confidence: 0.85, type: lesson)\n")
builder.WriteString("**Domain:** general\n")
builder.WriteString("**Stored:** 2025-01-19T14:00:00Z\n\n")
builder.WriteString("Mock cross-session memory: Always verify target scope before scanning.\n")
builder.WriteString("---------------------------\n")
resultObj = builder.String()

case tools.SageRememberToolName:
var rememberArgs tools.SageRememberAction
if err := json.Unmarshal(args, &rememberArgs); err != nil {
return "", fmt.Errorf("error unmarshaling sage remember arguments: %w", err)
}

terminal.PrintMock("SAGE Remember:")
terminal.PrintKeyValue("Domain", rememberArgs.Domain)
terminal.PrintKeyValue("Memory Type", rememberArgs.MemoryType)
terminal.PrintKeyValueFormat("Content length", "%d chars", len(rememberArgs.Content))

resultObj = "# SAGE Memory Stored\n\nKnowledge successfully submitted to SAGE persistent memory.\n**Memory ID:** mock-memory-id-123\n**Status:** Submitted for BFT consensus\n"

default:
terminal.PrintMock("Generic mock response:")
terminal.PrintKeyValue("Function", funcName)
Expand Down
2 changes: 2 additions & 0 deletions backend/cmd/ftester/worker/args.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,8 @@ func getStructTypeForFunction(funcName string) (reflect.Type, error) {
tools.SearchCodeToolName: &tools.SearchCodeAction{},
tools.StoreCodeToolName: &tools.StoreCodeAction{},
tools.GraphitiSearchToolName: &tools.GraphitiSearchAction{},
tools.SageRecallToolName: &tools.SageRecallAction{},
tools.SageRememberToolName: &tools.SageRememberAction{},
tools.SearchToolName: &tools.ComplexSearch{},
tools.MaintenanceToolName: &tools.MaintenanceAction{},
tools.CoderToolName: &tools.CoderAction{},
Expand Down
12 changes: 12 additions & 0 deletions backend/cmd/ftester/worker/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"pentagi/pkg/graphiti"
"pentagi/pkg/providers"
"pentagi/pkg/providers/embeddings"
"pentagi/pkg/sage"
"pentagi/pkg/terminal"
"pentagi/pkg/tools"

Expand Down Expand Up @@ -46,6 +47,7 @@ type toolExecutor struct {
handlers providers.FlowProviderHandlers
store *pgvector.Store
graphitiClient *graphiti.Client
sageClient *sage.Client
proxies mocks.ProxyProviders
flowID int64
taskID *int64
Expand All @@ -64,6 +66,7 @@ func newToolExecutor(
taskID, subtaskID *int64,
embedder embeddings.Embedder,
graphitiClient *graphiti.Client,
sageClient *sage.Client,
) (*toolExecutor, error) {
var store *pgvector.Store
if embedder.IsAvailable() {
Expand Down Expand Up @@ -101,6 +104,7 @@ func newToolExecutor(
handlers: handlers,
store: store,
graphitiClient: graphitiClient,
sageClient: sageClient,
proxies: proxies,
flowID: flowID,
taskID: taskID,
Expand Down Expand Up @@ -271,6 +275,14 @@ func (te *toolExecutor) GetTool(ctx context.Context, funcName string) (tools.Too
te.graphitiClient,
), nil

case tools.SageRecallToolName, tools.SageRememberToolName:
return tools.NewSageSearchTool(
te.flowID,
te.taskID,
te.subtaskID,
te.sageClient,
), nil

// AI Agent tools
case tools.AdviceToolName:
var handler tools.ExecutorHandler
Expand Down
2 changes: 2 additions & 0 deletions backend/cmd/ftester/worker/tester.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,14 @@ func NewTester(
flowExecutor.SetTermLogProvider(proxies.GetTermLogProvider())
flowExecutor.SetVectorStoreLogProvider(proxies.GetVectorStoreLogProvider())
flowExecutor.SetGraphitiClient(providerController.GraphitiClient())
flowExecutor.SetSageClient(providerController.SageClient())

// Initialize tool executor
toolExecutor, err := newToolExecutor(
flowExecutor, cfg, db, dockerClient, nil, proxies,
flowID, taskID, subtaskID, providerController.Embedder(),
providerController.GraphitiClient(),
providerController.SageClient(),
)
if err != nil {
return nil, fmt.Errorf("failed to create tool executor: %w", err)
Expand Down
8 changes: 8 additions & 0 deletions backend/pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,13 @@ type Config struct {

// === Agent Planning Phase Configuration ===
AgentPlanningStepEnabled bool `env:"AGENT_PLANNING_STEP_ENABLED" envDefault:"false"`

// === SAGE Persistent Memory System ===
SAGEEnabled bool `env:"SAGE_ENABLED" envDefault:"false"`
SAGETimeout int `env:"SAGE_TIMEOUT" envDefault:"30"`
SAGEURL string `env:"SAGE_URL"`
SAGEKeyPath string `env:"SAGE_KEY_PATH" envDefault:""`
SAGEBotName string `env:"SAGE_BOT_NAME" envDefault:"pentagi"`
}

func NewConfig() (*Config, error) {
Expand Down Expand Up @@ -322,6 +329,7 @@ func (c *Config) GetSecretPatterns() []patterns.Pattern {
{c.ProxyURL, "Proxy URL"},
{c.LangfusePublicKey, "Langfuse Public Key"},
{c.LangfuseSecretKey, "Langfuse Secret Key"},
// SAGEKeyPath is a file path, not a secret — omitted from patterns.
}

for _, s := range secrets {
Expand Down
53 changes: 53 additions & 0 deletions backend/pkg/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ func clearConfigEnv(t *testing.T) {
"EXECUTION_MONITOR_ENABLED", "EXECUTION_MONITOR_SAME_TOOL_LIMIT", "EXECUTION_MONITOR_TOTAL_TOOL_LIMIT",
"MAX_GENERAL_AGENT_TOOL_CALLS", "MAX_LIMITED_AGENT_TOOL_CALLS",
"AGENT_PLANNING_STEP_ENABLED",
"SAGE_ENABLED", "SAGE_TIMEOUT", "SAGE_URL", "SAGE_KEY_PATH", "SAGE_BOT_NAME",
}
for _, v := range envVars {
t.Setenv(v, "")
Expand Down Expand Up @@ -584,3 +585,55 @@ func TestNewConfig_AgentSupervisionOverride(t *testing.T) {
assert.Equal(t, 30, config.MaxLimitedAgentToolCalls)
assert.Equal(t, true, config.AgentPlanningStepEnabled)
}

// --- SAGE config tests ---

func TestSAGEConfigDefaults(t *testing.T) {
clearConfigEnv(t)
t.Chdir(t.TempDir())

config, err := NewConfig()
require.NoError(t, err)

assert.Equal(t, false, config.SAGEEnabled)
assert.Equal(t, 30, config.SAGETimeout)
assert.Equal(t, "", config.SAGEURL)
assert.Equal(t, "", config.SAGEKeyPath)
assert.Equal(t, "pentagi", config.SAGEBotName)
}

func TestSAGEConfigFromEnv(t *testing.T) {
clearConfigEnv(t)
t.Chdir(t.TempDir())

t.Setenv("SAGE_ENABLED", "true")
t.Setenv("SAGE_URL", "http://sage:8080")
t.Setenv("SAGE_KEY_PATH", "/tmp/test.key")
t.Setenv("SAGE_BOT_NAME", "mybot")
t.Setenv("SAGE_TIMEOUT", "60")

config, err := NewConfig()
require.NoError(t, err)

assert.Equal(t, true, config.SAGEEnabled)
assert.Equal(t, "http://sage:8080", config.SAGEURL)
assert.Equal(t, "/tmp/test.key", config.SAGEKeyPath)
assert.Equal(t, "mybot", config.SAGEBotName)
assert.Equal(t, 60, config.SAGETimeout)
}

func TestSAGEKeyPathNotInSecretPatterns(t *testing.T) {
cfg := &Config{
SAGEKeyPath: "/some/path/to/agent.key",
SAGEBotName: "pentagi",
SAGEURL: "http://sage:8080",
}

patterns := cfg.GetSecretPatterns()

for _, p := range patterns {
if p.Name == "SAGE" || p.Name == "SAGE Key Path" || p.Name == "SAGEKeyPath" {
t.Errorf("SAGEKeyPath should not appear in secret patterns, found pattern named %q", p.Name)
}
}
}
1 change: 1 addition & 0 deletions backend/pkg/controller/assistant.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ func NewAssistantWorker(ctx context.Context, awc newAssistantWorkerCtx) (Assista
executor.SetTermLogProvider(workers.tlw)
executor.SetVectorStoreLogProvider(workers.vslw)
executor.SetGraphitiClient(awc.provs.GraphitiClient())
executor.SetSageClient(awc.provs.SageClient())

ctx, cancel := context.WithCancel(context.Background())
ctx, _ = obs.Observer.NewObservation(ctx, langfuse.WithObservationTraceID(observation.TraceID()))
Expand Down
2 changes: 2 additions & 0 deletions backend/pkg/controller/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ func NewFlowWorker(
executor.SetTermLogProvider(workers.tlw)
executor.SetVectorStoreLogProvider(workers.vslw)
executor.SetGraphitiClient(fwc.provs.GraphitiClient())
executor.SetSageClient(fwc.provs.SageClient())

flowCtx := &FlowContext{
DB: fwc.db,
Expand Down Expand Up @@ -364,6 +365,7 @@ func LoadFlowWorker(ctx context.Context, flow database.Flow, fwc flowWorkerCtx)
executor.SetTermLogProvider(workers.tlw)
executor.SetVectorStoreLogProvider(workers.vslw)
executor.SetGraphitiClient(fwc.provs.GraphitiClient())
executor.SetSageClient(fwc.provs.SageClient())

flowCtx := &FlowContext{
DB: fwc.db,
Expand Down
10 changes: 10 additions & 0 deletions backend/pkg/providers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ func (fp *flowProvider) GetAskAdviceHandler(ctx context.Context, taskID, subtask
"SearchInMemoryToolName": tools.SearchInMemoryToolName,
"GraphitiEnabled": fp.graphitiClient != nil && fp.graphitiClient.IsEnabled(),
"GraphitiSearchToolName": tools.GraphitiSearchToolName,
"SAGEEnabled": fp.sageClient != nil && fp.sageClient.IsEnabled(),
"SageRecallToolName": tools.SageRecallToolName,
"FileToolName": tools.FileToolName,
"TerminalToolName": tools.TerminalToolName,
"BrowserToolName": tools.BrowserToolName,
Expand Down Expand Up @@ -266,6 +268,9 @@ func (fp *flowProvider) GetCoderHandler(ctx context.Context, taskID, subtaskID *
"StoreCodeToolName": tools.StoreCodeToolName,
"GraphitiEnabled": fp.graphitiClient != nil && fp.graphitiClient.IsEnabled(),
"GraphitiSearchToolName": tools.GraphitiSearchToolName,
"SAGEEnabled": fp.sageClient != nil && fp.sageClient.IsEnabled(),
"SageRecallToolName": tools.SageRecallToolName,
"SageRememberToolName": tools.SageRememberToolName,
"SearchToolName": tools.SearchToolName,
"AdviceToolName": tools.AdviceToolName,
"MemoristToolName": tools.MemoristToolName,
Expand Down Expand Up @@ -496,6 +501,8 @@ func (fp *flowProvider) GetMemoristHandler(ctx context.Context, taskID, subtaskI
"MemoristResultToolName": tools.MemoristResultToolName,
"GraphitiEnabled": fp.graphitiClient != nil && fp.graphitiClient.IsEnabled(),
"GraphitiSearchToolName": tools.GraphitiSearchToolName,
"SAGEEnabled": fp.sageClient != nil && fp.sageClient.IsEnabled(),
"SageRecallToolName": tools.SageRecallToolName,
"TerminalToolName": tools.TerminalToolName,
"FileToolName": tools.FileToolName,
"SummarizationToolName": cast.SummarizationToolName,
Expand Down Expand Up @@ -594,6 +601,9 @@ func (fp *flowProvider) GetPentesterHandler(ctx context.Context, taskID, subtask
"StoreGuideToolName": tools.StoreGuideToolName,
"GraphitiEnabled": fp.graphitiClient != nil && fp.graphitiClient.IsEnabled(),
"GraphitiSearchToolName": tools.GraphitiSearchToolName,
"SAGEEnabled": fp.sageClient != nil && fp.sageClient.IsEnabled(),
"SageRecallToolName": tools.SageRecallToolName,
"SageRememberToolName": tools.SageRememberToolName,
"SearchToolName": tools.SearchToolName,
"CoderToolName": tools.CoderToolName,
"AdviceToolName": tools.AdviceToolName,
Expand Down
2 changes: 2 additions & 0 deletions backend/pkg/providers/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"pentagi/pkg/database"
"pentagi/pkg/graphiti"
obs "pentagi/pkg/observability"
"pentagi/pkg/sage"
"pentagi/pkg/observability/langfuse"
"pentagi/pkg/providers/embeddings"
"pentagi/pkg/providers/pconfig"
Expand Down Expand Up @@ -130,6 +131,7 @@ type flowProvider struct {

embedder embeddings.Embedder
graphitiClient *graphiti.Client
sageClient *sage.Client

flowID int64
publicIP string
Expand Down
25 changes: 25 additions & 0 deletions backend/pkg/providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"pentagi/pkg/docker"
"pentagi/pkg/graphiti"
obs "pentagi/pkg/observability"
"pentagi/pkg/sage"
"pentagi/pkg/providers/anthropic"
"pentagi/pkg/providers/bedrock"
"pentagi/pkg/providers/custom"
Expand Down Expand Up @@ -85,6 +86,7 @@ type ProviderController interface {

Embedder() embeddings.Embedder
GraphitiClient() *graphiti.Client
SageClient() *sage.Client
DefaultProviders() provider.Providers
DefaultProvidersConfig() provider.ProvidersConfig
GetProvider(
Expand Down Expand Up @@ -139,6 +141,7 @@ type providerController struct {
dockerNetwork string
embedder embeddings.Embedder
graphitiClient *graphiti.Client
sageClient *sage.Client

startCallNumber *atomic.Int64

Expand Down Expand Up @@ -354,6 +357,19 @@ func NewProviderController(
graphitiClient = &graphiti.Client{}
}

var sageClient *sage.Client
sageClient, err = sage.NewClient(
cfg.SAGEURL,
cfg.SAGEKeyPath,
cfg.SAGEBotName,
time.Duration(cfg.SAGETimeout)*time.Second,
cfg.SAGEEnabled && cfg.SAGEURL != "",
)
if err != nil {
logrus.WithError(err).Warn("failed to initialize SAGE client, continuing without persistent memory")
sageClient = nil
}

return &providerController{
db: db,
cfg: cfg,
Expand All @@ -362,6 +378,7 @@ func NewProviderController(
dockerNetwork: cfg.DockerNetwork,
embedder: embedder,
graphitiClient: graphitiClient,
sageClient: sageClient,

startCallNumber: newAtomicInt64(0), // 0 means to make it random

Expand Down Expand Up @@ -447,6 +464,7 @@ func (pc *providerController) NewFlowProvider(
mx: &sync.RWMutex{},
embedder: pc.embedder,
graphitiClient: pc.graphitiClient,
sageClient: pc.sageClient,
flowID: flowID,
publicIP: pc.publicIP,
dockerNetwork: pc.dockerNetwork,
Expand Down Expand Up @@ -497,6 +515,7 @@ func (pc *providerController) LoadFlowProvider(
mx: &sync.RWMutex{},
embedder: pc.embedder,
graphitiClient: pc.graphitiClient,
sageClient: pc.sageClient,
flowID: flowID,
publicIP: pc.publicIP,
dockerNetwork: pc.dockerNetwork,
Expand Down Expand Up @@ -533,6 +552,10 @@ func (pc *providerController) GraphitiClient() *graphiti.Client {
return pc.graphitiClient
}

func (pc *providerController) SageClient() *sage.Client {
return pc.sageClient
}

func (pc *providerController) NewAssistantProvider(
ctx context.Context,
prvname provider.ProviderName,
Expand Down Expand Up @@ -592,6 +615,7 @@ func (pc *providerController) NewAssistantProvider(
mx: &sync.RWMutex{},
embedder: pc.embedder,
graphitiClient: pc.graphitiClient,
sageClient: pc.sageClient,
flowID: flowID,
publicIP: pc.publicIP,
dockerNetwork: pc.dockerNetwork,
Expand Down Expand Up @@ -645,6 +669,7 @@ func (pc *providerController) LoadAssistantProvider(
mx: &sync.RWMutex{},
embedder: pc.embedder,
graphitiClient: pc.graphitiClient,
sageClient: pc.sageClient,
flowID: flowID,
publicIP: pc.publicIP,
dockerNetwork: pc.dockerNetwork,
Expand Down
Loading