diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 795b6e0..8760775 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -43,19 +43,3 @@ jobs: - name: Test run: go test ./... - # OpenCode integration tests - runs on main merges and PRs - integration-opencode: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - - uses: actions/setup-go@v5 - with: - go-version: '1.24' - cache: true - - - name: Build binary for integration tests - run: make build - - - name: Run OpenCode integration tests - run: go test -v ./tests/integration/... -run "TestOpenCode" diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml new file mode 100644 index 0000000..7b26239 --- /dev/null +++ b/.github/workflows/codeql.yml @@ -0,0 +1,42 @@ +name: CodeQL + +on: + push: + branches: [main] + tags-ignore: + - 'v*' + pull_request: + branches: [main] + schedule: + - cron: '0 6 * * 1' # Weekly Monday 06:00 UTC + +jobs: + analyze: + runs-on: ubuntu-latest + permissions: + security-events: write + contents: read + strategy: + fail-fast: false + matrix: + language: [go] + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-go@v5 + with: + go-version: '1.24' + cache: true + + - name: Initialize CodeQL + uses: github/codeql-action/init@v3 + with: + languages: ${{ matrix.language }} + + - name: Build + run: make build + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v3 + with: + category: '/language:${{ matrix.language }}' diff --git a/.github/workflows/trivy.yml b/.github/workflows/trivy.yml new file mode 100644 index 0000000..7a88312 --- /dev/null +++ b/.github/workflows/trivy.yml @@ -0,0 +1,71 @@ +name: Trivy Security Scan + +on: + push: + branches: [main] + tags-ignore: + - 'v*' + pull_request: + branches: [main] + schedule: + - cron: '0 6 * * 1' # Weekly Monday 06:00 UTC + +jobs: + vulnerability-scan: + runs-on: ubuntu-latest + permissions: + security-events: write + contents: read + steps: + - uses: actions/checkout@v4 + + - name: Run Trivy vulnerability scanner (filesystem) + uses: aquasecurity/trivy-action@0.34.0 + with: + scan-type: fs + scan-ref: . + severity: CRITICAL,HIGH + format: sarif + output: trivy-fs.sarif + + - name: Upload Trivy filesystem results to GitHub Security + uses: github/codeql-action/upload-sarif@v3 + if: always() + with: + sarif_file: trivy-fs.sarif + category: trivy-filesystem + + binary-scan: + runs-on: ubuntu-latest + permissions: + security-events: write + contents: read + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-go@v5 + with: + go-version: '1.24' + cache: true + + - name: Build binary + run: | + make build + mkdir -p scan-target + cp taskwing scan-target/ + + - name: Run Trivy on compiled binary + uses: aquasecurity/trivy-action@0.34.0 + with: + scan-type: rootfs + scan-ref: ./scan-target + severity: CRITICAL,HIGH + format: sarif + output: trivy-binary.sarif + + - name: Upload Trivy binary results to GitHub Security + uses: github/codeql-action/upload-sarif@v3 + if: always() + with: + sarif_file: trivy-binary.sarif + category: trivy-binary diff --git a/.taskwing.example.yaml b/.taskwing.example.yaml index 8e41d94..a0f63a9 100644 --- a/.taskwing.example.yaml +++ b/.taskwing.example.yaml @@ -29,7 +29,7 @@ data: # # Query uses cheap fast models for frequent context lookups # models: # bootstrap: "openai:gpt-5" # Used by: tw bootstrap, tw plan -# query: "gemini:gemini-2.0-flash" # Used by: tw context, recall MCP +# query: "gemini:gemini-2.0-flash" # Used by: tw context, ask MCP # # # API keys can also be set via environment variables: # # OPENAI_API_KEY, ANTHROPIC_API_KEY, GEMINI_API_KEY diff --git a/AGENTS.md b/AGENTS.md index ef96e8b..842d6c2 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -67,20 +67,22 @@ Brand names and logos are trademarks of their respective owners; usage here indi ### Slash Commands -- /tw-brief - Use when you need a compact project brief. -- /tw-next - Use when you are ready to start the next approved task. -- /tw-done - Use when implementation is verified and ready to complete. -- /tw-plan - Use when you need to clarify a goal and build a plan. -- /tw-status - Use when you need current task progress. -- /tw-debug - Use when debugging must start from root-cause evidence. -- /tw-explain - Use when you need a deep symbol explanation. -- /tw-simplify - Use when you want to simplify code without behavior changes. +- /tw-ask - Use when you need to search project knowledge (decisions, patterns, constraints). +- /tw-remember - Use when you want to persist a decision, pattern, or insight to project memory. +- /tw-next - Use when you are ready to start the next approved TaskWing task with full context. +- /tw-done - Use when implementation is verified and you are ready to complete the current task. +- /tw-status - Use when you need current task progress and acceptance criteria status. +- /tw-plan - Use when you need to clarify a goal and build an approved execution plan. +- /tw-debug - Use when an issue requires root-cause-first debugging before proposing fixes. +- /tw-explain - Use when you need a deep explanation of a code symbol and its call graph. +- /tw-simplify - Use when you want to simplify code while preserving behavior. ### Core Commands - taskwing bootstrap - taskwing goal "" +- taskwing ask "" - taskwing task - taskwing plan status - taskwing slash @@ -95,7 +97,7 @@ Brand names and logos are trademarks of their respective owners; usage here indi | Tool | Description | |------|-------------| -| recall | Retrieve project knowledge (decisions, patterns, constraints) | +| ask | Search project knowledge (decisions, patterns, constraints) | | task | Unified task lifecycle (next, current, start, complete) | | plan | Plan management (clarify, decompose, expand, generate, finalize, audit) | | code | Code intelligence (find, search, explain, callers, impact, simplify) | diff --git a/CLAUDE.md b/CLAUDE.md index a67a065..12caa32 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -142,13 +142,13 @@ Uses CloudWeGo Eino for multi-provider support: ### MCP Server -`taskwing mcp` starts a JSON-RPC stdio server exposing `recall`, `task`, `plan`, `code`, `debug`, and `remember` tools. +`taskwing mcp` starts a JSON-RPC stdio server exposing `ask`, `task`, `plan`, `code`, `debug`, and `remember` tools. ### Task Context Binding Tasks receive architectural context via a **hybrid early+late binding** approach: -1. **Early binding** (at task creation): `TaskEnricher` executes recall queries and embeds results in `Task.ContextSummary` +1. **Early binding** (at task creation): `TaskEnricher` executes ask queries and embeds results in `Task.ContextSummary` 2. **Late binding** (at display): `FormatRichContext()` uses early-bound context or fetches fresh context as fallback This ensures tasks always have relevant architecture context while maintaining backward compatibility with older tasks. @@ -323,27 +323,29 @@ Brand names and logos are trademarks of their respective owners; usage here indi ### Slash Commands -- /tw-brief - Use when you need a compact project brief. -- /tw-next - Use when you are ready to start the next approved task. -- /tw-done - Use when implementation is verified and ready to complete. -- /tw-plan - Use when you need to clarify a goal and build a plan. -- /tw-status - Use when you need current task progress. -- /tw-debug - Use when debugging must start from root-cause evidence. -- /tw-explain - Use when you need a deep symbol explanation. -- /tw-simplify - Use when you want to simplify code without behavior changes. +- /tw-ask - Use when you need to search project knowledge (decisions, patterns, constraints). +- /tw-remember - Use when you want to persist a decision, pattern, or insight to project memory. +- /tw-next - Use when you are ready to start the next approved TaskWing task with full context. +- /tw-done - Use when implementation is verified and you are ready to complete the current task. +- /tw-status - Use when you need current task progress and acceptance criteria status. +- /tw-plan - Use when you need to clarify a goal and build an approved execution plan. +- /tw-debug - Use when an issue requires root-cause-first debugging before proposing fixes. +- /tw-explain - Use when you need a deep explanation of a code symbol and its call graph. +- /tw-simplify - Use when you want to simplify code while preserving behavior. ### Core Commands -- taskwing bootstrap -- taskwing goal "" -- taskwing task -- taskwing plan status -- taskwing slash -- taskwing mcp -- taskwing doctor -- taskwing config -- taskwing start +- `taskwing bootstrap` +- `taskwing goal ""` +- `taskwing ask ""` +- `taskwing task` +- `taskwing plan status` +- `taskwing slash` +- `taskwing mcp` +- `taskwing doctor` +- `taskwing config` +- `taskwing start` ### MCP Tools (Canonical Contract) @@ -351,12 +353,12 @@ Brand names and logos are trademarks of their respective owners; usage here indi | Tool | Description | |------|-------------| -| recall | Retrieve project knowledge (decisions, patterns, constraints) | -| task | Unified task lifecycle (next, current, start, complete) | -| plan | Plan management (clarify, decompose, expand, generate, finalize, audit) | -| code | Code intelligence (find, search, explain, callers, impact, simplify) | -| debug | Diagnose issues systematically with AI-powered analysis | -| remember | Store knowledge in project memory | +| `ask` | Search project knowledge (decisions, patterns, constraints) | +| `task` | Unified task lifecycle (`next`, `current`, `start`, `complete`) | +| `plan` | Plan management (`clarify`, `decompose`, `expand`, `generate`, `finalize`, `audit`) | +| `code` | Code intelligence (`find`, `search`, `explain`, `callers`, `impact`, `simplify`) | +| `debug` | Diagnose issues systematically with AI-powered analysis | +| `remember` | Store knowledge in project memory | ### Autonomous Task Execution (Hooks) diff --git a/GEMINI.md b/GEMINI.md index e849a24..ee7cf26 100644 --- a/GEMINI.md +++ b/GEMINI.md @@ -100,7 +100,7 @@ The system is composed of a CLI tool with an embedded MCP server and a web dashb ## MCP Integration -TaskWing exposes a `recall` tool. When working on this feature: +TaskWing exposes an `ask` tool. When working on this feature: * Ensure responses stay within token budgets (500-1000 tokens). * Test with `taskwing mcp` locally or use `make test-mcp`. @@ -211,27 +211,29 @@ Brand names and logos are trademarks of their respective owners; usage here indi ### Slash Commands -- /tw-brief - Use when you need a compact project brief. -- /tw-next - Use when you are ready to start the next approved task. -- /tw-done - Use when implementation is verified and ready to complete. -- /tw-plan - Use when you need to clarify a goal and build a plan. -- /tw-status - Use when you need current task progress. -- /tw-debug - Use when debugging must start from root-cause evidence. -- /tw-explain - Use when you need a deep symbol explanation. -- /tw-simplify - Use when you want to simplify code without behavior changes. +- /tw-ask - Use when you need to search project knowledge (decisions, patterns, constraints). +- /tw-remember - Use when you want to persist a decision, pattern, or insight to project memory. +- /tw-next - Use when you are ready to start the next approved TaskWing task with full context. +- /tw-done - Use when implementation is verified and you are ready to complete the current task. +- /tw-status - Use when you need current task progress and acceptance criteria status. +- /tw-plan - Use when you need to clarify a goal and build an approved execution plan. +- /tw-debug - Use when an issue requires root-cause-first debugging before proposing fixes. +- /tw-explain - Use when you need a deep explanation of a code symbol and its call graph. +- /tw-simplify - Use when you want to simplify code while preserving behavior. ### Core Commands -- taskwing bootstrap -- taskwing goal "" -- taskwing task -- taskwing plan status -- taskwing slash -- taskwing mcp -- taskwing doctor -- taskwing config -- taskwing start +- `taskwing bootstrap` +- `taskwing goal ""` +- `taskwing ask ""` +- `taskwing task` +- `taskwing plan status` +- `taskwing slash` +- `taskwing mcp` +- `taskwing doctor` +- `taskwing config` +- `taskwing start` ### MCP Tools (Canonical Contract) @@ -239,12 +241,12 @@ Brand names and logos are trademarks of their respective owners; usage here indi | Tool | Description | |------|-------------| -| recall | Retrieve project knowledge (decisions, patterns, constraints) | -| task | Unified task lifecycle (next, current, start, complete) | -| plan | Plan management (clarify, decompose, expand, generate, finalize, audit) | -| code | Code intelligence (find, search, explain, callers, impact, simplify) | -| debug | Diagnose issues systematically with AI-powered analysis | -| remember | Store knowledge in project memory | +| `ask` | Search project knowledge (decisions, patterns, constraints) | +| `task` | Unified task lifecycle (`next`, `current`, `start`, `complete`) | +| `plan` | Plan management (`clarify`, `decompose`, `expand`, `generate`, `finalize`, `audit`) | +| `code` | Code intelligence (`find`, `search`, `explain`, `callers`, `impact`, `simplify`) | +| `debug` | Diagnose issues systematically with AI-powered analysis | +| `remember` | Store knowledge in project memory | ### Autonomous Task Execution (Hooks) diff --git a/README.md b/README.md index e5b0765..f821ca3 100644 --- a/README.md +++ b/README.md @@ -8,32 +8,26 @@ ## Supported Models - [![OpenAI](https://img.shields.io/badge/OpenAI-412991?logo=openai&logoColor=white)](https://platform.openai.com/) [![Anthropic](https://img.shields.io/badge/Anthropic-191919?logo=anthropic&logoColor=white)](https://www.anthropic.com/) [![Google Gemini](https://img.shields.io/badge/Google_Gemini-4285F4?logo=google&logoColor=white)](https://ai.google.dev/) [![AWS Bedrock](https://img.shields.io/badge/AWS_Bedrock-OpenAI--Compatible_Beta-FF9900?logo=amazonaws&logoColor=white)](https://docs.aws.amazon.com/bedrock/latest/userguide/inference-chat-completions.html) [![Ollama](https://img.shields.io/badge/Ollama-Local-000000?logo=ollama&logoColor=white)](https://ollama.com/) - ## Works With - [![Claude Code](https://img.shields.io/badge/Claude_Code-191919?logo=anthropic&logoColor=white)](https://www.anthropic.com/claude-code) [![OpenAI Codex](https://img.shields.io/badge/OpenAI_Codex-412991?logo=openai&logoColor=white)](https://developers.openai.com/codex) [![Cursor](https://img.shields.io/badge/Cursor-111111?logo=cursor&logoColor=white)](https://cursor.com/) [![GitHub Copilot](https://img.shields.io/badge/GitHub_Copilot-181717?logo=githubcopilot&logoColor=white)](https://github.com/features/copilot) [![Gemini CLI](https://img.shields.io/badge/Gemini_CLI-4285F4?logo=google&logoColor=white)](https://github.com/google-gemini/gemini-cli) [![OpenCode](https://img.shields.io/badge/OpenCode-000000?logo=opencode&logoColor=white)](https://opencode.ai/) - - Brand names and logos are trademarks of their respective owners; usage here indicates compatibility, not endorsement. - ## Focused Workflow @@ -61,9 +55,9 @@ taskwing goal "Add Stripe billing" ## Core Commands - - `taskwing bootstrap` - `taskwing goal ""` +- `taskwing ask ""` - `taskwing task` - `taskwing plan status` - `taskwing slash` @@ -76,16 +70,14 @@ taskwing goal "Add Stripe billing" ## MCP Tools - -| Tool | Description | -| ---------- | ----------------------------------------------------------------------------------- | -| `recall` | Retrieve project knowledge (decisions, patterns, constraints) | -| `task` | Unified task lifecycle (`next`, `current`, `start`, `complete`) | -| `plan` | Plan management (`clarify`, `decompose`, `expand`, `generate`, `finalize`, `audit`) | -| `code` | Code intelligence (`find`, `search`, `explain`, `callers`, `impact`, `simplify`) | -| `debug` | Diagnose issues systematically with AI-powered analysis | -| `remember` | Store knowledge in project memory | - +| Tool | Description | +|------|-------------| +| `ask` | Search project knowledge (decisions, patterns, constraints) | +| `task` | Unified task lifecycle (`next`, `current`, `start`, `complete`) | +| `plan` | Plan management (`clarify`, `decompose`, `expand`, `generate`, `finalize`, `audit`) | +| `code` | Code intelligence (`find`, `search`, `explain`, `callers`, `impact`, `simplify`) | +| `debug` | Diagnose issues systematically with AI-powered analysis | +| `remember` | Store knowledge in project memory | ## AWS Bedrock (OpenAI-Compatible) Setup diff --git a/cmd/ask.go b/cmd/ask.go new file mode 100644 index 0000000..7d95646 --- /dev/null +++ b/cmd/ask.go @@ -0,0 +1,104 @@ +/* +Copyright © 2025 Joseph Goksu josephgoksu@gmail.com +*/ +package cmd + +import ( + "fmt" + "os" + + "github.com/josephgoksu/TaskWing/internal/app" + "github.com/josephgoksu/TaskWing/internal/llm" + "github.com/josephgoksu/TaskWing/internal/ui" + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +var askCmd = &cobra.Command{ + Use: "ask ", + Short: "Search project knowledge and code symbols", + SilenceUsage: true, + Long: `Query the project knowledge base from the CLI. + +Searches architectural knowledge (decisions, patterns, constraints) and +code symbols (functions, types, interfaces) using the same pipeline as +the MCP ask tool. + +By default, uses hybrid search (FTS + vector). Use --fts-only to skip +embedding API calls for faster, offline results. + +Examples: + taskwing ask "how does authentication work" + taskwing ask "SQLite schema design" --limit 10 + taskwing ask "how does the MCP server work" --answer + taskwing ask "task state machine" --json + taskwing ask "API endpoints" --fts-only + taskwing ask "auth" --workspace=osprey`, + Args: cobra.ExactArgs(1), + RunE: runAsk, +} + +func init() { + rootCmd.AddCommand(askCmd) + askCmd.Flags().BoolP("answer", "a", false, "Generate a RAG answer (uses LLM, slower)") + askCmd.Flags().StringP("workspace", "w", "", "Filter by workspace (monorepo)") + askCmd.Flags().IntP("limit", "l", 5, "Max knowledge results") + askCmd.Flags().Bool("no-symbols", false, "Skip code symbol search") + askCmd.Flags().Bool("fts-only", false, "Disable vector search (faster, no embedding API call)") +} + +func runAsk(cmd *cobra.Command, args []string) error { + query := args[0] + + repo, err := openRepoOrHandleMissingMemory() + if err != nil { + return err + } + if repo == nil { + return nil + } + defer func() { _ = repo.Close() }() + + cfg, err := getLLMConfigForRole(cmd, llm.RoleQuery) + if err != nil { + return fmt.Errorf("llm config: %w", err) + } + + askApp := app.NewAskApp(app.NewContextWithConfig(repo, cfg)) + + // Build options from flags + limit, _ := cmd.Flags().GetInt("limit") + noSymbols, _ := cmd.Flags().GetBool("no-symbols") + ftsOnly, _ := cmd.Flags().GetBool("fts-only") + generateAnswer, _ := cmd.Flags().GetBool("answer") + workspace, _ := cmd.Flags().GetString("workspace") + + if workspace != "" { + if err := app.ValidateWorkspace(workspace); err != nil { + return err + } + } + + opts := app.DefaultAskOptions() + opts.Limit = limit + opts.IncludeSymbols = !noSymbols + opts.DisableVector = ftsOnly + opts.GenerateAnswer = generateAnswer + opts.Workspace = workspace + + if generateAnswer { + opts.StreamWriter = os.Stdout + } + + result, err := askApp.Query(cmd.Context(), query, opts) + if err != nil { + return fmt.Errorf("query failed: %w", err) + } + + if isJSON() { + return printJSON(result) + } + + ui.RenderAskResult(result, viper.GetBool("verbose")) + return nil +} diff --git a/cmd/bootstrap.go b/cmd/bootstrap.go index fd6a7ca..f627bd3 100644 --- a/cmd/bootstrap.go +++ b/cmd/bootstrap.go @@ -226,7 +226,16 @@ func runBootstrap(cmd *cobra.Command, args []string) error { func executeAction(ctx context.Context, action bootstrap.Action, svc *bootstrap.Service, cwd string, flags bootstrap.Flags, plan *bootstrap.Plan, llmCfg llm.Config) error { switch action { case bootstrap.ActionInitProject: - return executeInitProject(svc, flags, plan) + if err := executeInitProject(svc, flags, plan); err != nil { + return err + } + // Re-detect project context now that local .taskwing/ exists. + // Without this, the cached context still points to ~/.taskwing/ (HOME) + // and all subsequent DB operations write to the wrong database. + if freshCtx, err := project.Detect(cwd); err == nil { + _ = config.SetProjectContext(freshCtx) + } + return nil case bootstrap.ActionGenerateAIConfigs: return executeGenerateAIConfigs(svc, flags, plan) @@ -764,7 +773,7 @@ func checkAgentFailures(agents []*ui.AgentState) error { } // runCodeIndexing runs the code intelligence indexer on the codebase. -// This extracts symbols (functions, types, etc.) for enhanced search and MCP recall. +// This extracts symbols (functions, types, etc.) for enhanced search and MCP ask. func runCodeIndexing(ctx context.Context, basePath string, forceIndex, isQuiet bool) error { // Open repository to get database handle repo, err := openRepo() diff --git a/cmd/bootstrap_test.go b/cmd/bootstrap_test.go deleted file mode 100644 index b3218af..0000000 --- a/cmd/bootstrap_test.go +++ /dev/null @@ -1,112 +0,0 @@ -/* -Copyright © 2025 Joseph Goksu josephgoksu@gmail.com -*/ -package cmd - -import ( - "encoding/json" - "os" - "path/filepath" - "testing" - - "github.com/josephgoksu/TaskWing/internal/bootstrap" -) - -// TestInstallMCPServers_OpenCode tests that installMCPServers correctly installs OpenCode MCP config. -func TestInstallMCPServers_OpenCode(t *testing.T) { - tmpDir := t.TempDir() - - // Mock binPath - in tests we can use any path - binPath := "/usr/local/bin/taskwing" - - // Call installMCPServers with opencode - installMCPServers(tmpDir, []string{"opencode"}) - - // Verify opencode.json was created - configPath := filepath.Join(tmpDir, "opencode.json") - content, err := os.ReadFile(configPath) - if err != nil { - t.Fatalf("Failed to read opencode.json: %v", err) - } - - // Parse and verify structure - var config OpenCodeConfig - if err := json.Unmarshal(content, &config); err != nil { - t.Fatalf("Invalid JSON in opencode.json: %v", err) - } - - // Verify schema - if config.Schema != "https://opencode.ai/config.json" { - t.Errorf("Schema = %q, want %q", config.Schema, "https://opencode.ai/config.json") - } - - // Verify MCP section exists - if config.MCP == nil { - t.Fatal("MCP section is nil") - } - - // Server name must be canonical (strict naming policy) - expectedServerName := "taskwing-mcp" - serverCfg, ok := config.MCP[expectedServerName] - if !ok { - t.Fatalf("Canonical taskwing-mcp server entry missing in MCP section. Got: %v", config.MCP) - } - - // Verify type is "local" - if serverCfg.Type != "local" { - t.Errorf("Type = %q, want %q", serverCfg.Type, "local") - } - - // Verify command is array format - if len(serverCfg.Command) != 2 { - t.Fatalf("Command length = %d, want 2", len(serverCfg.Command)) - } - // Command[0] will use the actual executable path, not our mock binPath - // Just verify the second element is "mcp" - if serverCfg.Command[1] != "mcp" { - t.Errorf("Command[1] = %q, want %q", serverCfg.Command[1], "mcp") - } - - _ = binPath // suppress unused variable warning -} - -// TestInstallMCPServers_AllIncludesOpenCode tests that "all" AIs doesn't break when including opencode. -func TestInstallMCPServers_AllIncludesOpenCode(t *testing.T) { - tmpDir := t.TempDir() - - // Install multiple AIs including opencode - installMCPServers(tmpDir, []string{"claude", "opencode"}) - - // Verify opencode.json was created - configPath := filepath.Join(tmpDir, "opencode.json") - if _, err := os.Stat(configPath); os.IsNotExist(err) { - t.Error("opencode.json was not created when installing multiple AIs including opencode") - } -} - -// TestAIConfigOrder_IncludesOpenCode verifies opencode is in the AI selection list. -func TestAIConfigOrder_IncludesOpenCode(t *testing.T) { - found := false - for _, ai := range aiConfigOrder { - if ai == "opencode" { - found = true - break - } - } - if !found { - t.Error("opencode is not in aiConfigOrder") - } -} - -// TestAIDisplayNames_OpenCodeEntry verifies opencode metadata is exposed by canonical catalog helpers. -func TestAIDisplayNames_OpenCodeEntry(t *testing.T) { - displayNames := bootstrap.AIDisplayNames() - displayName, ok := displayNames["opencode"] - if !ok { - t.Fatal("opencode entry not found in canonical AI display names") - } - - if displayName != "OpenCode" { - t.Errorf("displayName = %q, want %q", displayName, "OpenCode") - } -} diff --git a/cmd/config_cmd.go b/cmd/config_cmd.go index f4c896c..44158d4 100644 --- a/cmd/config_cmd.go +++ b/cmd/config_cmd.go @@ -593,7 +593,7 @@ func configureBootstrapModel() error { func configureQueryModel() error { fmt.Println("\n⚡ Configure Fast Queries Model") - fmt.Println(" Used for: context lookups, recall (cheaper, faster)") + fmt.Println(" Used for: context lookups, ask queries (cheaper, faster)") fmt.Println() selection, err := ui.PromptLLMSelection() diff --git a/cmd/config_helper.go b/cmd/config_helper.go index 0ded180..c9a636e 100644 --- a/cmd/config_helper.go +++ b/cmd/config_helper.go @@ -243,7 +243,7 @@ func promptBedrockRegion() (string, error) { // // Role-specific config keys: // - llm.models.bootstrap: "provider:model" for bootstrap/planning tasks -// - llm.models.query: "provider:model" for context/recall queries +// - llm.models.query: "provider:model" for context/ask queries // // If no role-specific config is set, falls back to getLLMConfig(). func getLLMConfigForRole(cmd *cobra.Command, role llm.ModelRole) (llm.Config, error) { diff --git a/cmd/doctor_test.go b/cmd/doctor_test.go deleted file mode 100644 index d4d93bc..0000000 --- a/cmd/doctor_test.go +++ /dev/null @@ -1,130 +0,0 @@ -package cmd - -import ( - "strings" - "testing" - - "github.com/josephgoksu/TaskWing/internal/bootstrap" -) - -func TestChecksFromIntegrationReports_Healthy(t *testing.T) { - reports := map[string]bootstrap.IntegrationReport{ - "opencode": { - AI: "opencode", - Issues: nil, - }, - } - - checks := checksFromIntegrationReports(reports) - if len(checks) != 1 { - t.Fatalf("expected 1 check, got %d", len(checks)) - } - if checks[0].Name != "Integration (opencode)" { - t.Fatalf("unexpected check name: %s", checks[0].Name) - } - if checks[0].Status != "ok" { - t.Fatalf("expected ok status, got %s", checks[0].Status) - } -} - -func TestChecksFromIntegrationReports_HintSelection(t *testing.T) { - reports := map[string]bootstrap.IntegrationReport{ - "claude": { - AI: "claude", - Issues: []bootstrap.IntegrationIssue{ - { - AI: "claude", - Component: bootstrap.AIComponentMCPGlobal, - Status: bootstrap.ComponentStatusMissing, - Reason: "global taskwing-mcp registration missing", - MutatesGlobal: true, - }, - }, - }, - "codex": { - AI: "codex", - Issues: []bootstrap.IntegrationIssue{ - { - AI: "codex", - Component: bootstrap.AIComponentCommands, - Status: bootstrap.ComponentStatusInvalid, - Reason: "commands invalid", - AdoptRequired: true, - }, - }, - }, - } - - checks := checksFromIntegrationReports(reports) - if len(checks) != 2 { - t.Fatalf("expected 2 checks, got %d", len(checks)) - } - - var sawGlobalHint bool - var sawAdoptHint bool - for _, check := range checks { - if strings.Contains(check.Hint, "--yes --ai claude") { - sawGlobalHint = true - } - if strings.Contains(check.Hint, "--adopt-unmanaged --ai codex") { - sawAdoptHint = true - } - } - if !sawGlobalHint { - t.Fatalf("expected global mutation hint in checks: %+v", checks) - } - if !sawAdoptHint { - t.Fatalf("expected adopt-unmanaged hint in checks: %+v", checks) - } -} - -func TestDoctor_OpenCodeFixReevaluatesState(t *testing.T) { - tmpDir := t.TempDir() - init := bootstrap.NewInitializer(tmpDir) - - if err := init.CreateSlashCommands("opencode", false); err != nil { - t.Fatalf("create opencode commands: %v", err) - } - if err := init.InstallHooksConfig("opencode", false); err != nil { - t.Fatalf("create opencode plugin: %v", err) - } - - reportsBefore := bootstrap.EvaluateIntegrations(tmpDir, map[string]bool{}) - if !hasIntegrationIssue(reportsBefore["opencode"], bootstrap.AIComponentMCPLocal) { - t.Fatalf("expected opencode mcp_local issue before repair, got %+v", reportsBefore["opencode"].Issues) - } - - plan := bootstrap.BuildRepairPlan(reportsBefore, bootstrap.RepairPlanOptions{ - TargetAIs: []string{"opencode"}, - IncludeGlobalMutations: true, - }) - if len(plan.Actions) == 0 { - t.Fatal("expected non-empty repair plan for opencode drift") - } - - applied, skipped, blocked, err := applyRepairPlan(tmpDir, plan, doctorFixOptions{ - Fix: true, - Yes: true, - TargetAIs: []string{"opencode"}, - }) - if err != nil { - t.Fatalf("apply repair plan: %v", err) - } - if len(applied) == 0 { - t.Fatalf("expected at least one applied action; skipped=%d blocked=%d", len(skipped), len(blocked)) - } - - reportsAfter := bootstrap.EvaluateIntegrations(tmpDir, map[string]bool{}) - if hasIntegrationIssue(reportsAfter["opencode"], bootstrap.AIComponentMCPLocal) { - t.Fatalf("expected opencode mcp_local issue to be resolved, got %+v", reportsAfter["opencode"].Issues) - } -} - -func hasIntegrationIssue(report bootstrap.IntegrationReport, component bootstrap.AIComponent) bool { - for _, issue := range report.Issues { - if issue.Component == component { - return true - } - } - return false -} diff --git a/cmd/hook.go b/cmd/hook.go index af8f856..d1c839a 100644 --- a/cmd/hook.go +++ b/cmd/hook.go @@ -341,21 +341,21 @@ func runContinueCheck(maxTasks, maxMinutes int) error { func buildTaskContext(repo *memory.Repository, nextTask *task.Task, plan *task.Plan) string { ctx := context.Background() - // Get knowledge service for recall context + // Get knowledge service for ask context llmCfg, _ := getLLMConfigFromViper() ks := knowledge.NewService(repo, llmCfg) // Create search adapter that wraps knowledge.Service for the task package - searchFn := func(ctx context.Context, query string, limit int) ([]task.RecallResult, error) { + searchFn := func(ctx context.Context, query string, limit int) ([]task.AskResult, error) { searchCtx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() results, err := ks.Search(searchCtx, query, limit) if err != nil { return nil, err } - var adapted []task.RecallResult + var adapted []task.AskResult for _, r := range results { - adapted = append(adapted, task.RecallResult{ + adapted = append(adapted, task.AskResult{ Summary: r.Node.Summary, Type: r.Node.Type, Content: r.Node.Text(), diff --git a/cmd/hook_test.go b/cmd/hook_test.go deleted file mode 100644 index 0062a10..0000000 --- a/cmd/hook_test.go +++ /dev/null @@ -1,354 +0,0 @@ -/* -Copyright © 2025 Joseph Goksu josephgoksu@gmail.com -*/ -package cmd - -import ( - "encoding/json" - "os" - "path/filepath" - "testing" - "time" - - "github.com/josephgoksu/TaskWing/internal/config" -) - -func TestHookSessionJSON(t *testing.T) { - // Test JSON serialization of HookSession with all fields - session := &HookSession{ - SessionID: "test-session-123", - StartedAt: time.Now().UTC(), - TasksCompleted: 3, - TasksStarted: 4, - CurrentTaskID: "task-456", - PlanID: "plan-789", - LastTaskHadCriticalDeviation: true, - LastDeviationSummary: "Modified protected file", - TotalDeviationsDetected: 2, - LastTaskHadPolicyViolation: true, - LastPolicyViolations: []string{"Cannot modify .env files", "Secrets directory protected"}, - TotalPolicyViolations: 2, - } - - // Marshal to JSON - data, err := json.Marshal(session) - if err != nil { - t.Fatalf("Failed to marshal session: %v", err) - } - - // Unmarshal back - var decoded HookSession - if err := json.Unmarshal(data, &decoded); err != nil { - t.Fatalf("Failed to unmarshal session: %v", err) - } - - // Verify fields - if decoded.SessionID != session.SessionID { - t.Errorf("SessionID: got %s, want %s", decoded.SessionID, session.SessionID) - } - if decoded.TasksCompleted != session.TasksCompleted { - t.Errorf("TasksCompleted: got %d, want %d", decoded.TasksCompleted, session.TasksCompleted) - } - if decoded.LastTaskHadPolicyViolation != session.LastTaskHadPolicyViolation { - t.Errorf("LastTaskHadPolicyViolation: got %v, want %v", decoded.LastTaskHadPolicyViolation, session.LastTaskHadPolicyViolation) - } - if len(decoded.LastPolicyViolations) != len(session.LastPolicyViolations) { - t.Errorf("LastPolicyViolations length: got %d, want %d", len(decoded.LastPolicyViolations), len(session.LastPolicyViolations)) - } -} - -func TestHookResponseJSON(t *testing.T) { - tests := []struct { - name string - response HookResponse - wantKey string - wantNoKey string - }{ - { - name: "allow stop (no decision)", - response: HookResponse{Reason: "Plan complete"}, - wantKey: "reason", - wantNoKey: "decision", - }, - { - name: "block stop", - response: func() HookResponse { - block := "block" - return HookResponse{Decision: &block, Reason: "Continue to next task"} - }(), - wantKey: "decision", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - data, err := json.Marshal(tt.response) - if err != nil { - t.Fatalf("Failed to marshal response: %v", err) - } - - var decoded map[string]any - if err := json.Unmarshal(data, &decoded); err != nil { - t.Fatalf("Failed to unmarshal response: %v", err) - } - - if _, ok := decoded[tt.wantKey]; !ok { - t.Errorf("Expected key %q in response, got: %s", tt.wantKey, string(data)) - } - - if tt.wantNoKey != "" { - if _, ok := decoded[tt.wantNoKey]; ok { - t.Errorf("Did not expect key %q in response, got: %s", tt.wantNoKey, string(data)) - } - } - }) - } -} - -func TestHookSessionPersistence(t *testing.T) { - // Create temp directory for test - tmpDir := t.TempDir() - sessionPath := filepath.Join(tmpDir, "hook_session.json") - - // Create test session - session := &HookSession{ - SessionID: "test-session-persist", - StartedAt: time.Now().UTC(), - TasksCompleted: 2, - TasksStarted: 3, - LastTaskHadPolicyViolation: true, - LastPolicyViolations: []string{"Policy violation 1"}, - TotalPolicyViolations: 1, - } - - // Save session manually - data, err := json.MarshalIndent(session, "", " ") - if err != nil { - t.Fatalf("Failed to marshal session: %v", err) - } - - if err := os.WriteFile(sessionPath, data, 0644); err != nil { - t.Fatalf("Failed to write session file: %v", err) - } - - // Read and verify - readData, err := os.ReadFile(sessionPath) - if err != nil { - t.Fatalf("Failed to read session file: %v", err) - } - - var loaded HookSession - if err := json.Unmarshal(readData, &loaded); err != nil { - t.Fatalf("Failed to unmarshal loaded session: %v", err) - } - - if loaded.SessionID != session.SessionID { - t.Errorf("SessionID mismatch: got %s, want %s", loaded.SessionID, session.SessionID) - } - if !loaded.LastTaskHadPolicyViolation { - t.Error("LastTaskHadPolicyViolation should be true") - } - if len(loaded.LastPolicyViolations) != 1 { - t.Errorf("LastPolicyViolations length: got %d, want 1", len(loaded.LastPolicyViolations)) - } -} - -func TestHookSessionDefaults(t *testing.T) { - if DefaultMaxTasksPerSession != 5 { - t.Errorf("DefaultMaxTasksPerSession: got %d, want 5", DefaultMaxTasksPerSession) - } - - if DefaultMaxSessionMinutes != 30 { - t.Errorf("DefaultMaxSessionMinutes: got %d, want 30", DefaultMaxSessionMinutes) - } -} - -func TestHookResponseBlock(t *testing.T) { - block := "block" - resp := HookResponse{ - Decision: &block, - Reason: "Continue to task 2/5: Add authentication", - } - - data, err := json.Marshal(resp) - if err != nil { - t.Fatalf("Failed to marshal: %v", err) - } - - var decoded map[string]any - if err := json.Unmarshal(data, &decoded); err != nil { - t.Fatalf("Failed to unmarshal: %v", err) - } - - if decoded["decision"] != "block" { - t.Errorf("Decision should be 'block', got: %v", decoded["decision"]) - } -} - -func TestHookResponseAllow(t *testing.T) { - resp := HookResponse{ - Reason: "Circuit breaker: Max tasks reached", - } - - data, err := json.Marshal(resp) - if err != nil { - t.Fatalf("Failed to marshal: %v", err) - } - - var decoded map[string]any - if err := json.Unmarshal(data, &decoded); err != nil { - t.Fatalf("Failed to unmarshal: %v", err) - } - - // When allowing stop, decision should be omitted - if _, exists := decoded["decision"]; exists { - t.Error("Decision should be omitted when allowing stop") - } -} - -func TestPolicyCircuitBreakerTracking(t *testing.T) { - // Test that policy violations are properly tracked in session - session := &HookSession{ - SessionID: "test-policy-circuit", - LastTaskHadPolicyViolation: true, - LastPolicyViolations: []string{"Cannot modify .env files", "Secrets directory is protected"}, - TotalPolicyViolations: 2, - } - - if !session.LastTaskHadPolicyViolation { - t.Error("Policy violation flag should be set") - } - - if len(session.LastPolicyViolations) != 2 { - t.Errorf("Expected 2 violations, got %d", len(session.LastPolicyViolations)) - } - - if session.TotalPolicyViolations != 2 { - t.Errorf("TotalPolicyViolations: got %d, want 2", session.TotalPolicyViolations) - } - - // Serialize and verify - data, _ := json.Marshal(session) - var decoded HookSession - _ = json.Unmarshal(data, &decoded) - - if !decoded.LastTaskHadPolicyViolation { - t.Error("Policy violation flag should persist through JSON serialization") - } -} - -func TestDeviationCircuitBreakerTracking(t *testing.T) { - // Test that deviations are properly tracked in session - session := &HookSession{ - SessionID: "test-deviation-circuit", - LastTaskHadCriticalDeviation: true, - LastDeviationSummary: "2 unexpected files, 1 missing file (requires review)", - TotalDeviationsDetected: 3, - } - - if !session.LastTaskHadCriticalDeviation { - t.Error("Critical deviation flag should be set") - } - - if session.LastDeviationSummary == "" { - t.Error("Deviation summary should not be empty") - } - - if session.TotalDeviationsDetected != 3 { - t.Errorf("TotalDeviationsDetected: got %d, want 3", session.TotalDeviationsDetected) - } -} - -func TestHookSessionAllFieldsSerialization(t *testing.T) { - // Verify all session fields survive JSON round-trip - original := &HookSession{ - SessionID: "full-test-session", - StartedAt: time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC), - TasksCompleted: 5, - TasksStarted: 6, - CurrentTaskID: "task-current", - PlanID: "plan-active", - LastTaskHadCriticalDeviation: true, - LastDeviationSummary: "deviation summary", - TotalDeviationsDetected: 10, - LastTaskHadPolicyViolation: true, - LastPolicyViolations: []string{"v1", "v2", "v3"}, - TotalPolicyViolations: 3, - } - - data, err := json.Marshal(original) - if err != nil { - t.Fatalf("Marshal failed: %v", err) - } - - var decoded HookSession - if err := json.Unmarshal(data, &decoded); err != nil { - t.Fatalf("Unmarshal failed: %v", err) - } - - // Check all fields - checks := []struct { - name string - got, want any - }{ - {"SessionID", decoded.SessionID, original.SessionID}, - {"TasksCompleted", decoded.TasksCompleted, original.TasksCompleted}, - {"TasksStarted", decoded.TasksStarted, original.TasksStarted}, - {"CurrentTaskID", decoded.CurrentTaskID, original.CurrentTaskID}, - {"PlanID", decoded.PlanID, original.PlanID}, - {"LastTaskHadCriticalDeviation", decoded.LastTaskHadCriticalDeviation, original.LastTaskHadCriticalDeviation}, - {"LastDeviationSummary", decoded.LastDeviationSummary, original.LastDeviationSummary}, - {"TotalDeviationsDetected", decoded.TotalDeviationsDetected, original.TotalDeviationsDetected}, - {"LastTaskHadPolicyViolation", decoded.LastTaskHadPolicyViolation, original.LastTaskHadPolicyViolation}, - {"TotalPolicyViolations", decoded.TotalPolicyViolations, original.TotalPolicyViolations}, - {"LastPolicyViolations length", len(decoded.LastPolicyViolations), len(original.LastPolicyViolations)}, - } - - for _, c := range checks { - if c.got != c.want { - t.Errorf("%s: got %v, want %v", c.name, c.got, c.want) - } - } -} - -func TestResolveHookMemoryPath_UsesClaudeProjectDir(t *testing.T) { - config.ClearProjectContext() - - tmpDir := t.TempDir() - if err := os.MkdirAll(filepath.Join(tmpDir, ".taskwing", "memory"), 0755); err != nil { - t.Fatalf("mkdir project memory: %v", err) - } - t.Setenv("CLAUDE_PROJECT_DIR", tmpDir) - - path, err := resolveHookMemoryPath() - if err != nil { - t.Fatalf("resolveHookMemoryPath failed: %v", err) - } - - want := filepath.Join(tmpDir, ".taskwing", "memory") - if path != want { - t.Fatalf("memory path = %q, want %q", path, want) - } -} - -func TestResolveHookMemoryPath_FallsBackToGlobal(t *testing.T) { - config.ClearProjectContext() - t.Setenv("CLAUDE_PROJECT_DIR", "") - - globalDir := t.TempDir() - origGlobalDir := config.GetGlobalConfigDir - config.GetGlobalConfigDir = func() (string, error) { return globalDir, nil } - t.Cleanup(func() { - config.GetGlobalConfigDir = origGlobalDir - }) - - path, err := resolveHookMemoryPath() - if err != nil { - t.Fatalf("resolveHookMemoryPath failed: %v", err) - } - - want := filepath.Join(globalDir, "memory") - if path != want { - t.Fatalf("memory path = %q, want %q", path, want) - } -} diff --git a/cmd/mcp_install_test.go b/cmd/mcp_install_test.go deleted file mode 100644 index 1ac8636..0000000 --- a/cmd/mcp_install_test.go +++ /dev/null @@ -1,372 +0,0 @@ -/* -Copyright © 2025 Joseph Goksu josephgoksu@gmail.com -*/ -package cmd - -import ( - "encoding/json" - "os" - "path/filepath" - "testing" -) - -// ============================================================================= -// OpenCode MCP Install Tests -// ============================================================================= - -// TestInstallOpenCode_Success tests successful creation of opencode.json -func TestInstallOpenCode_Success(t *testing.T) { - tmpDir := t.TempDir() - binPath := "/usr/local/bin/taskwing" - - err := installOpenCode(binPath, tmpDir) - if err != nil { - t.Fatalf("installOpenCode failed: %v", err) - } - - // Verify file was created at project root - configPath := filepath.Join(tmpDir, "opencode.json") - content, err := os.ReadFile(configPath) - if err != nil { - t.Fatalf("Failed to read opencode.json: %v", err) - } - - // Parse and verify JSON structure - var config OpenCodeConfig - if err := json.Unmarshal(content, &config); err != nil { - t.Fatalf("Invalid JSON in opencode.json: %v", err) - } - - // Verify schema - if config.Schema != "https://opencode.ai/config.json" { - t.Errorf("Schema = %q, want %q", config.Schema, "https://opencode.ai/config.json") - } - - // Verify MCP section exists - if config.MCP == nil { - t.Fatal("MCP section is nil") - } - - // Verify taskwing-mcp entry - serverCfg, ok := config.MCP["taskwing-mcp"] - if !ok { - t.Fatal("taskwing-mcp entry not found in MCP section") - } - - // Verify type is "local" - if serverCfg.Type != "local" { - t.Errorf("Type = %q, want %q", serverCfg.Type, "local") - } - - // Verify command is array format - if len(serverCfg.Command) != 2 { - t.Fatalf("Command length = %d, want 2", len(serverCfg.Command)) - } - if serverCfg.Command[0] != binPath { - t.Errorf("Command[0] = %q, want %q", serverCfg.Command[0], binPath) - } - if serverCfg.Command[1] != "mcp" { - t.Errorf("Command[1] = %q, want %q", serverCfg.Command[1], "mcp") - } - - // Verify timeout is set - if serverCfg.Timeout != 5000 { - t.Errorf("Timeout = %d, want %d", serverCfg.Timeout, 5000) - } -} - -// TestInstallOpenCode_CommandIsArray tests that command is JSON array, not string -func TestInstallOpenCode_CommandIsArray(t *testing.T) { - tmpDir := t.TempDir() - binPath := "/path/to/taskwing" - - err := installOpenCode(binPath, tmpDir) - if err != nil { - t.Fatalf("installOpenCode failed: %v", err) - } - - // Read raw JSON to verify array format - content, err := os.ReadFile(filepath.Join(tmpDir, "opencode.json")) - if err != nil { - t.Fatalf("Failed to read file: %v", err) - } - - // Check raw JSON contains array syntax for command - if !containsSubstr(string(content), `"command": [`) { - t.Error("command must be JSON array format (not string)") - } -} - -// TestInstallOpenCode_PreservesExistingConfig tests that existing config is preserved -func TestInstallOpenCode_PreservesExistingConfig(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "opencode.json") - - // Create existing config with another MCP server - existingConfig := OpenCodeConfig{ - Schema: "https://opencode.ai/config.json", - MCP: map[string]OpenCodeMCPServerConfig{ - "other-mcp": { - Type: "local", - Command: []string{"other", "command"}, - }, - }, - } - existingBytes, _ := json.MarshalIndent(existingConfig, "", " ") - if err := os.WriteFile(configPath, existingBytes, 0644); err != nil { - t.Fatalf("Failed to write existing config: %v", err) - } - - // Install TaskWing - err := installOpenCode("/usr/local/bin/taskwing", tmpDir) - if err != nil { - t.Fatalf("installOpenCode failed: %v", err) - } - - // Read back and verify both servers exist - content, err := os.ReadFile(configPath) - if err != nil { - t.Fatalf("Failed to read config: %v", err) - } - - var config OpenCodeConfig - if err := json.Unmarshal(content, &config); err != nil { - t.Fatalf("Invalid JSON: %v", err) - } - - // Verify existing server preserved - if _, ok := config.MCP["other-mcp"]; !ok { - t.Error("Existing 'other-mcp' server was removed") - } - - // Verify new server added - if _, ok := config.MCP["taskwing-mcp"]; !ok { - t.Error("'taskwing-mcp' server was not added") - } -} - -// TestInstallOpenCode_Idempotent tests that running twice doesn't duplicate -func TestInstallOpenCode_Idempotent(t *testing.T) { - tmpDir := t.TempDir() - binPath := "/usr/local/bin/taskwing" - - // Install twice - if err := installOpenCode(binPath, tmpDir); err != nil { - t.Fatalf("First install failed: %v", err) - } - if err := installOpenCode(binPath, tmpDir); err != nil { - t.Fatalf("Second install failed: %v", err) - } - - // Read and verify only one taskwing-mcp entry - content, err := os.ReadFile(filepath.Join(tmpDir, "opencode.json")) - if err != nil { - t.Fatalf("Failed to read config: %v", err) - } - - var config OpenCodeConfig - if err := json.Unmarshal(content, &config); err != nil { - t.Fatalf("Invalid JSON: %v", err) - } - - // Should have exactly one entry - if len(config.MCP) != 1 { - t.Errorf("Expected 1 MCP entry, got %d", len(config.MCP)) - } -} - -// TestInstallOpenCode_NoSecrets tests that no secrets are written -func TestInstallOpenCode_NoSecrets(t *testing.T) { - tmpDir := t.TempDir() - binPath := "/usr/local/bin/taskwing" - - err := installOpenCode(binPath, tmpDir) - if err != nil { - t.Fatalf("installOpenCode failed: %v", err) - } - - content, err := os.ReadFile(filepath.Join(tmpDir, "opencode.json")) - if err != nil { - t.Fatalf("Failed to read config: %v", err) - } - - contentStr := string(content) - - // Check no common secret patterns - secretPatterns := []string{ - "password", - "secret", - "api_key", - "apikey", - "API_KEY", - "token", - "credential", - } - - for _, pattern := range secretPatterns { - if containsSubstr(contentStr, pattern) { - t.Errorf("Config contains potential secret pattern: %s", pattern) - } - } - - // Verify no .env file was created - envPath := filepath.Join(tmpDir, ".env") - if _, err := os.Stat(envPath); !os.IsNotExist(err) { - t.Error(".env file should NOT be created") - } -} - -// ============================================================================= -// upsertOpenCodeMCPServer Validation Tests -// ============================================================================= - -// TestUpsertOpenCodeMCPServer_EmptyConfigPath tests validation of empty configPath -func TestUpsertOpenCodeMCPServer_EmptyConfigPath(t *testing.T) { - err := upsertOpenCodeMCPServer("", "taskwing-mcp", OpenCodeMCPServerConfig{ - Type: "local", - Command: []string{"taskwing", "mcp"}, - }) - if err == nil { - t.Error("Expected error for empty configPath, got nil") - } -} - -// TestUpsertOpenCodeMCPServer_EmptyServerName tests validation of empty serverName -func TestUpsertOpenCodeMCPServer_EmptyServerName(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "opencode.json") - - err := upsertOpenCodeMCPServer(configPath, "", OpenCodeMCPServerConfig{ - Type: "local", - Command: []string{"taskwing", "mcp"}, - }) - if err == nil { - t.Error("Expected error for empty serverName, got nil") - } -} - -// TestUpsertOpenCodeMCPServer_EmptyCommand tests validation of empty command -func TestUpsertOpenCodeMCPServer_EmptyCommand(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "opencode.json") - - err := upsertOpenCodeMCPServer(configPath, "taskwing-mcp", OpenCodeMCPServerConfig{ - Type: "local", - Command: []string{}, // Empty command array - }) - if err == nil { - t.Error("Expected error for empty command array, got nil") - } -} - -// TestUpsertOpenCodeMCPServer_InvalidType tests validation of invalid type -func TestUpsertOpenCodeMCPServer_InvalidType(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "opencode.json") - - err := upsertOpenCodeMCPServer(configPath, "taskwing-mcp", OpenCodeMCPServerConfig{ - Type: "remote", // Invalid - must be "local" - Command: []string{"taskwing", "mcp"}, - }) - if err == nil { - t.Error("Expected error for invalid type, got nil") - } -} - -// TestUpsertOpenCodeMCPServer_MalformedExistingJSON tests handling of malformed JSON -func TestUpsertOpenCodeMCPServer_MalformedExistingJSON(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "opencode.json") - - // Write malformed JSON - if err := os.WriteFile(configPath, []byte("not valid json{"), 0644); err != nil { - t.Fatalf("Failed to write malformed JSON: %v", err) - } - - // Should succeed by creating fresh config - err := upsertOpenCodeMCPServer(configPath, "taskwing-mcp", OpenCodeMCPServerConfig{ - Type: "local", - Command: []string{"taskwing", "mcp"}, - }) - if err != nil { - t.Fatalf("Should handle malformed JSON gracefully: %v", err) - } - - // Verify valid config was written - content, err := os.ReadFile(configPath) - if err != nil { - t.Fatalf("Failed to read config: %v", err) - } - - var config OpenCodeConfig - if err := json.Unmarshal(content, &config); err != nil { - t.Fatalf("Config should be valid JSON now: %v", err) - } -} - -func TestUpsertOpenCodeMCPServer_RemovesLegacyTaskWingKeys(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "opencode.json") - - existing := OpenCodeConfig{ - Schema: "https://opencode.ai/config.json", - MCP: map[string]OpenCodeMCPServerConfig{ - "taskwing-mcp-my-project": { - Type: "local", - Command: []string{"taskwing", "mcp"}, - }, - "other-mcp": { - Type: "local", - Command: []string{"other", "mcp"}, - }, - }, - } - data, err := json.MarshalIndent(existing, "", " ") - if err != nil { - t.Fatalf("marshal existing config: %v", err) - } - if err := os.WriteFile(configPath, data, 0644); err != nil { - t.Fatalf("write existing config: %v", err) - } - - if err := upsertOpenCodeMCPServer(configPath, "taskwing-mcp", OpenCodeMCPServerConfig{ - Type: "local", - Command: []string{"taskwing", "mcp"}, - }); err != nil { - t.Fatalf("upsertOpenCodeMCPServer failed: %v", err) - } - - content, err := os.ReadFile(configPath) - if err != nil { - t.Fatalf("read config: %v", err) - } - - var config OpenCodeConfig - if err := json.Unmarshal(content, &config); err != nil { - t.Fatalf("unmarshal config: %v", err) - } - - if _, ok := config.MCP["taskwing-mcp"]; !ok { - t.Fatal("canonical taskwing-mcp entry missing after upsert") - } - if _, ok := config.MCP["taskwing-mcp-my-project"]; ok { - t.Fatal("legacy taskwing-mcp-* entry should be removed during canonicalization") - } - if _, ok := config.MCP["other-mcp"]; !ok { - t.Fatal("non-taskwing MCP entries must be preserved") - } -} - -// ============================================================================= -// Helper Functions -// ============================================================================= - -// containsSubstr checks if s contains substr -func containsSubstr(s, substr string) bool { - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return true - } - } - return false -} diff --git a/cmd/mcp_server.go b/cmd/mcp_server.go index f62dabb..4ccd40f 100644 --- a/cmd/mcp_server.go +++ b/cmd/mcp_server.go @@ -28,7 +28,7 @@ var mcpCmd = &cobra.Command{ Long: `Start a Model Context Protocol (MCP) server to enable AI tools like Claude Code, Cursor, and other AI assistants to interact with TaskWing project memory. -The MCP server provides the recall tool that gives AI tools access to: +The MCP server provides the ask tool that gives AI tools access to: - Knowledge nodes (decisions, features, plans, notes) - Semantic search across project memory - Relationships between components @@ -161,10 +161,10 @@ func runMCPServer(ctx context.Context) error { server := mcpsdk.NewServer(impl, serverOpts) - // Register recall tool - retrieves stored codebase knowledge for AI context + // Register ask tool - retrieves stored codebase knowledge for AI context tool := &mcpsdk.Tool{ - Name: "recall", - Description: "Retrieve codebase architecture knowledge: decisions, patterns, constraints, and features. Returns an AI-synthesized answer and relevant context by default. Use {\"query\":\"search term\"} for semantic search.", + Name: "ask", + Description: "Search project knowledge: decisions, patterns, constraints, and code symbols. Returns an AI-synthesized answer and relevant context by default. Use {\"query\":\"search term\"} for semantic search.", } mcpsdk.AddTool(server, tool, func(ctx context.Context, session *mcpsdk.ServerSession, params *mcpsdk.CallToolParamsFor[mcppresenter.ProjectContextParams]) (*mcpsdk.CallToolResultFor[any], error) { @@ -309,17 +309,17 @@ REQUIRED FIELDS BY ACTION: // handleNodeContext returns context using the knowledge.Service (same as CLI). // This ensures MCP and CLI use identical search logic with zero drift. -// Uses the app.RecallApp for all business logic - single source of truth. +// Uses the app.AskApp for all business logic - single source of truth. func handleNodeContext(ctx context.Context, repo *memory.Repository, params mcppresenter.ProjectContextParams) (*mcpsdk.CallToolResultFor[any], error) { // Create app context with query role - respects llm.models.query config (same as CLI) appCtx := app.NewContextForRole(repo, llm.RoleQuery) - recallApp := app.NewRecallApp(appCtx) + askApp := app.NewAskApp(appCtx) query := strings.TrimSpace(params.Query) // No query = return project summary if query == "" { - summary, err := recallApp.Summary(ctx) + summary, err := askApp.Summary(ctx) if err != nil { return mcpErrorResponse(fmt.Errorf("get summary: %w", err)) } @@ -338,10 +338,10 @@ func handleNodeContext(ctx context.Context, repo *memory.Repository, params mcpp } // Execute query via app layer (ALL business logic delegated) - // Include symbols in MCP recall for enhanced context + // Include symbols in MCP ask for enhanced context // Note: Only generate answer if explicitly requested (params.Answer=true) // to avoid expensive LLM calls on every search - result, err := recallApp.Query(ctx, query, app.RecallOptions{ + result, err := askApp.Query(ctx, query, app.AskOptions{ Limit: 5, SymbolLimit: 5, GenerateAnswer: params.Answer, // Only when explicitly requested @@ -354,7 +354,7 @@ func handleNodeContext(ctx context.Context, repo *memory.Repository, params mcpp } // Return token-efficient Markdown instead of verbose JSON - return mcpMarkdownResponse(mcppresenter.FormatRecall(result)) + return mcpMarkdownResponse(mcppresenter.FormatAsk(result)) } // === Shared Tool Handlers === diff --git a/cmd/memory_missing_test.go b/cmd/memory_missing_test.go deleted file mode 100644 index d8a7ddd..0000000 --- a/cmd/memory_missing_test.go +++ /dev/null @@ -1,125 +0,0 @@ -package cmd - -import ( - "bytes" - "io" - "os" - "strings" - "testing" - - "github.com/josephgoksu/TaskWing/internal/config" - "github.com/spf13/viper" -) - -func TestMissingMemoryGuidance_AcrossKnowledgePlanTask(t *testing.T) { - withTempWorkingDir(t, func() { - knowledgeTypeFlag = "" - knowledgeWorkspaceFlag = "" - knowledgeAllFlag = false - - viper.Set("json", false) - viper.Set("quiet", false) - t.Cleanup(func() { - viper.Set("json", false) - viper.Set("quiet", false) - }) - - commands := []struct { - name string - run func() error - }{ - {name: "knowledge", run: func() error { return runKnowledge(knowledgeCmd, nil) }}, - {name: "plan status", run: func() error { return planStatusCmd.RunE(planStatusCmd, nil) }}, - {name: "task list", run: func() error { return runTaskList(taskCmd, nil) }}, - } - - for _, tc := range commands { - out := captureStdout(t, func() { - if err := tc.run(); err != nil { - t.Fatalf("%s returned error: %v", tc.name, err) - } - }) - - if !strings.Contains(out, "No project memory found for this repository.") { - t.Fatalf("%s output missing memory guidance: %q", tc.name, out) - } - if !strings.Contains(out, "Run 'taskwing bootstrap' first.") { - t.Fatalf("%s output missing bootstrap guidance: %q", tc.name, out) - } - } - }) -} - -func TestMissingMemoryGuidance_JSONShape(t *testing.T) { - withTempWorkingDir(t, func() { - viper.Set("json", true) - t.Cleanup(func() { - viper.Set("json", false) - }) - - out := captureStdout(t, func() { - if err := runTaskList(taskCmd, nil); err != nil { - t.Fatalf("task list returned error: %v", err) - } - }) - - for _, want := range []string{ - `"ok": false`, - `"error": "project memory not initialized"`, - `"command": "taskwing bootstrap"`, - `"next": "run taskwing bootstrap"`, - } { - if !strings.Contains(out, want) { - t.Fatalf("missing JSON field %q in output: %s", want, out) - } - } - }) -} - -func withTempWorkingDir(t *testing.T, fn func()) { - t.Helper() - - originalWD, err := os.Getwd() - if err != nil { - t.Fatalf("getwd: %v", err) - } - - tmpDir := t.TempDir() - if err := os.Chdir(tmpDir); err != nil { - t.Fatalf("chdir temp dir: %v", err) - } - - config.ClearProjectContext() - if _, err := config.DetectAndSetProjectContext(); err != nil { - t.Fatalf("set project context: %v", err) - } - t.Cleanup(func() { - config.ClearProjectContext() - _ = os.Chdir(originalWD) - }) - - fn() -} - -func captureStdout(t *testing.T, fn func()) string { - t.Helper() - - original := os.Stdout - r, w, err := os.Pipe() - if err != nil { - t.Fatalf("create pipe: %v", err) - } - - os.Stdout = w - fn() - _ = w.Close() - os.Stdout = original - - var buf bytes.Buffer - if _, err := io.Copy(&buf, r); err != nil { - t.Fatalf("capture stdout: %v", err) - } - _ = r.Close() - - return buf.String() -} diff --git a/cmd/slash.go b/cmd/slash.go index c0d7ed5..ed9c766 100644 --- a/cmd/slash.go +++ b/cmd/slash.go @@ -17,7 +17,8 @@ var slashContents = map[string]string{ "done": slashDoneContent, "status": slashStatusContent, "plan": slashPlanContent, - "brief": slashBriefContent, + "ask": slashAskContent, + "remember": slashRememberContent, "simplify": slashSimplifyContent, "debug": slashDebugContent, "explain": slashExplainContent, diff --git a/cmd/slash_content.go b/cmd/slash_content.go index b2c5e13..86cdce9 100644 --- a/cmd/slash_content.go +++ b/cmd/slash_content.go @@ -28,12 +28,12 @@ Extract from the response: - scope (e.g., "auth", "vectorsearch", "api") - keywords array - acceptance_criteria -- suggested_recall_queries +- suggested_ask_queries If no task returned, inform user: "No pending tasks. Use 'taskwing plan list' to check plan status." ## Step 2: Fetch Scope-Relevant Context -Call MCP tool ` + "`recall`" + ` with query based on task scope: +Call MCP tool ` + "`ask`" + ` with query based on task scope: ` + "```json" + ` {"query": "[task.scope] patterns constraints decisions"} ` + "```" + ` @@ -46,8 +46,8 @@ Examples: Extract: patterns, constraints, related decisions. ## Step 3: Fetch Task-Specific Context -Call MCP tool ` + "`recall`" + ` with keywords from the task. -Use ` + "`suggested_recall_queries`" + ` if available, otherwise extract keywords from title. +Call MCP tool ` + "`ask`" + ` with keywords from the task. +Use ` + "`suggested_ask_queries`" + ` if available, otherwise extract keywords from title. ` + "```json" + ` {"query": "[keywords from task title/description]"} ` + "```" + ` @@ -81,7 +81,7 @@ Display in this format: ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ ## Relevant Patterns -[Patterns from recall that apply to this task] +[Patterns from ask that apply to this task] ## Constraints [Constraints that must be respected] @@ -107,7 +107,7 @@ If approval is missing or unclear, STOP and respond with: ## Step 7: Begin Implementation (Only After Approval) Proceed with the task, following the patterns and respecting the constraints shown above. -**CRITICAL**: You MUST call all MCP tools (` + "`task(next)`" + `, ` + "`recall`" + ` x2, ` + "`task(start)`" + `) before showing the brief and before requesting implementation approval. +**CRITICAL**: You MUST call all MCP tools (` + "`task(next)`" + `, ` + "`ask`" + ` x2, ` + "`task(start)`" + `) before showing the brief and before requesting implementation approval. ## Fallback (No MCP) ` + "```bash" + ` @@ -243,7 +243,7 @@ Scope: [scope] ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ Commands: /tw-done - Complete this task - /tw-brief - Fetch more context + /tw-ask - Fetch more context ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ ` + "```" + ` @@ -668,7 +668,7 @@ const slashExplainContent = `# Explain Code Symbol **Usage:** ` + "`/tw-explain `" + ` -**Example:** ` + "`/tw-explain NewRecallApp`" + ` +**Example:** ` + "`/tw-explain NewAskApp`" + ` Get a deep-dive explanation of a code symbol including its purpose, usage patterns, and call graph. This is an analysis command and must not be used to bypass planning, verification, or debug gates. @@ -736,12 +736,12 @@ taskwing mcp ` + "```" + ` ` -// slashBriefContent is the prompt content for /tw-brief -const slashBriefContent = `# Project Knowledge Brief +// slashAskContent is the prompt content for /tw-ask +const slashAskContent = `# Project Knowledge Brief This is a context-priming command and must not be used to bypass planning, verification, or debug gates. -Call MCP tool ` + "`recall`" + ` to get a compact project knowledge brief. +Call MCP tool ` + "`ask`" + ` to get a compact project knowledge brief. Use: ` + "```json" + ` @@ -755,3 +755,23 @@ If you need broader coverage, run: Present the returned summary and top results to prime the conversation with project knowledge. ` + +// slashRememberContent is the prompt content for /tw-remember +const slashRememberContent = `# Store Knowledge in Project Memory + +This is a persistence command and must not be used to bypass planning, verification, or debug gates. + +Call MCP tool ` + "`remember`" + ` to persist a decision, pattern, or insight to project memory. + +Use: +` + "```json" + ` +{"content": "[the knowledge to store]"} +` + "```" + ` + +Optionally specify a type (decision, pattern, constraint, note): +` + "```json" + ` +{"content": "[the knowledge to store]", "type": "decision"} +` + "```" + ` + +The content will be classified automatically using AI if no type is provided. +` diff --git a/cmd/slash_contract_test.go b/cmd/slash_contract_test.go deleted file mode 100644 index 8176291..0000000 --- a/cmd/slash_contract_test.go +++ /dev/null @@ -1,112 +0,0 @@ -package cmd - -import ( - "strings" - "testing" -) - -func assertContainsInOrder(t *testing.T, content string, parts ...string) { - t.Helper() - - start := 0 - for _, part := range parts { - idx := strings.Index(content[start:], part) - if idx < 0 { - t.Fatalf("expected content to include %q after position %d", part, start) - } - start += idx + len(part) - } -} - -func TestSlashContracts_CoreCommandsIncludeWorkflowContract(t *testing.T) { - core := map[string]string{ - "next": slashNextContent, - "done": slashDoneContent, - "plan": slashPlanContent, - "debug": slashDebugContent, - } - - for name, content := range core { - if !strings.Contains(content, "TaskWing Workflow Contract v1") { - t.Fatalf("/tw-%s missing workflow contract banner", name) - } - } -} - -func TestSlashContract_NextHasImplementationGate(t *testing.T) { - if !strings.Contains(slashNextContent, "Implementation Start Gate (Hard Gate)") { - t.Fatal("/tw-next missing implementation hard gate") - } - if !strings.Contains(slashNextContent, "REFUSAL: I can't start implementation yet.") { - t.Fatal("/tw-next missing refusal language for checkpoint gate") - } - - assertContainsInOrder(t, slashNextContent, - "## Step 5: Present Unified Task Brief", - "## Step 6: Implementation Start Gate (Hard Gate)", - "## Step 7: Begin Implementation (Only After Approval)", - ) -} - -func TestSlashContract_DoneHasVerificationGate(t *testing.T) { - if !strings.Contains(slashDoneContent, "## Step 2: Collect Fresh Verification Evidence") { - t.Fatal("/tw-done missing verification collection step") - } - if !strings.Contains(slashDoneContent, "REFUSAL: I can't mark this task done yet.") { - t.Fatal("/tw-done missing refusal language for verification gate") - } - - assertContainsInOrder(t, slashDoneContent, - "## Step 2: Collect Fresh Verification Evidence", - "## Step 4: Completion Gate (Hard Gate)", - "## Step 5: Mark Complete", - ) -} - -func TestSlashContract_PlanRequiresClarificationApproval(t *testing.T) { - if !strings.Contains(slashPlanContent, "Hard gate for this command:") { - t.Fatal("/tw-plan missing hard gate definition") - } - if !strings.Contains(slashPlanContent, "REFUSAL: I can't move past planning yet.") { - t.Fatal("/tw-plan missing refusal language for clarification checkpoint") - } - - assertContainsInOrder(t, slashPlanContent, - "## Step 2: Ask Clarifying Questions (Loop)", - "## Step 3: Clarification Checkpoint Approval (Hard Gate)", - "## Step 4: Generate Plan", - ) -} - -func TestSlashContract_DebugRequiresRootCauseEvidence(t *testing.T) { - if !strings.Contains(slashDebugContent, "## Phase 2: Root-Cause Evidence Collection (Hard Gate)") { - t.Fatal("/tw-debug missing root-cause hard gate") - } - if !strings.Contains(slashDebugContent, "REFUSAL: I can't propose a fix yet.") { - t.Fatal("/tw-debug missing refusal language for root-cause gate") - } - - assertContainsInOrder(t, slashDebugContent, - "## Phase 1: Capture Problem Statement", - "## Phase 2: Root-Cause Evidence Collection (Hard Gate)", - "## Phase 3: Present Investigation Plan", - "## Phase 4: Fix Proposal (Only After Evidence Gate Passes)", - ) -} - -func TestSlashContract_LightweightCommandsRemainReadOnly(t *testing.T) { - lightweight := map[string]string{ - "status": slashStatusContent, - "brief": slashBriefContent, - "explain": slashExplainContent, - "simplify": slashSimplifyContent, - } - - for name, content := range lightweight { - if !strings.Contains(content, "must not be used to bypass planning, verification, or debug gates") && - !strings.Contains(content, "must not bypass planning, verification, or debugging gates") && - !strings.Contains(content, "Do not use it to bypass plan, verification, or debug gates") { - t.Fatalf("/tw-%s missing lightweight guardrail language", name) - } - } -} diff --git a/cmd/slash_test.go b/cmd/slash_test.go deleted file mode 100644 index 28410e0..0000000 --- a/cmd/slash_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package cmd - -import ( - "reflect" - "sort" - "strings" - "testing" - - "github.com/josephgoksu/TaskWing/internal/bootstrap" -) - -func TestSlashContentRegistry_CoversCanonicalCatalog(t *testing.T) { - for _, slash := range bootstrap.SlashCommands { - if _, ok := slashContents[slash.SlashCmd]; !ok { - t.Fatalf("missing slash content mapping for %q", slash.SlashCmd) - } - } -} - -func TestSlashRuntimeCommands_MatchCanonicalCatalog(t *testing.T) { - expected := bootstrap.SlashCommandNames() - sort.Strings(expected) - - actual := availableSlashCommands(slashCmd) - if !reflect.DeepEqual(actual, expected) { - t.Fatalf("slash command registry drift\nexpected: %v\nactual: %v", expected, actual) - } -} - -func TestSlashUnknownCommand_UsesRuntimeAvailability(t *testing.T) { - err := slashCmd.RunE(slashCmd, []string{"does-not-exist"}) - if err == nil { - t.Fatal("expected unknown slash command error") - } - - errMsg := err.Error() - expectedList := strings.Join(availableSlashCommands(slashCmd), ", ") - if !strings.Contains(errMsg, expectedList) { - t.Fatalf("unknown command error should list runtime available commands %q, got %q", expectedList, errMsg) - } -} diff --git a/cmd/task.go b/cmd/task.go index 3b2b53f..59af370 100644 --- a/cmd/task.go +++ b/cmd/task.go @@ -134,7 +134,7 @@ func runTaskList(cmd *cobra.Command, args []string) error { Validation []string `json:"validation_steps"` Scope string `json:"scope"` Keywords []string `json:"keywords"` - SuggestedRecallQueries []string `json:"suggestedRecallQueries"` + SuggestedAskQueries []string `json:"suggestedAskQueries"` } var jsonTasks []taskJSON for _, tp := range allTasks { @@ -152,7 +152,7 @@ func runTaskList(cmd *cobra.Command, args []string) error { Validation: t.ValidationSteps, Scope: t.Scope, Keywords: t.Keywords, - SuggestedRecallQueries: t.SuggestedRecallQueries, + SuggestedAskQueries: t.SuggestedAskQueries, }) } return printJSON(jsonTasks) diff --git a/cmd/task_test.go b/cmd/task_test.go deleted file mode 100644 index 041d767..0000000 --- a/cmd/task_test.go +++ /dev/null @@ -1,699 +0,0 @@ -/* -Copyright © 2025 Joseph Goksu josephgoksu@gmail.com -*/ -package cmd - -import ( - "context" - "encoding/json" - "os" - "os/exec" - "strings" - "testing" - - "github.com/josephgoksu/TaskWing/internal/memory" - "github.com/josephgoksu/TaskWing/internal/task" - "github.com/josephgoksu/TaskWing/internal/util" -) - -// TestFormatTaskStatus_AllKnownStatuses verifies all defined TaskStatus values render correctly. -func TestFormatTaskStatus_AllKnownStatuses(t *testing.T) { - tests := []struct { - status task.TaskStatus - contains string // substring that should be in the output - }{ - {task.StatusDraft, "draft"}, - {task.StatusPending, "pending"}, - {task.StatusInProgress, "active"}, - {task.StatusVerifying, "verify"}, - {task.StatusCompleted, "done"}, - {task.StatusFailed, "failed"}, - {task.StatusBlocked, "blocked"}, - {task.StatusReady, "ready"}, - } - - for _, tc := range tests { - t.Run(string(tc.status), func(t *testing.T) { - result := formatTaskStatus(tc.status) - if result == "" { - t.Error("formatTaskStatus returned empty string") - } - // Strip ANSI codes for checking content - stripped := stripANSI(result) - if !strings.Contains(strings.ToLower(stripped), tc.contains) { - t.Errorf("formatTaskStatus(%q) = %q, want string containing %q", tc.status, stripped, tc.contains) - } - }) - } -} - -// TestFormatTaskStatus_DoneAlias verifies "done" as an alias for StatusCompleted. -func TestFormatTaskStatus_DoneAlias(t *testing.T) { - result := formatTaskStatus("done") - stripped := stripANSI(result) - if !strings.Contains(strings.ToLower(stripped), "done") { - t.Errorf("formatTaskStatus(\"done\") = %q, want string containing 'done'", stripped) - } -} - -// TestFormatTaskStatus_UnknownStatus verifies unknown statuses render gracefully. -func TestFormatTaskStatus_UnknownStatus(t *testing.T) { - unknownStatuses := []task.TaskStatus{ - "invalid", - "garbage", - "", - "some_future_status", - "COMPLETED", // Wrong case - } - - for _, status := range unknownStatuses { - t.Run(string(status), func(t *testing.T) { - // Should not panic - result := formatTaskStatus(status) - if result == "" { - t.Error("formatTaskStatus returned empty string for unknown status") - } - stripped := stripANSI(result) - if !strings.Contains(strings.ToLower(stripped), "unknown") { - t.Errorf("formatTaskStatus(%q) = %q, want string containing 'unknown'", status, stripped) - } - }) - } -} - -// TestFormatTaskStatus_FixedWidth verifies all status strings have consistent width. -func TestFormatTaskStatus_FixedWidth(t *testing.T) { - statuses := []task.TaskStatus{ - task.StatusDraft, - task.StatusPending, - task.StatusInProgress, - task.StatusVerifying, - task.StatusCompleted, - task.StatusFailed, - task.StatusBlocked, - task.StatusReady, - "unknown_status", - } - - // Get the width of the first status - firstStripped := stripANSI(formatTaskStatus(statuses[0])) - expectedWidth := len(firstStripped) - - for _, status := range statuses { - t.Run(string(status), func(t *testing.T) { - result := formatTaskStatus(status) - stripped := stripANSI(result) - if len(stripped) != expectedWidth { - t.Errorf("formatTaskStatus(%q) width = %d, want %d (value: %q)", status, len(stripped), expectedWidth, stripped) - } - }) - } -} - -// TestFormatTaskStatus_NoPanic verifies the function never panics. -func TestFormatTaskStatus_NoPanic(t *testing.T) { - // Test with various edge cases that should not cause panic - testCases := []task.TaskStatus{ - "", - "null", - "undefined", - "\x00", // null byte - "status\nstatus", // newline - "status\tstatus", // tab - task.TaskStatus(strings.Repeat("x", 1000)), // very long string - } - - for i, tc := range testCases { - t.Run(string(rune('A'+i)), func(t *testing.T) { - defer func() { - if r := recover(); r != nil { - t.Errorf("formatTaskStatus panicked with input %q: %v", tc, r) - } - }() - _ = formatTaskStatus(tc) - }) - } -} - -// TestFormatTaskStatus_TableDriven comprehensive table-driven test. -func TestFormatTaskStatus_TableDriven(t *testing.T) { - tests := []struct { - name string - status task.TaskStatus - wantLabel string - wantUnknown bool - }{ - {"completed", task.StatusCompleted, "done", false}, - {"done_alias", "done", "done", false}, - {"in_progress", task.StatusInProgress, "active", false}, - {"pending", task.StatusPending, "pending", false}, - {"draft", task.StatusDraft, "draft", false}, - {"blocked", task.StatusBlocked, "blocked", false}, - {"ready", task.StatusReady, "ready", false}, - {"failed", task.StatusFailed, "failed", false}, - {"verifying", task.StatusVerifying, "verify", false}, - {"unknown_garbage", "garbage", "unknown", true}, - {"unknown_empty", "", "unknown", true}, - {"unknown_case_sensitive", "PENDING", "unknown", true}, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result := formatTaskStatus(tc.status) - stripped := stripANSI(result) - - if !strings.Contains(strings.ToLower(stripped), tc.wantLabel) { - t.Errorf("formatTaskStatus(%q) = %q, want string containing %q", tc.status, stripped, tc.wantLabel) - } - - if tc.wantUnknown && !strings.Contains(strings.ToLower(stripped), "unknown") { - t.Errorf("formatTaskStatus(%q) should render as 'unknown'", tc.status) - } - }) - } -} - -// stripANSI removes ANSI escape codes from a string for easier testing. -func stripANSI(s string) string { - // Simple ANSI stripping - removes escape sequences - result := strings.Builder{} - i := 0 - for i < len(s) { - if s[i] == '\x1b' && i+1 < len(s) && s[i+1] == '[' { - // Skip until 'm' (end of ANSI sequence) - for i < len(s) && s[i] != 'm' { - i++ - } - i++ // skip the 'm' - } else { - result.WriteByte(s[i]) - i++ - } - } - return result.String() -} - -// TestStripANSI verifies the helper function works correctly. -func TestStripANSI(t *testing.T) { - tests := []struct { - input string - want string - }{ - {"\x1b[32mgreen\x1b[0m", "green"}, - {"\x1b[1;31mred bold\x1b[0m", "red bold"}, - {"no ansi", "no ansi"}, - {"", ""}, - {"\x1b[42m\x1b[1m[done] \x1b[0m", "[done] "}, - } - - for _, tc := range tests { - t.Run(tc.want, func(t *testing.T) { - got := stripANSI(tc.input) - if got != tc.want { - t.Errorf("stripANSI(%q) = %q, want %q", tc.input, got, tc.want) - } - }) - } -} - -// TestTaskListArchivedFilter tests that the include-archived flag exists. -func TestTaskListArchivedFilter(t *testing.T) { - // Verify the flag is registered on the command - flag := taskListCmd.Flags().Lookup("include-archived") - if flag == nil { - t.Fatal("expected --include-archived flag to be registered on task list command") - } - - if flag.DefValue != "false" { - t.Errorf("expected default value 'false', got %q", flag.DefValue) - } - - // Verify the flag help text - usage := flag.Usage - if usage == "" { - t.Error("expected --include-archived flag to have usage text") - } - if !strings.Contains(strings.ToLower(usage), "archived") { - t.Errorf("flag usage should mention 'archived', got: %q", usage) - } -} - -// TestTaskListCommandHelp verifies help mentions archived filtering. -func TestTaskListCommandHelp(t *testing.T) { - longHelp := taskListCmd.Long - if !strings.Contains(longHelp, "archived") { - t.Error("task list long help should mention archived plans") - } - if !strings.Contains(longHelp, "--include-archived") { - t.Error("task list long help should mention --include-archived flag") - } -} - -// TestTaskListErrorPropagation verifies that errors are properly returned, not swallowed. -func TestTaskListErrorPropagation(t *testing.T) { - // Verify the RunE function is set (meaning errors will be returned to Cobra) - if taskListCmd.RunE == nil { - t.Fatal("taskListCmd.RunE should be set to return errors") - } - - // Verify Run is not set (which would swallow errors) - if taskListCmd.Run != nil { - t.Error("taskListCmd.Run should not be set; use RunE for error propagation") - } -} - -// TestTaskListExitOnError verifies non-zero exit on repository failure. -func TestTaskListExitOnError(t *testing.T) { - if testing.Short() { - t.Skip("skipping smoke test in short mode") - } - - // Create a temp directory that has no .taskwing/memory structure - tmpDir, err := os.MkdirTemp("", "taskwing-smoke-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - // Run the CLI with the temp dir as the working directory - // This should fail because there's no memory.db - cmd := exec.Command("go", "run", ".", "task", "list") - cmd.Dir = tmpDir - cmd.Env = append(os.Environ(), "HOME="+tmpDir) // Prevent using real home config - - output, err := cmd.CombinedOutput() - - // We expect an error (non-zero exit) because the repo doesn't exist - if err == nil { - t.Log("Command output:", string(output)) - // Note: If this passes, it might mean the command gracefully handles missing repos - // by showing "No plans found" which is acceptable behavior - if strings.Contains(string(output), "No plans found") { - t.Log("Command succeeded with 'No plans found' - this is acceptable behavior") - return - } - // If no error and no "No plans found", something unexpected happened - t.Log("Command succeeded unexpectedly without error") - } else { - // Verify exit error - exitErr, ok := err.(*exec.ExitError) - if !ok { - t.Fatalf("expected exec.ExitError, got %T: %v", err, err) - } - if exitErr.ExitCode() == 0 { - t.Error("expected non-zero exit code on failure") - } - - // Verify error message contains useful context - outputStr := string(output) - if outputStr == "" { - t.Error("expected error message in output, got empty") - } - - // Should have some indication of the failure - hasContext := strings.Contains(strings.ToLower(outputStr), "error") || - strings.Contains(strings.ToLower(outputStr), "failed") || - strings.Contains(strings.ToLower(outputStr), "no such file") - if !hasContext { - t.Logf("Output should contain error context: %s", outputStr) - } - } -} - -// TestTaskListVerboseError verifies verbose mode provides additional context. -func TestTaskListVerboseError(t *testing.T) { - if testing.Short() { - t.Skip("skipping smoke test in short mode") - } - - // The --verbose flag should be available - flag := taskListCmd.Root().PersistentFlags().Lookup("verbose") - if flag == nil { - // Check if it's inherited from root - flag = rootCmd.PersistentFlags().Lookup("verbose") - } - if flag == nil { - t.Log("verbose flag not found on taskListCmd, may be on root") - } - - // Test that the flag exists by checking the root command - if rootCmd.PersistentFlags().Lookup("verbose") == nil { - t.Error("expected --verbose flag to be registered on root command") - } -} - -// TestTaskShowAcceptsPrefix verifies task show command accepts ID prefixes. -func TestTaskShowAcceptsPrefix(t *testing.T) { - // Verify the command help mentions prefix support - longHelp := taskShowCmd.Long - if !strings.Contains(longHelp, "prefix") { - t.Error("task show long help should mention prefix support") - } - - // Verify RunE is set (for proper error handling) - if taskShowCmd.RunE == nil { - t.Fatal("taskShowCmd.RunE should be set") - } - - // Verify Args requires exactly 1 argument - if taskShowCmd.Args == nil { - t.Error("taskShowCmd.Args should be set") - } -} - -// TestTaskShowHelpExamples verifies help shows prefix examples. -func TestTaskShowHelpExamples(t *testing.T) { - longHelp := taskShowCmd.Long - - // Should have examples showing different ID formats - if !strings.Contains(longHelp, "task-abc") { - t.Error("task show help should have prefix example") - } - if !strings.Contains(longHelp, "auto-prepended") || !strings.Contains(longHelp, "abc") { - t.Error("task show help should mention auto-prepending task- prefix") - } -} - -// TestTaskListFormatting verifies that task list uses consistent ID formatting. -func TestTaskListFormatting(t *testing.T) { - // Verify util package is used for ID formatting by checking it can be imported - // and ShortID works correctly with TaskIDLength - testID := "task-abcdef12" - shortID := util.ShortID(testID, util.TaskIDLength) - - // TaskIDLength is 13, so full ID should be preserved - if shortID != testID { - t.Errorf("ShortID with TaskIDLength should preserve full ID, got %q want %q", shortID, testID) - } - - // Verify TaskIDLength constant is correct - if util.TaskIDLength != 13 { - t.Errorf("TaskIDLength = %d, want 13", util.TaskIDLength) - } - - // Verify PlanIDLength constant is correct - if util.PlanIDLength != 13 { - t.Errorf("PlanIDLength = %d, want 13", util.PlanIDLength) - } -} - -// TestTaskListJSONIncludesPlanStatus verifies plan_status is in JSON output struct. -func TestTaskListJSONIncludesPlanStatus(t *testing.T) { - // This test verifies that the JSON struct definition includes plan_status. - // We can't easily run the full command without a real database, - // so we verify the struct definition by checking the source. - - // The taskJSON struct in runTaskList should have plan_status field. - // This is a static verification that the field exists. - - // Create a sample JSON structure matching expected output - type taskJSON struct { - ID string `json:"id"` - PlanID string `json:"plan_id"` - PlanStatus string `json:"plan_status"` - Title string `json:"title"` - Description string `json:"description"` - Status string `json:"status"` - Priority int `json:"priority"` - Agent string `json:"assigned_agent"` - Acceptance []string `json:"acceptance_criteria"` - Validation []string `json:"validation_steps"` - Scope string `json:"scope"` - Keywords []string `json:"keywords"` - SuggestedRecallQueries []string `json:"suggestedRecallQueries"` - } - - // Verify we can create and marshal a sample - sample := taskJSON{ - ID: "task-123", - PlanID: "plan-456", - PlanStatus: "active", - Title: "Test Task", - Status: "pending", - } - - // Marshal to verify plan_status is included - data, err := json.Marshal(sample) - if err != nil { - t.Fatalf("failed to marshal sample: %v", err) - } - - jsonStr := string(data) - if !strings.Contains(jsonStr, `"plan_status"`) { - t.Errorf("JSON output should contain 'plan_status', got: %s", jsonStr) - } - if !strings.Contains(jsonStr, `"active"`) { - t.Errorf("JSON output should contain plan_status value 'active', got: %s", jsonStr) - } -} - -// === Integration Tests with SQLite === - -// setupTestRepo creates a temp directory with a repository seeded with test data. -// Returns the repo and a cleanup function. -func setupTestRepo(t *testing.T) (*memory.Repository, func()) { - t.Helper() - - tmpDir, err := os.MkdirTemp("", "taskwing-integration-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - - repo, err := memory.NewDefaultRepository(tmpDir) - if err != nil { - _ = os.RemoveAll(tmpDir) - t.Fatalf("failed to create repository: %v", err) - } - - cleanup := func() { - _ = repo.Close() - _ = os.RemoveAll(tmpDir) - } - - return repo, cleanup -} - -// seedTestData creates test plans and tasks in the repository. -func seedTestData(t *testing.T, repo *memory.Repository) (activePlan *task.Plan, archivedPlan *task.Plan) { - t.Helper() - - // Create an active plan with tasks - activePlan = &task.Plan{ - ID: "plan-active01", - Goal: "Active test plan", - Status: task.PlanStatusActive, - } - if err := repo.GetDB().CreatePlan(activePlan); err != nil { - t.Fatalf("failed to create active plan: %v", err) - } - - activeTask1 := &task.Task{ - ID: "task-act00001", - PlanID: activePlan.ID, - Title: "Active task 1", - Description: "First task in active plan", - Status: task.StatusPending, - Priority: 50, - } - if err := repo.GetDB().CreateTask(activeTask1); err != nil { - t.Fatalf("failed to create active task 1: %v", err) - } - - activeTask2 := &task.Task{ - ID: "task-act00002", - PlanID: activePlan.ID, - Title: "Active task 2", - Description: "Second task in active plan", - Status: task.StatusInProgress, - Priority: 30, - } - if err := repo.GetDB().CreateTask(activeTask2); err != nil { - t.Fatalf("failed to create active task 2: %v", err) - } - - // Create an archived plan with tasks - archivedPlan = &task.Plan{ - ID: "plan-archive1", - Goal: "Archived test plan", - Status: task.PlanStatusArchived, - } - if err := repo.GetDB().CreatePlan(archivedPlan); err != nil { - t.Fatalf("failed to create archived plan: %v", err) - } - - archivedTask := &task.Task{ - ID: "task-arch0001", - PlanID: archivedPlan.ID, - Title: "Archived task", - Description: "Task in archived plan", - Status: task.StatusCompleted, - Priority: 50, - } - if err := repo.GetDB().CreateTask(archivedTask); err != nil { - t.Fatalf("failed to create archived task: %v", err) - } - - return activePlan, archivedPlan -} - -// TestTaskListIntegration_ExcludesArchivedByDefault tests that archived plans are excluded by default. -func TestTaskListIntegration_ExcludesArchivedByDefault(t *testing.T) { - repo, cleanup := setupTestRepo(t) - defer cleanup() - - activePlan, archivedPlan := seedTestData(t, repo) - - // List all plans - plans, err := repo.GetDB().ListPlans() - if err != nil { - t.Fatalf("failed to list plans: %v", err) - } - - // Should have both plans - if len(plans) != 2 { - t.Fatalf("expected 2 plans, got %d", len(plans)) - } - - // Count tasks excluding archived plans (simulating default behavior) - var activeTasks []task.Task - for _, p := range plans { - if p.Status == task.PlanStatusArchived { - continue // This is what the CLI does by default - } - tasks, err := repo.GetDB().ListTasks(p.ID) - if err != nil { - t.Fatalf("failed to list tasks for plan %s: %v", p.ID, err) - } - activeTasks = append(activeTasks, tasks...) - } - - // Should only see active plan's tasks - if len(activeTasks) != 2 { - t.Errorf("expected 2 active tasks, got %d", len(activeTasks)) - } - - // Verify they're from the active plan - for _, tsk := range activeTasks { - if tsk.PlanID != activePlan.ID { - t.Errorf("task %s should be from active plan %s, got %s", tsk.ID, activePlan.ID, tsk.PlanID) - } - } - - // Count tasks including archived plans (simulating --include-archived) - var allTasks []task.Task - for _, p := range plans { - tasks, err := repo.GetDB().ListTasks(p.ID) - if err != nil { - t.Fatalf("failed to list tasks for plan %s: %v", p.ID, err) - } - allTasks = append(allTasks, tasks...) - } - - // Should see all 3 tasks - if len(allTasks) != 3 { - t.Errorf("expected 3 total tasks with include-archived, got %d", len(allTasks)) - } - - // Verify archived plan's task is included - foundArchived := false - for _, tsk := range allTasks { - if tsk.PlanID == archivedPlan.ID { - foundArchived = true - break - } - } - if !foundArchived { - t.Error("archived plan's tasks should be included with --include-archived") - } -} - -// TestTaskShowIntegration_PrefixResolution tests prefix resolution with real SQLite. -func TestTaskShowIntegration_PrefixResolution(t *testing.T) { - repo, cleanup := setupTestRepo(t) - defer cleanup() - - seedTestData(t, repo) - ctx := context.Background() - - t.Run("full ID resolves", func(t *testing.T) { - resolved, err := util.ResolveTaskID(ctx, repo, "task-act00001") - if err != nil { - t.Fatalf("failed to resolve full ID: %v", err) - } - if resolved != "task-act00001" { - t.Errorf("resolved = %q, want %q", resolved, "task-act00001") - } - }) - - t.Run("unique prefix resolves", func(t *testing.T) { - // "task-arch" should uniquely match "task-arch0001" - resolved, err := util.ResolveTaskID(ctx, repo, "task-arch") - if err != nil { - t.Fatalf("failed to resolve unique prefix: %v", err) - } - if resolved != "task-arch0001" { - t.Errorf("resolved = %q, want %q", resolved, "task-arch0001") - } - }) - - t.Run("ambiguous prefix errors", func(t *testing.T) { - // "task-act" matches both "task-act00001" and "task-act00002" - _, err := util.ResolveTaskID(ctx, repo, "task-act") - if err == nil { - t.Fatal("expected error for ambiguous prefix, got nil") - } - if !strings.Contains(err.Error(), "ambiguous") { - t.Errorf("error should mention 'ambiguous', got: %v", err) - } - }) - - t.Run("nonexistent prefix errors", func(t *testing.T) { - _, err := util.ResolveTaskID(ctx, repo, "task-nonexistent") - if err == nil { - t.Fatal("expected error for nonexistent prefix, got nil") - } - if !strings.Contains(err.Error(), "not found") { - t.Errorf("error should mention 'not found', got: %v", err) - } - }) - - t.Run("prefix without task- prepended", func(t *testing.T) { - // "arch" should be prepended to "task-arch" and resolve - resolved, err := util.ResolveTaskID(ctx, repo, "arch") - if err != nil { - t.Fatalf("failed to resolve prefix without task-: %v", err) - } - if resolved != "task-arch0001" { - t.Errorf("resolved = %q, want %q", resolved, "task-arch0001") - } - }) -} - -// TestPlanPrefixResolution tests plan ID prefix resolution. -func TestPlanPrefixResolution(t *testing.T) { - repo, cleanup := setupTestRepo(t) - defer cleanup() - - seedTestData(t, repo) - ctx := context.Background() - - t.Run("unique plan prefix resolves", func(t *testing.T) { - resolved, err := util.ResolvePlanID(ctx, repo, "plan-active") - if err != nil { - t.Fatalf("failed to resolve plan prefix: %v", err) - } - if resolved != "plan-active01" { - t.Errorf("resolved = %q, want %q", resolved, "plan-active01") - } - }) - - t.Run("plan prefix without plan- prepended", func(t *testing.T) { - resolved, err := util.ResolvePlanID(ctx, repo, "archive") - if err != nil { - t.Fatalf("failed to resolve plan prefix without plan-: %v", err) - } - if resolved != "plan-archive1" { - t.Errorf("resolved = %q, want %q", resolved, "plan-archive1") - } - }) -} diff --git a/docs/PRODUCT_VISION.md b/docs/PRODUCT_VISION.md index 866163e..8ea2db9 100644 --- a/docs/PRODUCT_VISION.md +++ b/docs/PRODUCT_VISION.md @@ -67,6 +67,7 @@ Brand names and logos are trademarks of their respective owners; usage here indi - `taskwing bootstrap` - `taskwing goal ""` +- `taskwing ask ""` - `taskwing task` - `taskwing plan status` - `taskwing slash` @@ -81,7 +82,7 @@ Brand names and logos are trademarks of their respective owners; usage here indi | Tool | Description | |------|-------------| -| `recall` | Retrieve project knowledge (decisions, patterns, constraints) | +| `ask` | Search project knowledge (decisions, patterns, constraints) | | `task` | Unified task lifecycle (`next`, `current`, `start`, `complete`) | | `plan` | Plan management (`clarify`, `decompose`, `expand`, `generate`, `finalize`, `audit`) | | `code` | Code intelligence (`find`, `search`, `explain`, `callers`, `impact`, `simplify`) | diff --git a/docs/TUTORIAL.md b/docs/TUTORIAL.md index 8f905d5..bf6fae4 100644 --- a/docs/TUTORIAL.md +++ b/docs/TUTORIAL.md @@ -155,6 +155,7 @@ Recommended Bedrock model IDs: - `taskwing bootstrap` - `taskwing goal ""` +- `taskwing ask ""` - `taskwing task` - `taskwing plan status` - `taskwing slash` @@ -169,7 +170,7 @@ Recommended Bedrock model IDs: | Tool | Description | |------|-------------| -| `recall` | Retrieve project knowledge (decisions, patterns, constraints) | +| `ask` | Search project knowledge (decisions, patterns, constraints) | | `task` | Unified task lifecycle (`next`, `current`, `start`, `complete`) | | `plan` | Plan management (`clarify`, `decompose`, `expand`, `generate`, `finalize`, `audit`) | | `code` | Code intelligence (`find`, `search`, `explain`, `callers`, `impact`, `simplify`) | diff --git a/docs/WORKFLOW_CONTRACT_V1.md b/docs/WORKFLOW_CONTRACT_V1.md index e320c5e..06bd799 100644 --- a/docs/WORKFLOW_CONTRACT_V1.md +++ b/docs/WORKFLOW_CONTRACT_V1.md @@ -47,5 +47,5 @@ KPI: ## Operating Policy - These gates are hard blockers for core workflow commands. -- Commands that are primarily read-only (`/tw-brief`, `/tw-status`, `/tw-explain`, `/tw-simplify`) remain lightweight but must not bypass these gates. +- Commands that are primarily read-only (`/tw-ask`, `/tw-status`, `/tw-explain`, `/tw-simplify`) remain lightweight but must not bypass these gates. - Prompt regressions against this contract are release blockers. diff --git a/docs/_partials/core_commands.md b/docs/_partials/core_commands.md index 03d7bea..ecaf254 100644 --- a/docs/_partials/core_commands.md +++ b/docs/_partials/core_commands.md @@ -1,5 +1,6 @@ - `taskwing bootstrap` - `taskwing goal ""` +- `taskwing ask ""` - `taskwing task` - `taskwing plan status` - `taskwing slash` diff --git a/docs/_partials/mcp_tools.md b/docs/_partials/mcp_tools.md index f47c5e0..3d6b05d 100644 --- a/docs/_partials/mcp_tools.md +++ b/docs/_partials/mcp_tools.md @@ -1,6 +1,6 @@ | Tool | Description | |------|-------------| -| `recall` | Retrieve project knowledge (decisions, patterns, constraints) | +| `ask` | Search project knowledge (decisions, patterns, constraints) | | `task` | Unified task lifecycle (`next`, `current`, `start`, `complete`) | | `plan` | Plan management (`clarify`, `decompose`, `expand`, `generate`, `finalize`, `audit`) | | `code` | Code intelligence (`find`, `search`, `explain`, `callers`, `impact`, `simplify`) | diff --git a/docs/architecture/ADR_CONTEXT_BINDING.md b/docs/architecture/ADR_CONTEXT_BINDING.md index 9853c6b..b5cc40c 100644 --- a/docs/architecture/ADR_CONTEXT_BINDING.md +++ b/docs/architecture/ADR_CONTEXT_BINDING.md @@ -27,11 +27,11 @@ At task creation time (`PlanApp.parseTasksFromMetadata`): ```go // internal/app/plan.go -t.EnrichAIFields() // Generates SuggestedRecallQueries +t.EnrichAIFields() // Generates SuggestedAskQueries -// Execute recall queries and embed context -if a.TaskEnricher != nil && len(t.SuggestedRecallQueries) > 0 { - if contextSummary, err := a.TaskEnricher(ctx, t.SuggestedRecallQueries); err == nil { +// Execute ask queries and embed context +if a.TaskEnricher != nil && len(t.SuggestedAskQueries) > 0 { + if contextSummary, err := a.TaskEnricher(ctx, t.SuggestedAskQueries); err == nil { t.ContextSummary = contextSummary // Embedded in task record } } @@ -47,10 +47,10 @@ At task presentation time (`FormatRichContext`): // internal/task/presentation.go if t.ContextSummary != "" { // Use pre-computed early-bound context (preferred) - recallContext = "\n" + t.ContextSummary -} else if len(t.SuggestedRecallQueries) > 0 && searchFn != nil { + askContext = "\n" + t.ContextSummary +} else if len(t.SuggestedAskQueries) > 0 && searchFn != nil { // Fallback: Fetch context dynamically - for _, query := range t.SuggestedRecallQueries { + for _, query := range t.SuggestedAskQueries { results, _ := searchFn(ctx, query, 3) // ... aggregate results } @@ -73,11 +73,11 @@ if t.ContextSummary != "" { │ ▼ │ │ 2. EnrichAIFields() generates: │ │ - Scope (inferred from keywords) │ -│ - SuggestedRecallQueries (3 queries) │ +│ - SuggestedAskQueries (3 queries) │ │ │ │ │ ▼ │ -│ 3. TaskEnricher executes ALL recall queries │ -│ - Calls RecallApp.Query() for each query │ +│ 3. TaskEnricher executes ALL ask queries │ +│ - Calls AskApp.Query() for each query │ │ - Aggregates results (deduped by summary) │ │ - Truncates content to 200 chars │ │ │ │ @@ -98,7 +98,7 @@ if t.ContextSummary != "" { │ 2. Check: Does Task.ContextSummary exist? │ │ ├── YES: Use it directly (fast path) │ │ │ │ -│ └── NO: Execute SuggestedRecallQueries (fallback) │ +│ └── NO: Execute SuggestedAskQueries (fallback) │ │ - Fetch fresh context from knowledge graph │ │ - Deduplicate by summary │ │ - Truncate content to 300 chars for display │ @@ -116,13 +116,13 @@ if t.ContextSummary != "" { |------|------| | **Reliable**: Context always available | **Staleness**: May not reflect latest knowledge | | **Fast**: No runtime queries needed | **Storage**: Increases task record size | -| **Offline-capable**: Works without recall service | **One-time**: Context frozen at creation time | +| **Offline-capable**: Works without ask service | **One-time**: Context frozen at creation time | ### Late Binding | Pros | Cons | |------|------| -| **Fresh**: Always reflects current knowledge | **Service dependency**: Requires recall service | +| **Fresh**: Always reflects current knowledge | **Service dependency**: Requires ask service | | **Lighter storage**: Queries stored, not results | **Slower**: N queries per task display | | **Adaptable**: Queries can evolve | **Unreliable**: May fail if service down | @@ -143,8 +143,8 @@ if t.ContextSummary != "" { 1. **Tasks created with embedded context** - AI assistants receive relevant architectural decisions immediately 2. **Backward compatible** - Old tasks (without `ContextSummary`) still work via late binding -3. **Resilient** - System works even if recall service is temporarily unavailable (uses cached context) -4. **Efficient** - Avoids repeated recall queries during task execution +3. **Resilient** - System works even if ask service is temporarily unavailable (uses cached context) +4. **Efficient** - Avoids repeated ask queries during task execution ### Negative @@ -168,7 +168,7 @@ if t.ContextSummary != "" { |------|----------------| | `internal/task/models.go` | `EnrichAIFields()` - scope inference, keyword extraction, query generation | | `internal/task/scope_config.go` | Configurable scope keywords via viper | -| `internal/app/plan.go` | `TaskEnricher` - executes recall queries at creation time | +| `internal/app/plan.go` | `TaskEnricher` - executes ask queries at creation time | | `internal/task/presentation.go` | `FormatRichContext()` - early binding display with late binding fallback | ### Configuration diff --git a/docs/architecture/PLANNING_FORENSIC_DOCUMENTATION.md b/docs/architecture/PLANNING_FORENSIC_DOCUMENTATION.md index fc72596..6d1ef6c 100644 --- a/docs/architecture/PLANNING_FORENSIC_DOCUMENTATION.md +++ b/docs/architecture/PLANNING_FORENSIC_DOCUMENTATION.md @@ -28,7 +28,7 @@ **What it observes:** - User-provided goals and answers - Git working directory state -- Project knowledge context (via recall) +- Project knowledge context (via ask) **What it does NOT control:** - Actual code execution by AI agents @@ -348,7 +348,7 @@ CREATE TABLE tasks ( context_summary TEXT, scope TEXT, keywords TEXT, -- JSON array - suggested_recall_queries TEXT, -- JSON array + suggested_ask_queries TEXT, -- JSON array claimed_by TEXT, claimed_at TEXT, completed_at TEXT, diff --git a/go.mod b/go.mod index d177d20..dbf5ee4 100644 --- a/go.mod +++ b/go.mod @@ -57,8 +57,8 @@ require ( github.com/beorn7/perks v1.0.1 // indirect github.com/buger/jsonparser v1.1.1 // indirect github.com/bytedance/gopkg v0.1.3 // indirect - github.com/bytedance/sonic v1.14.1 // indirect - github.com/bytedance/sonic/loader v0.3.0 // indirect + github.com/bytedance/sonic v1.15.0 // indirect + github.com/bytedance/sonic/loader v0.5.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect github.com/charmbracelet/x/ansi v0.10.1 // indirect diff --git a/go.sum b/go.sum index f02e99a..7181dca 100644 --- a/go.sum +++ b/go.sum @@ -65,10 +65,10 @@ github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= github.com/bytedance/mockey v1.3.0 h1:ONLRdvhqmCfr9rTasUB8ZKCfvbdD2tohOg4u+4Q/ed0= github.com/bytedance/mockey v1.3.0/go.mod h1:1BPHF9sol5R1ud/+0VEHGQq/+i2lN+GTsr3O2Q9IENY= -github.com/bytedance/sonic v1.14.1 h1:FBMC0zVz5XUmE4z9wF4Jey0An5FueFvOsTKKKtwIl7w= -github.com/bytedance/sonic v1.14.1/go.mod h1:gi6uhQLMbTdeP0muCnrjHLeCUPyb70ujhnNlhOylAFc= -github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= -github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= +github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE= +github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k= +github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE= +github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM= github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4= @@ -343,13 +343,15 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= diff --git a/internal/agents/impl/utility_agents_test.go b/internal/agents/impl/utility_agents_test.go deleted file mode 100644 index 51eb347..0000000 --- a/internal/agents/impl/utility_agents_test.go +++ /dev/null @@ -1,56 +0,0 @@ -package impl - -import ( - "testing" - - "github.com/josephgoksu/TaskWing/internal/agents/core" -) - -func TestUtilityAgentsRegistered(t *testing.T) { - // Get all registered agents - registry := core.Registry() - - // Build a map for easy lookup - registered := make(map[string]bool) - for _, info := range registry { - registered[info.ID] = true - } - - // Verify our utility agents are registered - expectedAgents := []string{"simplify", "explain", "debug"} - for _, id := range expectedAgents { - if !registered[id] { - t.Errorf("Agent %q not found in registry", id) - } - } -} - -func TestSimplifyAgentInfo(t *testing.T) { - info := core.GetAgentByID("simplify") - if info == nil { - t.Fatal("simplify agent not found") - } - if info.Name != "Code Simplification" { - t.Errorf("expected name 'Code Simplification', got %q", info.Name) - } -} - -func TestExplainAgentInfo(t *testing.T) { - info := core.GetAgentByID("explain") - if info == nil { - t.Fatal("explain agent not found") - } - if info.Name != "Code Explanation" { - t.Errorf("expected name 'Code Explanation', got %q", info.Name) - } -} - -func TestDebugAgentInfo(t *testing.T) { - info := core.GetAgentByID("debug") - if info == nil { - t.Fatal("debug agent not found") - } - if info.Name != "Debug Helper" { - t.Errorf("expected name 'Debug Helper', got %q", info.Name) - } -} diff --git a/internal/agents/tools/budget_test.go b/internal/agents/tools/budget_test.go deleted file mode 100644 index 4f3da31..0000000 --- a/internal/agents/tools/budget_test.go +++ /dev/null @@ -1,104 +0,0 @@ -package tools - -import "testing" - -func TestNewContextBudget(t *testing.T) { - budget := NewContextBudget(10000) - if budget.Total() != 10000 { - t.Errorf("Expected total budget 10000, got %d", budget.Total()) - } - if budget.Used() != 0 { - t.Errorf("Expected used tokens 0, got %d", budget.Used()) - } -} - -func TestNewSafeContextBudget_UnderLimit(t *testing.T) { - // Request less than MaxSafeContextBudget - should get requested amount - budget := NewSafeContextBudget(50000) - if budget.Total() != 50000 { - t.Errorf("Expected total budget 50000, got %d", budget.Total()) - } -} - -func TestNewSafeContextBudget_OverLimit(t *testing.T) { - // Request more than MaxSafeContextBudget - should be capped - budget := NewSafeContextBudget(500000) // 500k tokens - if budget.Total() != MaxSafeContextBudget { - t.Errorf("Expected total budget %d (MaxSafeContextBudget), got %d", MaxSafeContextBudget, budget.Total()) - } -} - -func TestNewSafeContextBudget_GeminiScenario(t *testing.T) { - // Simulate Gemini's 1M token limit with 50% budget - geminiLimit := 1_000_000 - requestedBudget := int(float64(geminiLimit) * 0.5) // 500k tokens - - budget := NewSafeContextBudget(requestedBudget) - - // Should be capped at MaxSafeContextBudget - if budget.Total() != MaxSafeContextBudget { - t.Errorf("Gemini scenario: expected budget capped at %d, got %d", MaxSafeContextBudget, budget.Total()) - } -} - -func TestContextBudget_Reserve(t *testing.T) { - budget := NewContextBudget(100) - - // Should succeed - if err := budget.Reserve(50); err != nil { - t.Errorf("Reserve 50 should succeed: %v", err) - } - if budget.Used() != 50 { - t.Errorf("Expected used 50, got %d", budget.Used()) - } - - // Should succeed - if err := budget.Reserve(50); err != nil { - t.Errorf("Reserve another 50 should succeed: %v", err) - } - - // Should fail - budget exhausted - if err := budget.Reserve(1); err != ErrBudgetExceeded { - t.Errorf("Reserve when exhausted should return ErrBudgetExceeded, got: %v", err) - } -} - -func TestContextBudget_TryReserve(t *testing.T) { - budget := NewContextBudget(100) - - if !budget.TryReserve(100) { - t.Error("TryReserve 100 should succeed") - } - - if budget.TryReserve(1) { - t.Error("TryReserve when exhausted should return false") - } -} - -func TestContextBudget_IsExhausted(t *testing.T) { - budget := NewContextBudget(100) - - if budget.IsExhausted() { - t.Error("Fresh budget should not be exhausted") - } - - _ = budget.Reserve(100) - - if !budget.IsExhausted() { - t.Error("Full budget should be exhausted") - } -} - -func TestContextBudget_Remaining(t *testing.T) { - budget := NewContextBudget(100) - - if budget.Remaining() != 100 { - t.Errorf("Expected remaining 100, got %d", budget.Remaining()) - } - - _ = budget.Reserve(30) - - if budget.Remaining() != 70 { - t.Errorf("Expected remaining 70, got %d", budget.Remaining()) - } -} diff --git a/internal/agents/tools/chunker_test.go b/internal/agents/tools/chunker_test.go deleted file mode 100644 index 6937480..0000000 --- a/internal/agents/tools/chunker_test.go +++ /dev/null @@ -1,198 +0,0 @@ -package tools - -import ( - "os" - "path/filepath" - "testing" -) - -func TestNewCodeChunker(t *testing.T) { - chunker := NewCodeChunker("/test/path") - if chunker == nil { - t.Fatal("NewCodeChunker returned nil") - } - if chunker.basePath != "/test/path" { - t.Errorf("basePath = %q, want %q", chunker.basePath, "/test/path") - } -} - -func TestDefaultChunkConfig(t *testing.T) { - cfg := DefaultChunkConfig() - if cfg.MaxTokensPerChunk != 30000 { - t.Errorf("MaxTokensPerChunk = %d, want 30000", cfg.MaxTokensPerChunk) - } - if cfg.MaxFilesPerChunk != 50 { - t.Errorf("MaxFilesPerChunk = %d, want 50", cfg.MaxFilesPerChunk) - } - if !cfg.IncludeLineNumbers { - t.Error("IncludeLineNumbers should be true by default") - } -} - -func TestCodeChunker_ChunkSourceCode(t *testing.T) { - // Create temp directory with test files - tmpDir := t.TempDir() - - // Create some source files - files := map[string]string{ - "main.go": "package main\n\nfunc main() {\n\tprintln(\"hello\")\n}\n", - "handler.go": "package main\n\nfunc handleRequest() {}\n", - "internal/util.go": "package internal\n\nfunc Util() {}\n", - "internal/store.go": "package internal\n\nfunc Store() {}\n", - } - - for path, content := range files { - fullPath := filepath.Join(tmpDir, path) - if err := os.MkdirAll(filepath.Dir(fullPath), 0755); err != nil { - t.Fatalf("Failed to create dir: %v", err) - } - if err := os.WriteFile(fullPath, []byte(content), 0644); err != nil { - t.Fatalf("Failed to write file: %v", err) - } - } - - chunker := NewCodeChunker(tmpDir) - chunks, err := chunker.ChunkSourceCode() - if err != nil { - t.Fatalf("ChunkSourceCode failed: %v", err) - } - - if len(chunks) == 0 { - t.Fatal("Expected at least one chunk") - } - - // Verify all files were read - coverage := chunker.GetCoverage() - if len(coverage.FilesRead) != 4 { - t.Errorf("Expected 4 files read, got %d", len(coverage.FilesRead)) - } -} - -func TestCodeChunker_ChunkSourceCode_MultipleChunks(t *testing.T) { - tmpDir := t.TempDir() - - // Create a chunk config with very small limits to force multiple chunks - chunker := NewCodeChunker(tmpDir) - chunker.SetConfig(ChunkConfig{ - MaxTokensPerChunk: 100, // Very small to force chunking - MaxFilesPerChunk: 2, // Only 2 files per chunk - IncludeLineNumbers: true, - }) - - // Create 5 files - for i := 0; i < 5; i++ { - filename := filepath.Join(tmpDir, "file"+string(rune('a'+i))+".go") - content := "package main\n\nfunc Test" + string(rune('A'+i)) + "() {}\n" - if err := os.WriteFile(filename, []byte(content), 0644); err != nil { - t.Fatalf("Failed to write file: %v", err) - } - } - - chunks, err := chunker.ChunkSourceCode() - if err != nil { - t.Fatalf("ChunkSourceCode failed: %v", err) - } - - // With 5 files and max 2 per chunk, should have at least 3 chunks - if len(chunks) < 3 { - t.Errorf("Expected at least 3 chunks, got %d", len(chunks)) - } - - // Verify chunk indices are sequential - for i, chunk := range chunks { - if chunk.Index != i { - t.Errorf("Chunk %d has index %d", i, chunk.Index) - } - } -} - -func TestCodeChunker_SkipsTestFiles(t *testing.T) { - tmpDir := t.TempDir() - - files := map[string]string{ - "main.go": "package main\nfunc main() {}\n", - "main_test.go": "package main\nfunc TestMain() {}\n", - "util.spec.ts": "describe('util', () => {})", - "app.go": "package main\nfunc app() {}\n", - } - - for path, content := range files { - if err := os.WriteFile(filepath.Join(tmpDir, path), []byte(content), 0644); err != nil { - t.Fatalf("Failed to write file: %v", err) - } - } - - chunker := NewCodeChunker(tmpDir) - _, err := chunker.ChunkSourceCode() - if err != nil { - t.Fatalf("ChunkSourceCode failed: %v", err) - } - - coverage := chunker.GetCoverage() - - // Should read main.go and app.go, skip test files - if len(coverage.FilesRead) != 2 { - t.Errorf("Expected 2 files read, got %d", len(coverage.FilesRead)) - } - - // Test files should be in skipped - hasSkippedTest := false - for _, skip := range coverage.FilesSkipped { - if skip.Reason == "test file" { - hasSkippedTest = true - break - } - } - if !hasSkippedTest { - t.Error("Expected test files to be skipped") - } -} - -func TestCodeChunker_PriorityOrdering(t *testing.T) { - tmpDir := t.TempDir() - - // Create files with different priorities - // main.go should be higher priority than random.go - files := map[string]string{ - "zzz_random.go": "package main\nfunc random() {}\n", - "main.go": "package main\nfunc main() {}\n", - "internal/handler.go": "package internal\nfunc handler() {}\n", - } - - for path, content := range files { - fullPath := filepath.Join(tmpDir, path) - if err := os.MkdirAll(filepath.Dir(fullPath), 0755); err != nil { - t.Fatalf("Failed to create dir: %v", err) - } - if err := os.WriteFile(fullPath, []byte(content), 0644); err != nil { - t.Fatalf("Failed to write file: %v", err) - } - } - - chunker := NewCodeChunker(tmpDir) - chunks, err := chunker.ChunkSourceCode() - if err != nil { - t.Fatalf("ChunkSourceCode failed: %v", err) - } - - if len(chunks) == 0 || len(chunks[0].Files) == 0 { - t.Fatal("Expected at least one chunk with files") - } - - // First file should be main.go (priority 1) not zzz_random.go - firstFile := chunks[0].Files[0].RelPath - if firstFile != "main.go" { - t.Errorf("Expected main.go as first file (highest priority), got %s", firstFile) - } -} - -func TestCodeChunker_EmptyDirectory(t *testing.T) { - tmpDir := t.TempDir() - - chunker := NewCodeChunker(tmpDir) - _, err := chunker.ChunkSourceCode() - - if err == nil { - t.Error("Expected error for empty directory") - } -} diff --git a/internal/agents/tools/dedup_test.go b/internal/agents/tools/dedup_test.go deleted file mode 100644 index 6014d89..0000000 --- a/internal/agents/tools/dedup_test.go +++ /dev/null @@ -1,214 +0,0 @@ -package tools - -import ( - "testing" - - "github.com/josephgoksu/TaskWing/internal/agents/core" -) - -func TestNewFindingDeduplicator(t *testing.T) { - d := NewFindingDeduplicator() - if d == nil { - t.Fatal("NewFindingDeduplicator returned nil") - } - if d.similarityThreshold != 0.6 { - t.Errorf("Default threshold = %v, want 0.6", d.similarityThreshold) - } -} - -func TestFindingDeduplicator_SetSimilarityThreshold(t *testing.T) { - d := NewFindingDeduplicator() - - // Valid threshold - d.SetSimilarityThreshold(0.8) - if d.similarityThreshold != 0.8 { - t.Errorf("Threshold = %v, want 0.8", d.similarityThreshold) - } - - // Invalid thresholds should be ignored - d.SetSimilarityThreshold(0) - if d.similarityThreshold != 0.8 { - t.Error("Zero threshold should be ignored") - } - - d.SetSimilarityThreshold(1.5) - if d.similarityThreshold != 0.8 { - t.Error("Threshold > 1.0 should be ignored") - } -} - -func TestJaccardSimilarity(t *testing.T) { - tests := []struct { - a, b string - expected float64 - }{ - {"hello world", "hello world", 1.0}, - {"", "", 1.0}, - {"hello", "", 0.0}, - // After stop word removal: "hello world" vs "hello there" = {hello,world} vs {hello,there} - // Intersection=1 (hello), Union=3 (hello,world,there) -> 1/3 = 0.33 - {"hello world", "hello there", 0.33}, - // After stop word removal: "quick brown fox" vs "lazy brown dog" - // Intersection=1 (brown), Union=5 -> 1/5 = 0.2 - {"the quick brown fox", "the lazy brown dog", 0.2}, - } - - for _, tc := range tests { - got := jaccardSimilarity(tc.a, tc.b) - // Allow some floating point tolerance - if got < tc.expected-0.1 || got > tc.expected+0.1 { - t.Errorf("jaccardSimilarity(%q, %q) = %v, want ~%v", tc.a, tc.b, got, tc.expected) - } - } -} - -func TestTokenize(t *testing.T) { - tests := []struct { - input string - expected []string - }{ - {"hello world", []string{"hello", "world"}}, - {"hello-world", []string{"hello", "world"}}, - {"hello_world", []string{"hello", "world"}}, - {"the quick brown fox", []string{"quick", "brown", "fox"}}, // "the" is stop word - {"", nil}, - {"a b c", nil}, // All too short - } - - for _, tc := range tests { - got := tokenize(tc.input) - if len(got) != len(tc.expected) { - t.Errorf("tokenize(%q) = %v, want %v", tc.input, got, tc.expected) - } - } -} - -func TestDeduplicateFindings_Empty(t *testing.T) { - d := NewFindingDeduplicator() - result := d.DeduplicateFindings(nil) - if result != nil { - t.Error("Expected nil for empty input") - } -} - -func TestDeduplicateFindings_NoDuplicates(t *testing.T) { - d := NewFindingDeduplicator() - - findings := []core.Finding{ - {Type: core.FindingTypeDecision, Title: "Use PostgreSQL", Description: "Database choice"}, - {Type: core.FindingTypeDecision, Title: "Use Redis", Description: "Caching choice"}, - {Type: core.FindingTypePattern, Title: "Repository Pattern", Description: "Data access"}, - } - - result := d.DeduplicateFindings(findings) - if len(result) != 3 { - t.Errorf("Expected 3 findings, got %d", len(result)) - } -} - -func TestDeduplicateFindings_ExactDuplicates(t *testing.T) { - d := NewFindingDeduplicator() - - findings := []core.Finding{ - {Type: core.FindingTypeDecision, Title: "Use PostgreSQL", Description: "Database choice", ConfidenceScore: 0.9}, - {Type: core.FindingTypeDecision, Title: "Use PostgreSQL", Description: "Database choice", ConfidenceScore: 0.8}, - } - - result := d.DeduplicateFindings(findings) - if len(result) != 1 { - t.Errorf("Expected 1 finding after dedup, got %d", len(result)) - } - - // Should keep the higher confidence one - if result[0].ConfidenceScore != 0.9 { - t.Errorf("Expected to keep finding with higher confidence (0.9), got %v", result[0].ConfidenceScore) - } -} - -func TestDeduplicateFindings_SimilarTitles(t *testing.T) { - d := NewFindingDeduplicator() - d.SetSimilarityThreshold(0.5) // Lower threshold for this test - - // These titles share significant words after tokenization - findings := []core.Finding{ - {Type: core.FindingTypeDecision, Title: "PostgreSQL database storage decision", Description: "Primary storage choice"}, - {Type: core.FindingTypeDecision, Title: "PostgreSQL database storage choice", Description: "Data storage selection"}, - } - - result := d.DeduplicateFindings(findings) - if len(result) != 1 { - t.Errorf("Expected 1 finding after dedup of similar titles, got %d", len(result)) - } -} - -func TestDeduplicateFindings_DifferentTypes(t *testing.T) { - d := NewFindingDeduplicator() - - // Same title but different types should NOT be deduplicated - findings := []core.Finding{ - {Type: core.FindingTypeDecision, Title: "Repository Pattern", Description: "Decision to use"}, - {Type: core.FindingTypePattern, Title: "Repository Pattern", Description: "Pattern implementation"}, - } - - result := d.DeduplicateFindings(findings) - if len(result) != 2 { - t.Errorf("Different types should not be deduplicated, got %d findings", len(result)) - } -} - -func TestDeduplicateRelationships(t *testing.T) { - d := NewFindingDeduplicator() - - rels := []core.Relationship{ - {From: "A", To: "B", Relation: "depends_on"}, - {From: "A", To: "B", Relation: "depends_on"}, // duplicate - {From: "B", To: "C", Relation: "extends"}, - {From: "A", To: "B", Relation: "extends"}, // same from/to but different relation - } - - result := d.DeduplicateRelationships(rels) - if len(result) != 3 { - t.Errorf("Expected 3 unique relationships, got %d", len(result)) - } -} - -func TestDeduplicateRelationships_Empty(t *testing.T) { - d := NewFindingDeduplicator() - result := d.DeduplicateRelationships(nil) - if result != nil { - t.Error("Expected nil for empty input") - } -} - -func TestGetConfidence_Float(t *testing.T) { - d := NewFindingDeduplicator() - - finding := core.Finding{ConfidenceScore: 0.85} - got := d.getConfidence(finding) - if got != 0.85 { - t.Errorf("getConfidence = %v, want 0.85", got) - } -} - -func TestGetConfidence_String(t *testing.T) { - d := NewFindingDeduplicator() - - tests := []struct { - confidence string - expected float64 - }{ - {"high", 0.9}, - {"HIGH", 0.9}, - {"medium", 0.6}, - {"low", 0.3}, - {"unknown", 0.5}, - } - - for _, tc := range tests { - finding := core.Finding{Confidence: tc.confidence} - got := d.getConfidence(finding) - if got != tc.expected { - t.Errorf("getConfidence(%q) = %v, want %v", tc.confidence, got, tc.expected) - } - } -} diff --git a/internal/agents/verification/agent.go b/internal/agents/verification/agent.go index 3628e31..0798549 100644 --- a/internal/agents/verification/agent.go +++ b/internal/agents/verification/agent.go @@ -119,8 +119,9 @@ func (v *Agent) checkEvidence(evidence core.Evidence) core.EvidenceCheckResult { return result } - // Detect git evidence by explicit type or path prefix - if evidence.EvidenceType == "git" || strings.HasPrefix(evidence.FilePath, ".git") { + // Detect git evidence by explicit type, path prefix, or embedded .git path + // (multi-repo paths look like "serviceDir/.git/logs/HEAD") + if evidence.EvidenceType == "git" || strings.HasPrefix(evidence.FilePath, ".git") || strings.Contains(evidence.FilePath, "/.git/") { return v.verifyGitEvidence(evidence) } @@ -212,9 +213,17 @@ func (v *Agent) verifyGitEvidence(evidence core.Evidence) core.EvidenceCheckResu return result } + // Determine git working directory. + // For workspace-relative paths like "serviceDir/.git/logs/HEAD", + // extract the service directory and use it as the git root. + gitDir := v.basePath + if idx := strings.Index(evidence.FilePath, "/.git/"); idx > 0 { + gitDir = filepath.Join(v.basePath, evidence.FilePath[:idx]) + } + // Run git log to fetch recent commit history cmd := exec.Command("git", "log", "--all", "--oneline", "-500") - cmd.Dir = v.basePath + cmd.Dir = gitDir out, err := cmd.Output() if err != nil { result.ErrorMessage = "git log failed: " + err.Error() diff --git a/internal/app/recall.go b/internal/app/ask.go similarity index 93% rename from internal/app/recall.go rename to internal/app/ask.go index ed3352a..8483ce5 100644 --- a/internal/app/recall.go +++ b/internal/app/ask.go @@ -32,9 +32,9 @@ type SymbolResponse struct { Location string `json:"location"` // "file:line" for easy navigation } -// RecallResult contains the complete result of a knowledge search. +// AskResult contains the complete result of a knowledge search. // This is the canonical response type used by both CLI and MCP. -type RecallResult struct { +type AskResult struct { Query string `json:"query"` RewrittenQuery string `json:"rewritten_query,omitempty"` Pipeline string `json:"pipeline"` @@ -46,8 +46,8 @@ type RecallResult struct { Warning string `json:"warning,omitempty"` } -// RecallOptions configures the behavior of a recall query. -type RecallOptions struct { +// AskOptions configures the behavior of an ask query. +type AskOptions struct { Limit int // Maximum number of knowledge results (default: 5) SymbolLimit int // Maximum number of symbol results (default: 5) GenerateAnswer bool // Whether to generate a RAG answer @@ -62,9 +62,9 @@ type RecallOptions struct { IncludeRoot bool // When Workspace is set, also include 'root' workspace nodes (default: true) } -// DefaultRecallOptions returns sensible defaults for recall queries. -func DefaultRecallOptions() RecallOptions { - return RecallOptions{ +// DefaultAskOptions returns sensible defaults for ask queries. +func DefaultAskOptions() AskOptions { + return AskOptions{ Limit: 5, SymbolLimit: 5, GenerateAnswer: false, @@ -120,15 +120,15 @@ func ResolveWorkspace(explicitWorkspace string, autoDetect bool) (string, error) return "", nil // Empty means all workspaces } -// RecallApp provides knowledge retrieval operations. +// AskApp provides knowledge retrieval operations. // This is THE implementation - CLI and MCP both call these methods. -type RecallApp struct { +type AskApp struct { ctx *Context } -// NewRecallApp creates a new recall application service. -func NewRecallApp(ctx *Context) *RecallApp { - return &RecallApp{ctx: ctx} +// NewAskApp creates a new ask application service. +func NewAskApp(ctx *Context) *AskApp { + return &AskApp{ctx: ctx} } // Query performs semantic search with optional RAG answer generation. @@ -139,7 +139,7 @@ func NewRecallApp(ctx *Context) *RecallApp { // 4. Reranking (if enabled) // 5. Graph expansion (if enabled) // 6. Answer generation (if requested) -func (a *RecallApp) Query(ctx context.Context, query string, opts RecallOptions) (*RecallResult, error) { +func (a *AskApp) Query(ctx context.Context, query string, opts AskOptions) (*AskResult, error) { if opts.Limit <= 0 { opts.Limit = 5 } @@ -308,7 +308,7 @@ func (a *RecallApp) Query(ctx context.Context, query string, opts RecallOptions) } } - return &RecallResult{ + return &AskResult{ Query: query, RewrittenQuery: rewrittenQuery, Pipeline: pipeline, @@ -323,7 +323,7 @@ func (a *RecallApp) Query(ctx context.Context, query string, opts RecallOptions) // searchSymbols searches the code intelligence index for matching symbols. // It prioritizes public symbols over private ones. -func (a *RecallApp) searchSymbols(ctx context.Context, query string, limit int) []SymbolResponse { +func (a *AskApp) searchSymbols(ctx context.Context, query string, limit int) []SymbolResponse { // Get database handle from repository store := a.ctx.Repo.GetDB() if store == nil { @@ -379,7 +379,7 @@ func (a *RecallApp) searchSymbols(ctx context.Context, query string, limit int) // Summary returns a high-level overview of the project's knowledge base. // Use this when no query is provided. -func (a *RecallApp) Summary(ctx context.Context) (*knowledge.ProjectSummary, error) { +func (a *AskApp) Summary(ctx context.Context) (*knowledge.ProjectSummary, error) { ks := knowledge.NewService(a.ctx.Repo, a.ctx.LLMCfg) summary, err := ks.GetProjectSummary(ctx) if err != nil { @@ -390,7 +390,7 @@ func (a *RecallApp) Summary(ctx context.Context) (*knowledge.ProjectSummary, err // getRawSymbols retrieves raw codeintel.Symbol objects for source code fetching. // This is the core symbol retrieval - searchSymbols wraps it with response conversion. -func (a *RecallApp) getRawSymbols(ctx context.Context, query string, limit int) []codeintel.Symbol { +func (a *AskApp) getRawSymbols(ctx context.Context, query string, limit int) []codeintel.Symbol { store := a.ctx.Repo.GetDB() if store == nil { return nil @@ -411,7 +411,7 @@ func (a *RecallApp) getRawSymbols(ctx context.Context, query string, limit int) // generateRAGAnswer creates an answer using both knowledge nodes and code snippets. // This is the core of Code-Based RAG: answers are grounded in actual source code. // If streamWriter is provided, tokens are streamed as they arrive. -func (a *RecallApp) generateRAGAnswer(ctx context.Context, query string, nodes []knowledge.ScoredNode, snippets []CodeSnippet, streamWriter io.Writer) (string, error) { +func (a *AskApp) generateRAGAnswer(ctx context.Context, query string, nodes []knowledge.ScoredNode, snippets []CodeSnippet, streamWriter io.Writer) (string, error) { // Build context from both sources var contextParts []string diff --git a/internal/app/codeintel.go b/internal/app/codeintel.go index db7072e..1dcda6f 100644 --- a/internal/app/codeintel.go +++ b/internal/app/codeintel.go @@ -9,7 +9,7 @@ import ( ) // CodeIntelApp provides code intelligence operations through the app layer. -// This follows the same pattern as RecallApp, TaskApp, etc. +// This follows the same pattern as AskApp, TaskApp, etc. type CodeIntelApp struct { ctx *Context } diff --git a/internal/app/plan.go b/internal/app/plan.go index 332c952..d6825d1 100644 --- a/internal/app/plan.go +++ b/internal/app/plan.go @@ -112,7 +112,7 @@ type TaskPlanner interface { Close() error } -// TaskContextEnricher executes recall queries and returns aggregated context for a task. +// TaskContextEnricher executes ask queries and returns aggregated context for a task. // This is used during task creation to populate ContextSummary (early binding). // See docs/architecture/ADR_CONTEXT_BINDING.md for the full context binding design. type TaskContextEnricher func(ctx context.Context, queries []string) (string, error) @@ -130,7 +130,7 @@ type PlanApp struct { ClarifierFactory func(llm.Config) GoalsClarifier PlannerFactory func(llm.Config) TaskPlanner ContextRetriever func(ctx context.Context, ks *knowledge.Service, goal, memoryPath string) (impl.SearchStrategyResult, error) - // TaskEnricher executes recall queries to populate task ContextSummary. + // TaskEnricher executes ask queries to populate task ContextSummary. // If nil, tasks will not have embedded context (legacy behavior). TaskEnricher TaskContextEnricher } @@ -148,23 +148,23 @@ func NewPlanApp(ctx *Context) *PlanApp { }, ContextRetriever: impl.RetrieveContext, } - // Initialize default TaskEnricher using RecallApp + // Initialize default TaskEnricher using AskApp pa.TaskEnricher = pa.defaultTaskEnricher return pa } -// defaultTaskEnricher executes all recall queries and aggregates results into a context summary. +// defaultTaskEnricher executes all ask queries and aggregates results into a context summary. // This is the production implementation; tests can override TaskEnricher for mocking. func (a *PlanApp) defaultTaskEnricher(ctx context.Context, queries []string) (string, error) { if len(queries) == 0 { return "", nil } - recallApp := NewRecallApp(a.ctx) + askApp := NewAskApp(a.ctx) var contextParts []string for _, query := range queries { - result, err := recallApp.Query(ctx, query, RecallOptions{ + result, err := askApp.Query(ctx, query, AskOptions{ Limit: 3, // 3 results per query GenerateAnswer: false, IncludeSymbols: false, // Keep context focused on knowledge, not symbols @@ -1110,9 +1110,9 @@ func (a *PlanApp) parseTasksFromMetadata(ctx context.Context, metadata map[strin } t.EnrichAIFields() - // Populate ContextSummary by executing recall queries - if a.TaskEnricher != nil && len(t.SuggestedRecallQueries) > 0 { - if contextSummary, err := a.TaskEnricher(ctx, t.SuggestedRecallQueries); err == nil && contextSummary != "" { + // Populate ContextSummary by executing ask queries + if a.TaskEnricher != nil && len(t.SuggestedAskQueries) > 0 { + if contextSummary, err := a.TaskEnricher(ctx, t.SuggestedAskQueries); err == nil && contextSummary != "" { t.ContextSummary = contextSummary } } @@ -1197,9 +1197,9 @@ func (a *PlanApp) parseTasksFromMetadata(ctx context.Context, metadata map[strin } newTask.EnrichAIFields() - // Populate ContextSummary by executing recall queries - if a.TaskEnricher != nil && len(newTask.SuggestedRecallQueries) > 0 { - if contextSummary, err := a.TaskEnricher(ctx, newTask.SuggestedRecallQueries); err == nil && contextSummary != "" { + // Populate ContextSummary by executing ask queries + if a.TaskEnricher != nil && len(newTask.SuggestedAskQueries) > 0 { + if contextSummary, err := a.TaskEnricher(ctx, newTask.SuggestedAskQueries); err == nil && contextSummary != "" { newTask.ContextSummary = contextSummary } } diff --git a/internal/app/plan_enrichment_test.go b/internal/app/plan_enrichment_test.go deleted file mode 100644 index ea28453..0000000 --- a/internal/app/plan_enrichment_test.go +++ /dev/null @@ -1,420 +0,0 @@ -package app - -import ( - "context" - "strings" - "testing" - - "github.com/josephgoksu/TaskWing/internal/agents/core" - "github.com/josephgoksu/TaskWing/internal/agents/impl" - "github.com/josephgoksu/TaskWing/internal/knowledge" - "github.com/josephgoksu/TaskWing/internal/llm" - "github.com/josephgoksu/TaskWing/internal/task" -) - -// TestPlanEnrichment_ContextSummaryPopulated verifies that tasks have ContextSummary -// populated when TaskEnricher is configured and queries are generated. -func TestPlanEnrichment_ContextSummaryPopulated(t *testing.T) { - // Track which queries were executed - queriesExecuted := []string{} - - // Mock repo that captures created tasks - createdTasks := []*task.Task{} - mockRepo := &MockRepository{ - CreatePlanFunc: func(p *task.Plan) error { - return nil - }, - CreateTaskFunc: func(tsk *task.Task) error { - createdTasks = append(createdTasks, tsk) - return nil - }, - SetActivePlanFunc: func(id string) error { - return nil - }, - } - - appCtx := &Context{ - Repo: nil, - LLMCfg: llm.Config{}, - } - - app := NewPlanApp(appCtx) - app.Repo = mockRepo - - // Mock context retriever - app.ContextRetriever = func(ctx context.Context, ks *knowledge.Service, goal, memoryPath string) (impl.SearchStrategyResult, error) { - return impl.SearchStrategyResult{ - Context: "Mock Architecture Context", - Strategy: "Mock Strategy", - }, nil - } - - // Mock TaskEnricher that tracks executed queries and returns context - app.TaskEnricher = func(ctx context.Context, queries []string) (string, error) { - queriesExecuted = append(queriesExecuted, queries...) - if len(queries) == 0 { - return "", nil - } - // Return enriched context based on queries - return "## Relevant Architecture Context\n- **Test Pattern** (pattern): Use dependency injection for testability\n- **SQLite Constraint** (constraint): SQLite is the single source of truth", nil - } - - // Mock clarifier that returns ready-to-plan - app.ClarifierFactory = func(cfg llm.Config) GoalsClarifier { - return &MockClarifier{ - RunFunc: func(ctx context.Context, input core.Input) (core.Output, error) { - return core.Output{ - Findings: []core.Finding{ - { - Type: "clarification", - Metadata: map[string]interface{}{ - "is_ready_to_plan": true, - "enriched_goal": "Implement user authentication with JWT tokens", - "goal_summary": "Auth Implementation", - "questions": []string{}, - }, - }, - }, - }, nil - }, - } - } - - // Mock planner that returns tasks with keywords (which will generate SuggestedRecallQueries) - app.PlannerFactory = func(cfg llm.Config) TaskPlanner { - return &MockPlanner{ - RunFunc: func(ctx context.Context, input core.Input) (core.Output, error) { - tasks := []impl.PlanningTask{ - { - Title: "Design authentication schema", - Description: "Create database schema for user auth", - Priority: 100, - Keywords: []string{"auth", "database", "schema"}, - Scope: "api", - }, - { - Title: "Implement JWT middleware", - Description: "Create middleware for JWT validation", - Priority: 90, - Keywords: []string{"jwt", "middleware", "security"}, - Scope: "api", - }, - } - return core.Output{ - Findings: []core.Finding{ - { - Type: "plan", - Metadata: map[string]interface{}{ - "tasks": tasks, - }, - }, - }, - }, nil - }, - } - } - - // Execute clarify - clarifyRes, err := app.Clarify(context.Background(), ClarifyOptions{Goal: "implement user auth"}) - if err != nil { - t.Fatalf("Clarify failed: %v", err) - } - if !clarifyRes.IsReadyToPlan { - t.Fatal("Expected IsReadyToPlan to be true") - } - - // Execute generate with Save=true - genRes, err := app.Generate(context.Background(), GenerateOptions{ - Goal: "implement user auth", - ClarifySessionID: clarifyRes.ClarifySessionID, - EnrichedGoal: clarifyRes.EnrichedGoal, - Save: true, - }) - if err != nil { - t.Fatalf("Generate failed: %v", err) - } - - // Assertions - if !genRes.Success { - t.Errorf("Expected Generate success, got failure: %s", genRes.Message) - } - if len(genRes.Tasks) != 2 { - t.Errorf("Expected 2 tasks, got %d", len(genRes.Tasks)) - } - - // Verify TaskEnricher was called (queries were executed) - if len(queriesExecuted) == 0 { - t.Error("TaskEnricher was never called - no recall queries were executed") - } - - // Verify ContextSummary is populated on tasks - for i, tsk := range genRes.Tasks { - if tsk.ContextSummary == "" { - t.Errorf("Task %d (%s) has empty ContextSummary", i, tsk.Title) - } - if !strings.Contains(tsk.ContextSummary, "Relevant Architecture Context") { - t.Errorf("Task %d ContextSummary doesn't contain expected content: %s", i, tsk.ContextSummary) - } - } -} - -// TestPlanEnrichment_MultipleQueriesAggregated verifies that multiple recall queries -// from different tasks are properly executed. -func TestPlanEnrichment_MultipleQueriesAggregated(t *testing.T) { - // Track queries per task - taskQueryCounts := make(map[int]int) - callCount := 0 - - mockRepo := &MockRepository{ - CreatePlanFunc: func(p *task.Plan) error { return nil }, - CreateTaskFunc: func(tsk *task.Task) error { return nil }, - SetActivePlanFunc: func(id string) error { return nil }, - } - - appCtx := &Context{ - Repo: nil, - LLMCfg: llm.Config{}, - } - - app := NewPlanApp(appCtx) - app.Repo = mockRepo - - app.ContextRetriever = func(ctx context.Context, ks *knowledge.Service, goal, memoryPath string) (impl.SearchStrategyResult, error) { - return impl.SearchStrategyResult{Context: "ctx", Strategy: "s"}, nil - } - - // Track each call to TaskEnricher - app.TaskEnricher = func(ctx context.Context, queries []string) (string, error) { - taskQueryCounts[callCount] = len(queries) - callCount++ - return "## Context\n- Mock result", nil - } - - app.ClarifierFactory = func(cfg llm.Config) GoalsClarifier { - return &MockClarifier{ - RunFunc: func(ctx context.Context, input core.Input) (core.Output, error) { - return core.Output{ - Findings: []core.Finding{{ - Type: "clarification", - Metadata: map[string]interface{}{ - "is_ready_to_plan": true, - "enriched_goal": "Test goal", - "goal_summary": "Test", - "questions": []string{}, - }, - }}, - }, nil - }, - } - } - - // Tasks with varying numbers of keywords (which generate queries) - app.PlannerFactory = func(cfg llm.Config) TaskPlanner { - return &MockPlanner{ - RunFunc: func(ctx context.Context, input core.Input) (core.Output, error) { - tasks := []impl.PlanningTask{ - {Title: "Task 1", Description: "Task 1 desc", Priority: 100, Keywords: []string{"kw1", "kw2", "kw3"}, Scope: "api"}, - {Title: "Task 2", Description: "Task 2 desc", Priority: 90, Keywords: []string{"kw4"}, Scope: "cli"}, - {Title: "Task 3", Description: "Task 3 desc", Priority: 80, Keywords: []string{}, Scope: "test"}, // No keywords - } - return core.Output{ - Findings: []core.Finding{{ - Type: "plan", - Metadata: map[string]interface{}{"tasks": tasks}, - }}, - }, nil - }, - } - } - - clarifyRes, _ := app.Clarify(context.Background(), ClarifyOptions{Goal: "test"}) - _, err := app.Generate(context.Background(), GenerateOptions{ - Goal: "test", - ClarifySessionID: clarifyRes.ClarifySessionID, - EnrichedGoal: clarifyRes.EnrichedGoal, - Save: true, - }) - if err != nil { - t.Fatalf("Generate failed: %v", err) - } - - // Verify TaskEnricher was called for tasks with queries - // Task 1 and Task 2 have keywords, Task 3 does not - if callCount < 2 { - t.Errorf("Expected TaskEnricher to be called at least 2 times, got %d", callCount) - } -} - -// TestPlanEnrichment_NilEnricherSkipsEnrichment verifies backward compatibility -// when TaskEnricher is nil (legacy behavior). -func TestPlanEnrichment_NilEnricherSkipsEnrichment(t *testing.T) { - createdTasks := []*task.Task{} - mockRepo := &MockRepository{ - CreatePlanFunc: func(p *task.Plan) error { return nil }, - CreateTaskFunc: func(tsk *task.Task) error { - createdTasks = append(createdTasks, tsk) - return nil - }, - SetActivePlanFunc: func(id string) error { return nil }, - } - - appCtx := &Context{ - Repo: nil, - LLMCfg: llm.Config{}, - } - - app := NewPlanApp(appCtx) - app.Repo = mockRepo - // Explicitly set TaskEnricher to nil (simulating legacy behavior) - app.TaskEnricher = nil - - app.ContextRetriever = func(ctx context.Context, ks *knowledge.Service, goal, memoryPath string) (impl.SearchStrategyResult, error) { - return impl.SearchStrategyResult{Context: "ctx", Strategy: "s"}, nil - } - - app.ClarifierFactory = func(cfg llm.Config) GoalsClarifier { - return &MockClarifier{ - RunFunc: func(ctx context.Context, input core.Input) (core.Output, error) { - return core.Output{ - Findings: []core.Finding{{ - Type: "clarification", - Metadata: map[string]interface{}{ - "is_ready_to_plan": true, - "enriched_goal": "Test", - "goal_summary": "Test", - "questions": []string{}, - }, - }}, - }, nil - }, - } - } - - app.PlannerFactory = func(cfg llm.Config) TaskPlanner { - return &MockPlanner{ - RunFunc: func(ctx context.Context, input core.Input) (core.Output, error) { - tasks := []impl.PlanningTask{ - {Title: "Task 1", Description: "Task 1 description", Priority: 100, Keywords: []string{"kw1", "kw2"}, Scope: "api"}, - } - return core.Output{ - Findings: []core.Finding{{ - Type: "plan", - Metadata: map[string]interface{}{"tasks": tasks}, - }}, - }, nil - }, - } - } - - clarifyRes, _ := app.Clarify(context.Background(), ClarifyOptions{Goal: "test"}) - genRes, err := app.Generate(context.Background(), GenerateOptions{ - Goal: "test", - ClarifySessionID: clarifyRes.ClarifySessionID, - EnrichedGoal: clarifyRes.EnrichedGoal, - Save: true, - }) - if err != nil { - t.Fatalf("Generate failed: %v", err) - } - - // Should still generate tasks successfully - if !genRes.Success { - t.Fatalf("Expected success even with nil TaskEnricher, got failure: %s", genRes.Message) - } - if len(genRes.Tasks) != 1 { - t.Fatalf("Expected 1 task, got %d", len(genRes.Tasks)) - } - - // ContextSummary should be empty (legacy behavior) - if genRes.Tasks[0].ContextSummary != "" { - t.Errorf("Expected empty ContextSummary with nil TaskEnricher, got: %s", genRes.Tasks[0].ContextSummary) - } - - // But SuggestedRecallQueries should still be generated (by EnrichAIFields) - if len(genRes.Tasks[0].SuggestedRecallQueries) == 0 { - t.Error("Expected SuggestedRecallQueries to be generated even without TaskEnricher") - } -} - -// TestPlanEnrichment_ContentLengthHandling verifies that very long context -// summaries are handled correctly (not testing truncation here as that's -// done in presentation layer, but ensuring no errors with long content). -func TestPlanEnrichment_ContentLengthHandling(t *testing.T) { - mockRepo := &MockRepository{ - CreatePlanFunc: func(p *task.Plan) error { return nil }, - CreateTaskFunc: func(tsk *task.Task) error { return nil }, - SetActivePlanFunc: func(id string) error { return nil }, - } - - appCtx := &Context{ - Repo: nil, - LLMCfg: llm.Config{}, - } - - app := NewPlanApp(appCtx) - app.Repo = mockRepo - - app.ContextRetriever = func(ctx context.Context, ks *knowledge.Service, goal, memoryPath string) (impl.SearchStrategyResult, error) { - return impl.SearchStrategyResult{Context: "ctx", Strategy: "s"}, nil - } - - // Return very long context - longContent := strings.Repeat("This is a very long piece of content that tests our handling of large context summaries. ", 50) - app.TaskEnricher = func(ctx context.Context, queries []string) (string, error) { - return "## Context\n" + longContent, nil - } - - app.ClarifierFactory = func(cfg llm.Config) GoalsClarifier { - return &MockClarifier{ - RunFunc: func(ctx context.Context, input core.Input) (core.Output, error) { - return core.Output{ - Findings: []core.Finding{{ - Type: "clarification", - Metadata: map[string]interface{}{ - "is_ready_to_plan": true, - "enriched_goal": "Test", - "goal_summary": "Test", - "questions": []string{}, - }, - }}, - }, nil - }, - } - } - - app.PlannerFactory = func(cfg llm.Config) TaskPlanner { - return &MockPlanner{ - RunFunc: func(ctx context.Context, input core.Input) (core.Output, error) { - tasks := []impl.PlanningTask{ - {Title: "Task 1", Description: "Task 1 description", Priority: 100, Keywords: []string{"test"}, Scope: "api"}, - } - return core.Output{ - Findings: []core.Finding{{ - Type: "plan", - Metadata: map[string]interface{}{"tasks": tasks}, - }}, - }, nil - }, - } - } - - clarifyRes, _ := app.Clarify(context.Background(), ClarifyOptions{Goal: "test"}) - genRes, err := app.Generate(context.Background(), GenerateOptions{ - Goal: "test", - ClarifySessionID: clarifyRes.ClarifySessionID, - EnrichedGoal: clarifyRes.EnrichedGoal, - Save: true, - }) - if err != nil { - t.Fatalf("Generate failed: %v", err) - } - - if !genRes.Success { - t.Fatalf("Expected success with long content, got: %s", genRes.Message) - } - - // Verify full content is stored (truncation happens at presentation layer) - if len(genRes.Tasks[0].ContextSummary) < len(longContent) { - t.Error("Expected full context to be stored, but got truncated content") - } -} diff --git a/internal/app/plan_integration_test.go b/internal/app/plan_integration_test.go deleted file mode 100644 index 24a972c..0000000 --- a/internal/app/plan_integration_test.go +++ /dev/null @@ -1,349 +0,0 @@ -package app - -import ( - "context" - "path/filepath" - "strings" - "testing" - - "github.com/josephgoksu/TaskWing/internal/agents/core" - "github.com/josephgoksu/TaskWing/internal/agents/impl" - "github.com/josephgoksu/TaskWing/internal/knowledge" - "github.com/josephgoksu/TaskWing/internal/llm" - "github.com/josephgoksu/TaskWing/internal/memory" - "github.com/josephgoksu/TaskWing/internal/task" -) - -// MockClarifier -type MockClarifier struct { - RunFunc func(ctx context.Context, input core.Input) (core.Output, error) - AutoAnswerFunc func(ctx context.Context, spec string, q []string, kg string) (string, error) -} - -func (m *MockClarifier) Run(ctx context.Context, input core.Input) (core.Output, error) { - if m.RunFunc != nil { - return m.RunFunc(ctx, input) - } - return core.Output{}, nil -} -func (m *MockClarifier) AutoAnswer(ctx context.Context, spec string, q []string, kg string) (string, error) { - if m.AutoAnswerFunc != nil { - return m.AutoAnswerFunc(ctx, spec, q, kg) - } - return "", nil -} -func (m *MockClarifier) Close() error { return nil } - -// MockPlanner -type MockPlanner struct { - RunFunc func(ctx context.Context, input core.Input) (core.Output, error) -} - -func (m *MockPlanner) Run(ctx context.Context, input core.Input) (core.Output, error) { - if m.RunFunc != nil { - return m.RunFunc(ctx, input) - } - return core.Output{}, nil -} -func (m *MockPlanner) Close() error { return nil } - -func TestPlanApp_TUIFlow(t *testing.T) { - // Invocation counters - createPlanCalled := false - setActivePlanCalled := false - clarifierCalled := false - plannerCalled := false - - // 1. Setup Mock Repo - mockRepo := &MockRepository{ - CreatePlanFunc: func(p *task.Plan) error { - createPlanCalled = true - if p.Status != "active" { - t.Errorf("expected plan status active, got %s", p.Status) - } - return nil - }, - CreateTaskFunc: func(tsk *task.Task) error { - if tsk.Title == "" { - t.Error("created task has no title") - } - return nil - }, - SetActivePlanFunc: func(id string) error { - setActivePlanCalled = true - return nil - }, - } - - appCtx := &Context{ - Repo: nil, // concrete repo not needed if we override - LLMCfg: llm.Config{}, - } - - app := NewPlanApp(appCtx) - // Inject dependencies - app.Repo = mockRepo - - // Mock Context Retrieval - app.ContextRetriever = func(ctx context.Context, ks *knowledge.Service, goal, memoryPath string) (impl.SearchStrategyResult, error) { - return impl.SearchStrategyResult{ - Context: "Mock Architecture Context", - Strategy: "Mock Strategy", - }, nil - } - - // Mock TaskEnricher to avoid calling real RecallApp - app.TaskEnricher = func(ctx context.Context, queries []string) (string, error) { - return "## Mock Context\n- Test decision: Use mock pattern", nil - } - - // Mock Clarifier - app.ClarifierFactory = func(cfg llm.Config) GoalsClarifier { - return &MockClarifier{ - RunFunc: func(ctx context.Context, input core.Input) (core.Output, error) { - clarifierCalled = true - return core.Output{ - Findings: []core.Finding{ - { - Type: "clarification", - Metadata: map[string]interface{}{ - "is_ready_to_plan": true, - "enriched_goal": "Build a flux capacitor", - "goal_summary": "Flux Capacitor", - "questions": []string{}, - }, - }, - }, - }, nil - }, - } - } - - // Mock Planner - app.PlannerFactory = func(cfg llm.Config) TaskPlanner { - return &MockPlanner{ - RunFunc: func(ctx context.Context, input core.Input) (core.Output, error) { - plannerCalled = true - tasks := []impl.PlanningTask{ - {Title: "Task 1", Description: "Desc 1", Priority: 1, AssignedAgent: "engineer"}, - {Title: "Task 2", Description: "Desc 2", Priority: 2, AssignedAgent: "qa"}, - } - return core.Output{ - Findings: []core.Finding{ - { - Type: "plan", - Metadata: map[string]interface{}{ - "tasks": tasks, - }, - }, - }, - }, nil - }, - } - } - - // 2. Execute Clarify - clarifyRes, err := app.Clarify(context.Background(), ClarifyOptions{Goal: "build time machine"}) - if err != nil { - t.Fatalf("Clarify failed: %v", err) - } - if !clarifyRes.IsReadyToPlan { - t.Error("Expected IsReadyToPlan to be true") - } - if clarifyRes.EnrichedGoal != "Build a flux capacitor" { - t.Errorf("Expected enriched goal 'Build a flux capacitor', got '%s'", clarifyRes.EnrichedGoal) - } - - // 3. Execute Generate - genRes, err := app.Generate(context.Background(), GenerateOptions{ - Goal: "build time machine", - ClarifySessionID: clarifyRes.ClarifySessionID, - EnrichedGoal: clarifyRes.EnrichedGoal, - Save: true, - }) - if err != nil { - t.Fatalf("Generate failed: %v", err) - } - if !genRes.Success { - t.Errorf("Expected Generate success, got failure: %s", genRes.Message) - } - if len(genRes.Tasks) != 2 { - t.Errorf("Expected 2 tasks, got %d", len(genRes.Tasks)) - } - - // 4. Verify all mocks were called - if !clarifierCalled { - t.Error("ClarifierFactory agent was never called") - } - if !plannerCalled { - t.Error("PlannerFactory agent was never called") - } - if !createPlanCalled { - t.Error("CreatePlan was never called - plan was not saved") - } - if !setActivePlanCalled { - t.Error("SetActivePlan was never called - plan was not activated") - } -} - -func TestPlanApp_ClarifySessionizedRoundsPersistTurns(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "memory.db") - repo, err := memory.NewDefaultRepository(dbPath) - if err != nil { - t.Fatalf("create repository: %v", err) - } - defer func() { _ = repo.Close() }() - - appCtx := &Context{ - Repo: repo, - LLMCfg: llm.Config{}, - } - planApp := NewPlanApp(appCtx) - - runCount := 0 - planApp.ClarifierFactory = func(cfg llm.Config) GoalsClarifier { - return &MockClarifier{ - RunFunc: func(ctx context.Context, input core.Input) (core.Output, error) { - runCount++ - if runCount == 2 { - history, _ := input.ExistingContext["history"].(string) - if !strings.Contains(history, "Q1?") || !strings.Contains(history, "Use streaming") { - t.Fatalf("continuation clarify call missing persisted Q/A context: %q", history) - } - } - - if runCount == 1 { - return core.Output{ - Findings: []core.Finding{{ - Type: "clarification", - Metadata: map[string]any{ - "is_ready_to_plan": false, - "enriched_goal": "Draft spec v1", - "goal_summary": "Draft summary", - "questions": []string{"Q1?"}, - }, - }}, - }, nil - } - - return core.Output{ - Findings: []core.Finding{{ - Type: "clarification", - Metadata: map[string]any{ - "is_ready_to_plan": true, - "enriched_goal": "Final enriched spec", - "goal_summary": "Final summary", - "questions": []string{}, - }, - }}, - }, nil - }, - } - } - - first, err := planApp.Clarify(context.Background(), ClarifyOptions{ - Goal: "Ship onboarding revamp", - }) - if err != nil { - t.Fatalf("first clarify failed: %v", err) - } - if first.ClarifySessionID == "" { - t.Fatal("expected clarify_session_id on first clarify round") - } - if first.IsReadyToPlan { - t.Fatal("expected first clarify round to be unresolved") - } - if first.RoundIndex != 1 { - t.Fatalf("expected first round index 1, got %d", first.RoundIndex) - } - - second, err := planApp.Clarify(context.Background(), ClarifyOptions{ - Goal: "Ship onboarding revamp", - ClarifySessionID: first.ClarifySessionID, - Answers: []ClarifyAnswer{ - {Question: "Q1?", Answer: "Use streaming"}, - }, - }) - if err != nil { - t.Fatalf("second clarify failed: %v", err) - } - if second.ClarifySessionID != first.ClarifySessionID { - t.Fatalf("expected stable session id %q, got %q", first.ClarifySessionID, second.ClarifySessionID) - } - if !second.IsReadyToPlan { - t.Fatal("expected second clarify round to be ready_to_plan") - } - if second.RoundIndex != 2 { - t.Fatalf("expected second round index 2, got %d", second.RoundIndex) - } - - session, err := repo.GetClarifySession(first.ClarifySessionID) - if err != nil { - t.Fatalf("load clarify session: %v", err) - } - if !session.IsReadyToPlan { - t.Fatal("expected persisted session to be ready_to_plan") - } - if session.RoundIndex != 2 { - t.Fatalf("expected persisted round index 2, got %d", session.RoundIndex) - } - - turns, err := repo.ListClarifyTurns(first.ClarifySessionID) - if err != nil { - t.Fatalf("list clarify turns: %v", err) - } - if len(turns) != 2 { - t.Fatalf("expected 2 persisted turns, got %d", len(turns)) - } - if len(turns[0].Questions) != 1 || turns[0].Questions[0] != "Q1?" { - t.Fatalf("unexpected round 1 questions: %+v", turns[0].Questions) - } - if len(turns[1].Answers) != 1 || turns[1].Answers[0] != "Use streaming" { - t.Fatalf("unexpected round 2 answers: %+v", turns[1].Answers) - } -} - -func TestPlanApp_GenerateBlockedUntilClarifyReady(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "memory.db") - repo, err := memory.NewDefaultRepository(dbPath) - if err != nil { - t.Fatalf("create repository: %v", err) - } - defer func() { _ = repo.Close() }() - - session := &task.ClarifySession{ - ID: "clarify-awaiting", - Goal: "Improve planner", - EnrichedGoal: "Draft goal", - State: task.ClarifySessionStateAwaitingAnswers, - RoundIndex: 1, - MaxRounds: 5, - MaxQuestionsPerRound: 3, - CurrentQuestions: []string{"Q1?"}, - IsReadyToPlan: false, - } - if err := repo.CreateClarifySession(session); err != nil { - t.Fatalf("create clarify session: %v", err) - } - - planApp := NewPlanApp(&Context{ - Repo: repo, - LLMCfg: llm.Config{}, - }) - - result, err := planApp.Generate(context.Background(), GenerateOptions{ - Goal: "Improve planner", - ClarifySessionID: session.ID, - EnrichedGoal: "Draft goal", - Save: false, - }) - if err != nil { - t.Fatalf("generate returned unexpected error: %v", err) - } - if result.Success { - t.Fatal("expected generate to be blocked while clarify is unresolved") - } - if !strings.Contains(strings.ToLower(result.Message), "clarification is not complete") { - t.Fatalf("expected clarification gate message, got %q", result.Message) - } -} diff --git a/internal/app/plan_test.go b/internal/app/plan_test.go deleted file mode 100644 index 9c62ccf..0000000 --- a/internal/app/plan_test.go +++ /dev/null @@ -1,347 +0,0 @@ -package app - -import ( - "context" - "os/exec" - "strings" - "testing" - - "github.com/josephgoksu/TaskWing/internal/agents/core" - "github.com/josephgoksu/TaskWing/internal/agents/impl" - "github.com/josephgoksu/TaskWing/internal/knowledge" - "github.com/josephgoksu/TaskWing/internal/llm" - "github.com/josephgoksu/TaskWing/internal/task" -) - -// MockRepository implements task.Repository for testing -type MockRepository struct { - CreatePlanFunc func(p *task.Plan) error - SetActivePlanFunc func(id string) error - GetActivePlanFunc func() (*task.Plan, error) - GetPlanFunc func(id string) (*task.Plan, error) - ListPlansFunc func() ([]task.Plan, error) - ListTasksFunc func(planID string) ([]task.Task, error) - CreateTaskFunc func(t *task.Task) error - GetTaskFunc func(id string) (*task.Task, error) - UpdateTaskStatusFunc func(id string, status task.TaskStatus) error - UpdatePlanFunc func(id, goal, enrichedGoal string, status task.PlanStatus) error - DeletePlanFunc func(id string) error - SearchPlansFunc func(query string, status task.PlanStatus) ([]task.Plan, error) - UpdatePlanAuditReportFunc func(id string, status task.PlanStatus, auditReportJSON string) error -} - -func (m *MockRepository) CreatePlan(p *task.Plan) error { - if m.CreatePlanFunc != nil { - return m.CreatePlanFunc(p) - } - return nil -} -func (m *MockRepository) SetActivePlan(id string) error { - if m.SetActivePlanFunc != nil { - return m.SetActivePlanFunc(id) - } - return nil -} -func (m *MockRepository) GetActivePlan() (*task.Plan, error) { - if m.GetActivePlanFunc != nil { - return m.GetActivePlanFunc() - } - return nil, nil -} -func (m *MockRepository) GetPlan(id string) (*task.Plan, error) { - if m.GetPlanFunc != nil { - return m.GetPlanFunc(id) - } - return nil, nil -} -func (m *MockRepository) ListPlans() ([]task.Plan, error) { - if m.ListPlansFunc != nil { - return m.ListPlansFunc() - } - return nil, nil -} -func (m *MockRepository) SearchPlans(query string, status task.PlanStatus) ([]task.Plan, error) { - if m.SearchPlansFunc != nil { - return m.SearchPlansFunc(query, status) - } - return nil, nil -} -func (m *MockRepository) UpdatePlanAuditReport(id string, status task.PlanStatus, auditReportJSON string) error { - if m.UpdatePlanAuditReportFunc != nil { - return m.UpdatePlanAuditReportFunc(id, status, auditReportJSON) - } - return nil -} - -func (m *MockRepository) ListTasks(planID string) ([]task.Task, error) { - if m.ListTasksFunc != nil { - return m.ListTasksFunc(planID) - } - return nil, nil -} -func (m *MockRepository) CreateTask(t *task.Task) error { - if m.CreateTaskFunc != nil { - return m.CreateTaskFunc(t) - } - return nil -} -func (m *MockRepository) GetTask(id string) (*task.Task, error) { - if m.GetTaskFunc != nil { - return m.GetTaskFunc(id) - } - return nil, nil -} -func (m *MockRepository) UpdateTaskStatus(id string, status task.TaskStatus) error { - if m.UpdateTaskStatusFunc != nil { - return m.UpdateTaskStatusFunc(id, status) - } - return nil -} -func (m *MockRepository) UpdatePlan(id, goal, enrichedGoal string, status task.PlanStatus) error { - if m.UpdatePlanFunc != nil { - return m.UpdatePlanFunc(id, goal, enrichedGoal, status) - } - return nil -} -func (m *MockRepository) DeletePlan(id string) error { - if m.DeletePlanFunc != nil { - return m.DeletePlanFunc(id) - } - return nil -} - -func (m *MockRepository) AddDependency(taskID, dependsOn string) error { - return nil -} - -func (m *MockRepository) RemoveDependency(taskID, dependsOn string) error { - return nil -} - -// Phase repository methods (for interactive plan generation) -func (m *MockRepository) CreatePhase(p *task.Phase) error { - return nil -} - -func (m *MockRepository) GetPhase(id string) (*task.Phase, error) { - return nil, nil -} - -func (m *MockRepository) ListPhases(planID string) ([]task.Phase, error) { - return nil, nil -} - -func (m *MockRepository) UpdatePhase(p *task.Phase) error { - return nil -} - -func (m *MockRepository) UpdatePhaseStatus(id string, status task.PhaseStatus) error { - return nil -} - -func (m *MockRepository) DeletePhase(id string) error { - return nil -} - -func (m *MockRepository) CreatePhasesForPlan(planID string, phases []task.Phase) error { - return nil -} - -func (m *MockRepository) ListTasksByPhase(phaseID string) ([]task.Task, error) { - return nil, nil -} - -func (m *MockRepository) GetPlanWithPhases(id string) (*task.Plan, error) { - return nil, nil -} - -func (m *MockRepository) UpdatePlanDraftState(planID string, draftStateJSON string) error { - return nil -} - -func (m *MockRepository) UpdatePlanGenerationMode(planID string, mode task.GenerationMode) error { - return nil -} - -func TestPlanApp_Generate_Failures(t *testing.T) { - // Placeholder test -} - -func TestPlanApp_Generate_SemanticValidation(t *testing.T) { - // Test that semantic validation catches invalid file paths and shell commands - t.Run("warns on missing file paths", func(t *testing.T) { - // Create mock repo - mockRepo := &MockRepository{ - CreatePlanFunc: func(p *task.Plan) error { - p.ID = "test-plan-1" - return nil - }, - SetActivePlanFunc: func(id string) error { - return nil - }, - } - - // Create app with mocked planner that returns tasks with file references - planApp := &PlanApp{ - ctx: &Context{}, // Empty context - semantic validation doesn't need it - Repo: mockRepo, - PlannerFactory: func(cfg llm.Config) TaskPlanner { - return &mockTaskPlanner{ - tasks: []impl.PlanningTask{ - { - Title: "Task with missing file", - Description: "Modify missing/path/to/file.go to add feature", - Priority: 50, - Complexity: "medium", - AssignedAgent: "coder", - AcceptanceCriteria: []string{"File is modified"}, - ValidationSteps: []string{"go test ./..."}, - }, - }, - } - }, - ContextRetriever: func(ctx context.Context, ks *knowledge.Service, goal, memoryPath string) (impl.SearchStrategyResult, error) { - return impl.SearchStrategyResult{}, nil - }, - } - - result, err := planApp.Generate(context.Background(), GenerateOptions{ - Goal: "Test goal", - ClarifySessionID: "clarify-ephemeral", - EnrichedGoal: "Test enriched goal", - Save: true, - }) - - if err != nil { - t.Fatalf("Generate failed: %v", err) - } - if !result.Success { - t.Fatalf("Expected success, got: %s", result.Message) - } - - // Verify semantic warnings include the missing file - if len(result.SemanticWarnings) == 0 { - t.Error("Expected semantic warnings for missing file path, got none") - } - - foundFileWarning := false - for _, w := range result.SemanticWarnings { - lower := strings.ToLower(w) - if strings.Contains(lower, "missing_file") || strings.Contains(lower, "missing") { - foundFileWarning = true - break - } - } - if !foundFileWarning { - t.Errorf("Expected warning about missing file, got: %v", result.SemanticWarnings) - } - - // Verify stats were populated - if result.ValidationStats == nil { - t.Error("Expected ValidationStats to be populated") - } else if result.ValidationStats.PathsChecked == 0 { - t.Error("Expected PathsChecked > 0") - } - }) - - t.Run("reports invalid shell commands", func(t *testing.T) { - if _, err := exec.LookPath("bash"); err != nil { - t.Skip("bash not available; skipping shell validation test") - } - - mockRepo := &MockRepository{ - CreatePlanFunc: func(p *task.Plan) error { - p.ID = "test-plan-2" - return nil - }, - SetActivePlanFunc: func(id string) error { - return nil - }, - } - - planApp := &PlanApp{ - ctx: &Context{}, // Empty context - semantic validation doesn't need it - Repo: mockRepo, - PlannerFactory: func(cfg llm.Config) TaskPlanner { - return &mockTaskPlanner{ - tasks: []impl.PlanningTask{ - { - Title: "Task with invalid command", - Description: "Run the build", - Priority: 50, - Complexity: "low", - AssignedAgent: "coder", - AcceptanceCriteria: []string{"Build passes"}, - ValidationSteps: []string{"if [ -f test.txt then echo ok fi"}, // Invalid syntax - missing ] - }, - }, - } - }, - ContextRetriever: func(ctx context.Context, ks *knowledge.Service, goal, memoryPath string) (impl.SearchStrategyResult, error) { - return impl.SearchStrategyResult{}, nil - }, - } - - result, err := planApp.Generate(context.Background(), GenerateOptions{ - Goal: "Test goal", - ClarifySessionID: "clarify-ephemeral", - EnrichedGoal: "Test enriched goal", - Save: true, - }) - - if err != nil { - t.Fatalf("Generate failed: %v", err) - } - if !result.Success { - t.Fatalf("Expected success, got: %s", result.Message) - } - - // Verify semantic errors include the invalid command - if len(result.SemanticErrors) == 0 { - t.Error("Expected semantic errors for invalid shell command, got none") - } - - foundCommandError := false - for _, e := range result.SemanticErrors { - if strings.Contains(strings.ToLower(e), "invalid_command") { - foundCommandError = true - break - } - } - if !foundCommandError { - t.Errorf("Expected error about invalid command, got: %v", result.SemanticErrors) - } - - // Verify stats - if result.ValidationStats == nil { - t.Error("Expected ValidationStats to be populated") - } else if result.ValidationStats.CommandsValidated == 0 { - t.Error("Expected CommandsValidated > 0") - } - }) -} - -// mockTaskPlanner is a mock TaskPlanner for testing -type mockTaskPlanner struct { - tasks []impl.PlanningTask -} - -func (m *mockTaskPlanner) Run(ctx context.Context, input core.Input) (core.Output, error) { - return core.Output{ - AgentName: "mock-planner", - Findings: []core.Finding{ - { - Type: "plan", - Title: "Test Plan", - Description: "Test rationale", - Metadata: map[string]any{ - "tasks": m.tasks, - }, - }, - }, - }, nil -} - -func (m *mockTaskPlanner) Close() error { - return nil -} diff --git a/internal/app/recall_test.go b/internal/app/recall_test.go deleted file mode 100644 index b583d5a..0000000 --- a/internal/app/recall_test.go +++ /dev/null @@ -1,87 +0,0 @@ -package app - -import "testing" - -// TestDefaultRecallOptions verifies that default recall options have expected values. -func TestDefaultRecallOptions(t *testing.T) { - opts := DefaultRecallOptions() - - // Basic defaults - if opts.Limit != 5 { - t.Errorf("Limit = %d, want 5", opts.Limit) - } - if opts.SymbolLimit != 5 { - t.Errorf("SymbolLimit = %d, want 5", opts.SymbolLimit) - } - if opts.GenerateAnswer != false { - t.Error("GenerateAnswer should be false by default") - } - if opts.IncludeSymbols != true { - t.Error("IncludeSymbols should be true by default") - } - - // Workspace defaults - if opts.Workspace != "" { - t.Errorf("Workspace = %q, want empty string (all workspaces)", opts.Workspace) - } - if opts.IncludeRoot != true { - t.Error("IncludeRoot should be true by default") - } -} - -// TestValidateWorkspace tests workspace validation logic. -func TestValidateWorkspace(t *testing.T) { - tests := []struct { - name string - workspace string - wantErr bool - }{ - {"empty is valid", "", false}, - {"root is valid", "root", false}, - {"simple name", "osprey", false}, - {"with hyphen", "my-service", false}, - {"with underscore", "my_service", false}, - {"with numbers", "service123", false}, - {"uppercase", "MyService", false}, - {"mixed", "My-Service_123", false}, - {"invalid space", "my service", true}, - {"invalid slash", "my/service", true}, - {"invalid dot", "my.service", true}, - {"invalid colon", "my:service", true}, - {"invalid at", "my@service", true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := ValidateWorkspace(tt.workspace) - if (err != nil) != tt.wantErr { - t.Errorf("ValidateWorkspace(%q) error = %v, wantErr %v", tt.workspace, err, tt.wantErr) - } - }) - } -} - -// TestRecallOptionsWorkspaceDefaults tests that workspace filtering defaults work correctly. -func TestRecallOptionsWorkspaceDefaults(t *testing.T) { - // Default options: no workspace filter, include root - opts := DefaultRecallOptions() - - // When no workspace is specified, all workspaces should be searched - if opts.Workspace != "" { - t.Errorf("default workspace filter should be empty, got %q", opts.Workspace) - } - - // IncludeRoot should be true so that root/global knowledge is always visible - if !opts.IncludeRoot { - t.Error("IncludeRoot should default to true for workspace-aware searches") - } - - // When workspace is set, IncludeRoot=true means we get workspace+root results - opts.Workspace = "osprey" - opts.IncludeRoot = true - // This test documents the expected behavior - actual filtering is in repository layer - - // When IncludeRoot=false, we should only get workspace-specific results - opts.IncludeRoot = false - // This should exclude root nodes (implementation in repository layer) -} diff --git a/internal/app/task.go b/internal/app/task.go index 1835777..011e2bd 100644 --- a/internal/app/task.go +++ b/internal/app/task.go @@ -186,9 +186,9 @@ func (a *TaskApp) Next(ctx context.Context, opts TaskNextOptions) (*TaskResult, } // Build hint - hint := "Call recall tool with suggested queries for context before starting work." - if len(nextTask.SuggestedRecallQueries) > 0 { - hint = fmt.Sprintf("Call recall tool with queries: %v", nextTask.SuggestedRecallQueries) + hint := "Call ask tool with suggested queries for context before starting work." + if len(nextTask.SuggestedAskQueries) > 0 { + hint = fmt.Sprintf("Call ask tool with queries: %v", nextTask.SuggestedAskQueries) } // Build rich context @@ -308,9 +308,9 @@ func (a *TaskApp) Start(ctx context.Context, opts TaskStartOptions) (*TaskResult plan, _ := repo.GetPlan(startedTask.PlanID) - hint := "Call recall tool with suggested queries for relevant context." - if len(startedTask.SuggestedRecallQueries) > 0 { - hint = fmt.Sprintf("Call recall tool with queries: %v", startedTask.SuggestedRecallQueries) + hint := "Call ask tool with suggested queries for relevant context." + if len(startedTask.SuggestedAskQueries) > 0 { + hint = fmt.Sprintf("Call ask tool with queries: %v", startedTask.SuggestedAskQueries) } return &TaskResult{ @@ -649,16 +649,16 @@ func (a *TaskApp) executeGitWorkflow(plan *task.Plan, skipUnpushedCheck bool) (* return gitClient.StartPlanWorkflow(plan.ID, plan.Goal, skipUnpushedCheck) } -// buildRichContext creates markdown context for a task using RecallApp. +// buildRichContext creates markdown context for a task using AskApp. func (a *TaskApp) buildRichContext(ctx context.Context, t *task.Task, plan *task.Plan) string { if plan == nil { return "" } - // Create a search function that uses RecallApp - recallApp := NewRecallApp(a.ctx) - searchFunc := func(ctx context.Context, query string, limit int) ([]task.RecallResult, error) { - result, err := recallApp.Query(ctx, query, RecallOptions{ + // Create a search function that uses AskApp + askApp := NewAskApp(a.ctx) + searchFunc := func(ctx context.Context, query string, limit int) ([]task.AskResult, error) { + result, err := askApp.Query(ctx, query, AskOptions{ Limit: limit, GenerateAnswer: false, }) @@ -666,9 +666,9 @@ func (a *TaskApp) buildRichContext(ctx context.Context, t *task.Task, plan *task return nil, err } - var adapted []task.RecallResult + var adapted []task.AskResult for _, r := range result.Results { - adapted = append(adapted, task.RecallResult{ + adapted = append(adapted, task.AskResult{ Summary: r.Summary, Type: r.Type, Content: r.Content, diff --git a/internal/app/task_policy_test.go b/internal/app/task_policy_test.go deleted file mode 100644 index bdddb3c..0000000 --- a/internal/app/task_policy_test.go +++ /dev/null @@ -1,196 +0,0 @@ -package app - -import ( - "context" - "os" - "path/filepath" - "strings" - "testing" - - "github.com/josephgoksu/TaskWing/internal/llm" - "github.com/josephgoksu/TaskWing/internal/memory" - "github.com/josephgoksu/TaskWing/internal/task" -) - -// TestTaskComplete_PolicyEnforcement tests that policy violations block task completion. -func TestTaskComplete_PolicyEnforcement(t *testing.T) { - // Create a temporary directory for the test - tmpDir, err := os.MkdirTemp("", "taskwing-policy-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - // Create .taskwing/policies directory - policiesDir := filepath.Join(tmpDir, ".taskwing", "policies") - if err := os.MkdirAll(policiesDir, 0755); err != nil { - t.Fatalf("failed to create policies dir: %v", err) - } - - // Create a test policy that blocks .env files - testPolicy := `package taskwing.policy - -import rego.v1 - -# Block environment files -deny contains msg if { - some file in input.task.files_modified - startswith(file, ".env") - msg := sprintf("BLOCKED: Environment file '%s' is protected", [file]) -} - -# Block GOVERNANCE.md -deny contains msg if { - some file in input.task.files_modified - file == "GOVERNANCE.md" - msg := "BLOCKED: GOVERNANCE.md is a protected file" -} -` - policyPath := filepath.Join(policiesDir, "test.rego") - if err := os.WriteFile(policyPath, []byte(testPolicy), 0644); err != nil { - t.Fatalf("failed to write test policy: %v", err) - } - - // Create .taskwing/memory directory for SQLite - memoryDir := filepath.Join(tmpDir, ".taskwing", "memory") - if err := os.MkdirAll(memoryDir, 0755); err != nil { - t.Fatalf("failed to create memory dir: %v", err) - } - - // Initialize repository - repo, err := memory.NewDefaultRepository(memoryDir) - if err != nil { - t.Fatalf("failed to create repository: %v", err) - } - defer func() { _ = repo.Close() }() - - // Create a test plan and task - testPlan := &task.Plan{ - ID: "test-plan-001", - Goal: "Test policy enforcement", - Status: task.PlanStatusActive, - } - if err := repo.CreatePlan(testPlan); err != nil { - t.Fatalf("failed to create plan: %v", err) - } - - testTask := &task.Task{ - ID: "test-task-001", - PlanID: testPlan.ID, - Title: "Test task", - Description: "A test task for policy enforcement", - Status: task.StatusInProgress, - Priority: 50, - } - if err := repo.CreateTask(testTask); err != nil { - t.Fatalf("failed to create task: %v", err) - } - - // Change to temp directory so policy engine finds .taskwing/policies - oldWd, _ := os.Getwd() - if err := os.Chdir(tmpDir); err != nil { - t.Fatalf("failed to change to temp dir: %v", err) - } - defer func() { _ = os.Chdir(oldWd) }() - - // Create TaskApp - appCtx := &Context{ - Repo: repo, - LLMCfg: llm.Config{}, - } - taskApp := NewTaskApp(appCtx) - - tests := []struct { - name string - filesModified []string - expectSuccess bool - expectMessage string - }{ - { - name: "allowed_files_should_pass", - filesModified: []string{"main.go", "README.md"}, - expectSuccess: true, - }, - { - name: "env_file_should_be_blocked", - filesModified: []string{"main.go", ".env"}, - expectSuccess: false, - expectMessage: "BLOCKED: Environment file", - }, - { - name: "env_local_should_be_blocked", - filesModified: []string{".env.local"}, - expectSuccess: false, - expectMessage: "BLOCKED: Environment file", - }, - { - name: "governance_md_should_be_blocked", - filesModified: []string{"GOVERNANCE.md"}, - expectSuccess: false, - expectMessage: "BLOCKED: GOVERNANCE.md is a protected file", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Create a new task for each test case - taskID := "task-" + tt.name - newTask := &task.Task{ - ID: taskID, - PlanID: testPlan.ID, - Title: "Test: " + tt.name, - Description: "Test task", - Status: task.StatusInProgress, - Priority: 50, - } - if err := repo.CreateTask(newTask); err != nil { - t.Fatalf("failed to create task: %v", err) - } - - // Attempt to complete the task - result, err := taskApp.Complete(context.Background(), TaskCompleteOptions{ - TaskID: taskID, - Summary: "Test completion", - FilesModified: tt.filesModified, - }) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Success != tt.expectSuccess { - t.Errorf("expected Success=%v, got %v. Message: %s", tt.expectSuccess, result.Success, result.Message) - } - - if !tt.expectSuccess && tt.expectMessage != "" { - if result.Message == "" || !strings.Contains(result.Message, tt.expectMessage) { - t.Errorf("expected message to contain %q, got %q", tt.expectMessage, result.Message) - } - } - - // Verify PolicyViolation flag is set correctly - if !tt.expectSuccess && !result.PolicyViolation { - t.Errorf("expected PolicyViolation=true for blocked task, got false") - } - if tt.expectSuccess && result.PolicyViolation { - t.Errorf("expected PolicyViolation=false for allowed task, got true") - } - - // Verify task status in database - taskFromDB, err := repo.GetTask(taskID) - if err != nil { - t.Fatalf("failed to get task: %v", err) - } - - if tt.expectSuccess { - if taskFromDB.Status != task.StatusCompleted { - t.Errorf("expected task status %s, got %s", task.StatusCompleted, taskFromDB.Status) - } - } else { - if taskFromDB.Status != task.StatusInProgress { - t.Errorf("expected task status %s (unchanged), got %s", task.StatusInProgress, taskFromDB.Status) - } - } - }) - } -} diff --git a/internal/bootstrap/initializer.go b/internal/bootstrap/initializer.go index 78e2a22..51d5e46 100644 --- a/internal/bootstrap/initializer.go +++ b/internal/bootstrap/initializer.go @@ -263,7 +263,8 @@ type SlashCommand struct { // SlashCommands is the canonical list of slash commands generated by TaskWing. // When this list changes, the version hash changes, triggering updates on next bootstrap. var SlashCommands = []SlashCommand{ - {"tw-brief", "brief", "Use when you need a compact project brief (decisions, patterns, constraints)."}, + {"tw-ask", "ask", "Use when you need to search project knowledge (decisions, patterns, constraints)."}, + {"tw-remember", "remember", "Use when you want to persist a decision, pattern, or insight to project memory."}, {"tw-next", "next", "Use when you are ready to start the next approved TaskWing task with full context."}, {"tw-done", "done", "Use when implementation is verified and you are ready to complete the current task."}, {"tw-status", "status", "Use when you need current task progress and acceptance criteria status."}, @@ -282,9 +283,48 @@ func SlashCommandNames() []string { return names } -var removedLegacySlashCommands = []string{ - "tw-context", - "tw-block", +// MCPTool describes a single MCP tool for documentation generation. +type MCPTool struct { + Name string `json:"name"` + Description string `json:"description"` +} + +// MCPTools is the canonical list of MCP tools exposed by the TaskWing MCP server. +var MCPTools = []MCPTool{ + {"ask", "Search project knowledge (decisions, patterns, constraints)"}, + {"task", "Unified task lifecycle (next, current, start, complete)"}, + {"plan", "Plan management (clarify, decompose, expand, generate, finalize, audit)"}, + {"code", "Code intelligence (find, search, explain, callers, impact, simplify)"}, + {"debug", "Diagnose issues systematically with AI-powered analysis"}, + {"remember", "Store knowledge in project memory"}, +} + +// MCPToolNames returns MCP tool names in canonical order. +func MCPToolNames() []string { + names := make([]string, 0, len(MCPTools)) + for _, tool := range MCPTools { + names = append(names, tool.Name) + } + return names +} + +// CoreCommand describes a CLI command included in documentation. +type CoreCommand struct { + Display string `json:"display"` // e.g. "taskwing goal \"\"" +} + +// CoreCommands is the curated list of CLI commands shown in documentation. +var CoreCommands = []CoreCommand{ + {"taskwing bootstrap"}, + {"taskwing goal \"\""}, + {"taskwing ask \"\""}, + {"taskwing task"}, + {"taskwing plan status"}, + {"taskwing slash"}, + {"taskwing mcp"}, + {"taskwing doctor"}, + {"taskwing config"}, + {"taskwing start"}, } // AIToolConfigVersion computes a version hash for the AI tool configuration. @@ -307,6 +347,14 @@ func AIToolConfigVersion(aiName string) string { parts = append(parts, fmt.Sprintf("cmd:%s:%s:%s", cmd.BaseName, cmd.SlashCmd, cmd.Description)) } + for _, tool := range MCPTools { + parts = append(parts, fmt.Sprintf("mcp:%s:%s", tool.Name, tool.Description)) + } + + for _, cc := range CoreCommands { + parts = append(parts, fmt.Sprintf("corecmd:%s", cc.Display)) + } + // Sort for determinism sort.Strings(parts) @@ -334,13 +382,10 @@ func expectedSlashCommandFiles(ext string) map[string]struct{} { } func managedSlashCommandBases() map[string]struct{} { - managed := make(map[string]struct{}, len(SlashCommands)+len(removedLegacySlashCommands)) + managed := make(map[string]struct{}, len(SlashCommands)) for _, cmd := range SlashCommands { managed[cmd.BaseName] = struct{}{} } - for _, base := range removedLegacySlashCommands { - managed[base] = struct{}{} - } return managed } @@ -551,7 +596,7 @@ func (i *Initializer) createSingleFileInstructions(aiName string, verbose bool) sb.WriteString("### Usage\n\n") sb.WriteString("With MCP configured, you can use TaskWing tools via:\n") - sb.WriteString("- `@mcp taskwing-mcp recall \"query\"` - Search project knowledge\n") + sb.WriteString("- `@mcp taskwing-mcp ask \"query\"` - Search project knowledge\n") sb.WriteString("- `@mcp taskwing-mcp task {\\\"action\\\":\\\"next\\\"}` - Get next task from plan (session_id auto-derived in MCP session)\n") sb.WriteString("- `@mcp taskwing-mcp remember \"content\"` - Store knowledge\n") @@ -957,9 +1002,8 @@ const ( taskwingDocMarkerEnd = "" ) -// taskwingDocSection is the complete TaskWing documentation block with markers. -// Keep this aligned with docs/_partials and scripts/sync-docs.sh contracts. -const taskwingDocSection = taskwingDocMarkerStart + ` +// taskwingDocSectionHeader is the static top portion of the documentation block. +const taskwingDocSectionHeader = ` ## TaskWing Integration @@ -990,43 +1034,10 @@ TaskWing helps me turn a goal into executed tasks with persistent context across Brand names and logos are trademarks of their respective owners; usage here indicates compatibility, not endorsement. -### Slash Commands -- /tw-brief - Use when you need a compact project brief. -- /tw-next - Use when you are ready to start the next approved task. -- /tw-done - Use when implementation is verified and ready to complete. -- /tw-plan - Use when you need to clarify a goal and build a plan. -- /tw-status - Use when you need current task progress. -- /tw-debug - Use when debugging must start from root-cause evidence. -- /tw-explain - Use when you need a deep symbol explanation. -- /tw-simplify - Use when you want to simplify code without behavior changes. - -### Core Commands - - -- taskwing bootstrap -- taskwing goal "" -- taskwing task -- taskwing plan status -- taskwing slash -- taskwing mcp -- taskwing doctor -- taskwing config -- taskwing start - - -### MCP Tools (Canonical Contract) - - -| Tool | Description | -|------|-------------| -| recall | Retrieve project knowledge (decisions, patterns, constraints) | -| task | Unified task lifecycle (next, current, start, complete) | -| plan | Plan management (clarify, decompose, expand, generate, finalize, audit) | -| code | Code intelligence (find, search, explain, callers, impact, simplify) | -| debug | Diagnose issues systematically with AI-powered analysis | -| remember | Store knowledge in project memory | - +` +// taskwingDocSectionFooter is the static bottom portion of the documentation block. +const taskwingDocSectionFooter = ` ### Autonomous Task Execution (Hooks) TaskWing integrates with Claude Code's hook system for autonomous plan execution: @@ -1046,12 +1057,55 @@ Configuration in .claude/settings.json enables auto-continuation through plans. Hook commands prefer $CLAUDE_PROJECT_DIR/bin/taskwing and fall back to taskwing in PATH. If Claude Code is already running, use /hooks to review or reload hook changes. -` + taskwingDocMarkerEnd +` + +// buildTaskwingDocSection assembles the complete TaskWing documentation block +// from the three registries (SlashCommands, CoreCommands, MCPTools). +// This is the single source of truth for documentation stamped into CLAUDE.md, +// AGENTS.md, and GEMINI.md during bootstrap. +func buildTaskwingDocSection() string { + var sb strings.Builder + + sb.WriteString(taskwingDocMarkerStart) + sb.WriteString(taskwingDocSectionHeader) + + // Slash Commands — generated from SlashCommands registry + sb.WriteString("### Slash Commands\n") + for _, cmd := range SlashCommands { + fmt.Fprintf(&sb, "- /%s - %s\n", cmd.BaseName, cmd.Description) + } + + // Core Commands — generated from CoreCommands registry + sb.WriteString("\n### Core Commands\n\n") + sb.WriteString("\n") + for _, cc := range CoreCommands { + fmt.Fprintf(&sb, "- %s\n", cc.Display) + } + sb.WriteString("\n") + + // MCP Tools — generated from MCPTools registry + sb.WriteString("\n### MCP Tools (Canonical Contract)\n\n") + sb.WriteString("\n") + sb.WriteString("| Tool | Description |\n") + sb.WriteString("|------|-------------|\n") + for _, tool := range MCPTools { + fmt.Fprintf(&sb, "| %s | %s |\n", tool.Name, tool.Description) + } + sb.WriteString("\n") + + sb.WriteString(taskwingDocSectionFooter) + sb.WriteString(taskwingDocMarkerEnd) + + return sb.String() +} func (i *Initializer) updateAgentDocs(verbose bool) error { // Always update all three agent doc files: CLAUDE.md, GEMINI.md, AGENTS.md filesToUpdate := []string{"CLAUDE.md", "GEMINI.md", "AGENTS.md"} + // Build doc section once from registries (single source of truth) + docSection := buildTaskwingDocSection() + for _, fileName := range filesToUpdate { filePath := filepath.Join(i.basePath, fileName) content, err := os.ReadFile(filePath) @@ -1076,7 +1130,7 @@ func (i *Initializer) updateAgentDocs(verbose bool) error { // Valid markers - replace content between them before := contentStr[:startIdx] after := contentStr[endIdx+len(taskwingDocMarkerEnd):] - newContent = before + taskwingDocSection + after + newContent = before + docSection + after action = "updated" } else if hasStartMarker != hasEndMarker { // Partial markers - warn and skip to avoid corruption @@ -1089,11 +1143,11 @@ func (i *Initializer) updateAgentDocs(verbose bool) error { if legacyEnd < len(contentStr) { after = contentStr[legacyEnd:] } - newContent = strings.TrimRight(before, "\n") + "\n" + taskwingDocSection + after + newContent = strings.TrimRight(before, "\n") + "\n" + docSection + after action = "migrated" } else { // No existing TaskWing content - append - newContent = strings.TrimRight(contentStr, "\n") + "\n" + taskwingDocSection + newContent = strings.TrimRight(contentStr, "\n") + "\n" + docSection action = "added" } diff --git a/internal/bootstrap/initializer_test.go b/internal/bootstrap/initializer_test.go deleted file mode 100644 index 6f8e874..0000000 --- a/internal/bootstrap/initializer_test.go +++ /dev/null @@ -1,933 +0,0 @@ -package bootstrap - -import ( - "encoding/json" - "os" - "path/filepath" - "regexp" - "strings" - "testing" -) - -func TestNewInitializer(t *testing.T) { - basePath := "/test/path" - init := NewInitializer(basePath) - - if init == nil { - t.Fatal("NewInitializer returned nil") - } - if init.basePath != basePath { - t.Errorf("basePath = %q, want %q", init.basePath, basePath) - } -} - -func TestValidAINames(t *testing.T) { - names := ValidAINames() - - // Should return all keys from aiHelpers map - if len(names) == 0 { - t.Error("ValidAINames returned empty slice") - } - - // Check that known AI names are present - expectedNames := map[string]bool{ - "claude": false, - "cursor": false, - "gemini": false, - "codex": false, - "copilot": false, - "opencode": false, - } - - for _, name := range names { - if _, ok := expectedNames[name]; ok { - expectedNames[name] = true - } - } - - for name, found := range expectedNames { - if !found { - t.Errorf("Expected AI name %q not found in ValidAINames()", name) - } - } -} - -func TestInitializer_Run_EmptyAIs(t *testing.T) { - tmpDir := t.TempDir() - init := NewInitializer(tmpDir) - - // Should not error with empty AIs - err := init.Run(false, []string{}) - if err != nil { - t.Errorf("Run with empty AIs failed: %v", err) - } - - // Should create .taskwing directory - if _, err := os.Stat(filepath.Join(tmpDir, ".taskwing")); os.IsNotExist(err) { - t.Error(".taskwing directory was not created") - } -} - -func TestInitializer_Run_InvalidAI(t *testing.T) { - tmpDir := t.TempDir() - init := NewInitializer(tmpDir) - - // Should handle invalid AI names gracefully - err := init.Run(true, []string{"invalid-ai-name"}) - if err != nil { - t.Errorf("Run with invalid AI failed: %v", err) - } -} - -func TestInitializer_Run_CreateSlashCommands(t *testing.T) { - tmpDir := t.TempDir() - init := NewInitializer(tmpDir) - - // Test with claude - err := init.Run(false, []string{"claude"}) - if err != nil { - t.Fatalf("Run failed: %v", err) - } - - // Check slash command files were created - expectedFiles := []string{ - ".claude/commands/tw-brief.md", - ".claude/commands/tw-next.md", - ".claude/commands/tw-done.md", - } - - for _, file := range expectedFiles { - path := filepath.Join(tmpDir, file) - if _, err := os.Stat(path); os.IsNotExist(err) { - t.Errorf("Expected file %s was not created", file) - } - } -} - -func TestInitializer_Run_GeminiTOML(t *testing.T) { - tmpDir := t.TempDir() - init := NewInitializer(tmpDir) - - err := init.Run(false, []string{"gemini"}) - if err != nil { - t.Fatalf("Run failed: %v", err) - } - - // Check TOML files were created for Gemini - tomlPath := filepath.Join(tmpDir, ".gemini/commands/tw-brief.toml") - content, err := os.ReadFile(tomlPath) - if err != nil { - t.Fatalf("Failed to read TOML file: %v", err) - } - - // Verify TOML content has expected fields - contentStr := string(content) - if !contains(contentStr, "description =") { - t.Error("TOML file missing description field") - } - if !contains(contentStr, "prompt =") { - t.Error("TOML file missing prompt field") - } -} - -func TestInitializer_InstallHooksConfig(t *testing.T) { - tmpDir := t.TempDir() - init := NewInitializer(tmpDir) - - err := init.InstallHooksConfig("claude", false) - if err != nil { - t.Fatalf("InstallHooksConfig failed: %v", err) - } - - // Read the created settings.json - settingsPath := filepath.Join(tmpDir, ".claude/settings.json") - content, err := os.ReadFile(settingsPath) - if err != nil { - t.Fatalf("Failed to read settings.json: %v", err) - } - - // Parse and verify JSON structure - var config HooksConfig - if err := json.Unmarshal(content, &config); err != nil { - t.Fatalf("Invalid JSON in settings.json: %v", err) - } - - if config.Hooks == nil { - t.Error("Hooks config is nil") - } - if _, ok := config.Hooks["SessionStart"]; !ok { - t.Error("Missing SessionStart hook") - } - if _, ok := config.Hooks["Stop"]; !ok { - t.Error("Missing Stop hook") - } - if _, ok := config.Hooks["SessionEnd"]; !ok { - t.Error("Missing SessionEnd hook") - } - - stopHook := config.Hooks["Stop"] - if len(stopHook) == 0 || len(stopHook[0].Hooks) == 0 { - t.Fatal("Stop hook commands missing") - } - stopCmd := stopHook[0].Hooks[0].Command - if !strings.Contains(stopCmd, "$CLAUDE_PROJECT_DIR/bin/taskwing") { - t.Errorf("Stop hook should prefer project-local binary, got: %q", stopCmd) - } - if !strings.Contains(stopCmd, "hook continue-check") { - t.Errorf("Stop hook should call continue-check, got: %q", stopCmd) - } - if stopHook[0].Hooks[0].Timeout != 0 { - t.Errorf("Stop hook should rely on default timeout (0/omitted), got: %d", stopHook[0].Hooks[0].Timeout) - } -} - -func TestInitializer_InstallHooksConfig_MalformedJSON(t *testing.T) { - tmpDir := t.TempDir() - init := NewInitializer(tmpDir) - - // Create malformed settings.json - settingsDir := filepath.Join(tmpDir, ".claude") - if err := os.MkdirAll(settingsDir, 0755); err != nil { - t.Fatalf("Failed to create dir: %v", err) - } - settingsPath := filepath.Join(settingsDir, "settings.json") - if err := os.WriteFile(settingsPath, []byte("not valid json{"), 0644); err != nil { - t.Fatalf("Failed to write malformed JSON: %v", err) - } - - // Should return error for malformed JSON - err := init.InstallHooksConfig("claude", false) - if err == nil { - t.Error("Expected error for malformed JSON, got nil") - } -} - -func TestInitializer_InstallHooksConfig_ExistingHooks(t *testing.T) { - tmpDir := t.TempDir() - init := NewInitializer(tmpDir) - - // Create valid settings.json with hooks - settingsDir := filepath.Join(tmpDir, ".claude") - if err := os.MkdirAll(settingsDir, 0755); err != nil { - t.Fatalf("Failed to create dir: %v", err) - } - settingsPath := filepath.Join(settingsDir, "settings.json") - existingConfig := `{"hooks": {"Test": []}}` - if err := os.WriteFile(settingsPath, []byte(existingConfig), 0644); err != nil { - t.Fatalf("Failed to write existing config: %v", err) - } - - // Should preserve existing hooks and add missing TaskWing defaults - err := init.InstallHooksConfig("claude", false) - if err != nil { - t.Fatalf("InstallHooksConfig failed: %v", err) - } - - // Read back and verify hooks weren't changed - content, err := os.ReadFile(settingsPath) - if err != nil { - t.Fatalf("Failed to read settings.json: %v", err) - } - - var config map[string]any - if err := json.Unmarshal(content, &config); err != nil { - t.Fatalf("Invalid JSON: %v", err) - } - - hooks, ok := config["hooks"].(map[string]any) - if !ok { - t.Fatal("Hooks field missing or wrong type") - } - if _, ok := hooks["Test"]; !ok { - t.Error("Existing Test hook was removed") - } - if _, ok := hooks["SessionStart"]; !ok { - t.Error("SessionStart hook was not added") - } - if _, ok := hooks["Stop"]; !ok { - t.Error("Stop hook was not added") - } - if _, ok := hooks["SessionEnd"]; !ok { - t.Error("SessionEnd hook was not added") - } -} - -func TestInitializer_InstallHooksConfig_RepairsWrongStopCommand(t *testing.T) { - tmpDir := t.TempDir() - init := NewInitializer(tmpDir) - - settingsDir := filepath.Join(tmpDir, ".claude") - if err := os.MkdirAll(settingsDir, 0755); err != nil { - t.Fatalf("Failed to create dir: %v", err) - } - settingsPath := filepath.Join(settingsDir, "settings.json") - existingConfig := `{ - "hooks": { - "SessionStart": [{"hooks":[{"type":"command","command":"taskwing hook session-init"}]}], - "Stop": [{"hooks":[{"type":"command","command":"echo noop"}]}], - "SessionEnd": [{"hooks":[{"type":"command","command":"taskwing hook session-end"}]}] - } - }` - if err := os.WriteFile(settingsPath, []byte(existingConfig), 0644); err != nil { - t.Fatalf("Failed to write existing config: %v", err) - } - - if err := init.InstallHooksConfig("claude", false); err != nil { - t.Fatalf("InstallHooksConfig failed: %v", err) - } - - content, err := os.ReadFile(settingsPath) - if err != nil { - t.Fatalf("Failed to read settings.json: %v", err) - } - if !strings.Contains(string(content), "hook continue-check") { - t.Fatalf("Stop hook repair should inject continue-check command, got: %s", string(content)) - } -} - -func TestInitializer_InstallHooksConfig_UnsupportedAI(t *testing.T) { - tmpDir := t.TempDir() - init := NewInitializer(tmpDir) - - // Cursor doesn't support hooks - err := init.InstallHooksConfig("cursor", false) - if err != nil { - t.Errorf("Expected nil for unsupported AI, got: %v", err) - } -} - -func TestCreateSlashCommands_AllAIs(t *testing.T) { - for aiName := range aiHelpers { - t.Run(aiName, func(t *testing.T) { - tmpDir := t.TempDir() - init := NewInitializer(tmpDir) - - err := init.CreateSlashCommands(aiName, false) - if err != nil { - t.Errorf("CreateSlashCommands(%s) failed: %v", aiName, err) - } - - // Verify commands directory exists - cfg := aiHelpers[aiName] - cmdDir := filepath.Join(tmpDir, cfg.commandsDir) - if _, err := os.Stat(cmdDir); os.IsNotExist(err) { - t.Errorf("Commands directory not created for %s", aiName) - } - }) - } -} - -func TestCreateSlashCommands_UnknownAI(t *testing.T) { - tmpDir := t.TempDir() - init := NewInitializer(tmpDir) - - // Unknown AI should return nil (no error) - err := init.CreateSlashCommands("unknown-ai", false) - if err != nil { - t.Errorf("Expected nil for unknown AI, got: %v", err) - } -} - -func TestCreateSlashCommands_PrunesRemovedLegacyCommands(t *testing.T) { - tmpDir := t.TempDir() - init := NewInitializer(tmpDir) - - cmdDir := filepath.Join(tmpDir, ".claude", "commands") - if err := os.MkdirAll(cmdDir, 0755); err != nil { - t.Fatalf("Failed to create commands dir: %v", err) - } - if err := os.WriteFile(filepath.Join(cmdDir, "tw-context.md"), []byte("legacy"), 0644); err != nil { - t.Fatalf("Failed to write tw-context.md: %v", err) - } - if err := os.WriteFile(filepath.Join(cmdDir, "tw-block.md"), []byte("legacy"), 0644); err != nil { - t.Fatalf("Failed to write tw-block.md: %v", err) - } - - if err := init.CreateSlashCommands("claude", false); err != nil { - t.Fatalf("CreateSlashCommands(claude) failed: %v", err) - } - - if _, err := os.Stat(filepath.Join(cmdDir, "tw-context.md")); !os.IsNotExist(err) { - t.Error("tw-context.md should be removed during slash command regeneration") - } - if _, err := os.Stat(filepath.Join(cmdDir, "tw-block.md")); !os.IsNotExist(err) { - t.Error("tw-block.md should be removed during slash command regeneration") - } -} - -// Helper function to check string containment -func contains(s, substr string) bool { - return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr)) -} - -func containsHelper(s, substr string) bool { - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return true - } - } - return false -} - -// TestCopilotSingleFile tests Copilot single-file generation -func TestCopilotSingleFile(t *testing.T) { - tmpDir := t.TempDir() - init := NewInitializer(tmpDir) - - err := init.CreateSlashCommands("copilot", false) - if err != nil { - t.Fatalf("CreateSlashCommands(copilot) failed: %v", err) - } - - // Verify single file created (not a directory of files) - filePath := filepath.Join(tmpDir, ".github", "copilot-instructions.md") - content, err := os.ReadFile(filePath) - if err != nil { - t.Fatalf("Failed to read copilot-instructions.md: %v", err) - } - - // Verify marker is present - if !contains(string(content), "") { - t.Error("Missing TASKWING_MANAGED marker in copilot-instructions.md") - } - - // Verify version is present - if !contains(string(content), "\ntest content" - _ = os.WriteFile(filepath.Join(githubDir, "copilot-instructions.md"), []byte(content), 0644) - }, - expectStatus: HealthOK, - }, - { - name: "copilot - user-managed (no marker)", - aiName: "copilot", - setup: func(dir string) { - // User created their own copilot-instructions.md without TaskWing marker - githubDir := filepath.Join(dir, ".github") - _ = os.MkdirAll(githubDir, 0755) - content := "# My Custom Instructions\nDo this, not that." - _ = os.WriteFile(filepath.Join(githubDir, "copilot-instructions.md"), []byte(content), 0644) - }, - expectStatus: HealthOK, // User-managed = OK (we won't touch it) - }, - { - name: "copilot - empty file", - aiName: "copilot", - setup: func(dir string) { - // Empty file exists - githubDir := filepath.Join(dir, ".github") - _ = os.MkdirAll(githubDir, 0755) - _ = os.WriteFile(filepath.Join(githubDir, "copilot-instructions.md"), []byte(""), 0644) - }, - expectStatus: HealthOK, // Empty file = user-managed (no marker) - }, - { - name: "codex - partial (commands ok, hooks missing)", - aiName: "codex", - setup: func(dir string) { - cmdDir := filepath.Join(dir, ".codex", "commands") - _ = os.MkdirAll(cmdDir, 0755) - for _, name := range []string{"tw-brief", "tw-next", "tw-done", "tw-status", "tw-plan", "tw-debug", "tw-explain", "tw-simplify"} { - _ = os.WriteFile(filepath.Join(cmdDir, name+".md"), []byte("test"), 0644) - } - }, - expectStatus: HealthPartial, // Hooks missing - }, - { - name: "codex - ok", - aiName: "codex", - setup: func(dir string) { - cmdDir := filepath.Join(dir, ".codex", "commands") - _ = os.MkdirAll(cmdDir, 0755) - for _, name := range []string{"tw-brief", "tw-next", "tw-done", "tw-status", "tw-plan", "tw-debug", "tw-explain", "tw-simplify"} { - _ = os.WriteFile(filepath.Join(cmdDir, name+".md"), []byte("test"), 0644) - } - _ = os.WriteFile(filepath.Join(dir, ".codex", "settings.json"), []byte(`{"hooks":{"SessionStart":[{"hooks":[{"type":"command","command":"taskwing hook session-init"}]}],"Stop":[{"hooks":[{"type":"command","command":"taskwing hook continue-check --max-tasks=5 --max-minutes=30"}]}],"SessionEnd":[{"hooks":[{"type":"command","command":"taskwing hook session-end"}]}]}}`), 0644) - }, - expectStatus: HealthOK, - }, - { - name: "cursor - ok", - aiName: "cursor", - setup: func(dir string) { - cmdDir := filepath.Join(dir, ".cursor", "rules") - _ = os.MkdirAll(cmdDir, 0755) - for _, name := range []string{"tw-brief", "tw-next", "tw-done", "tw-status", "tw-plan", "tw-debug", "tw-explain", "tw-simplify"} { - _ = os.WriteFile(filepath.Join(cmdDir, name+".md"), []byte("test"), 0644) - } - }, - expectStatus: HealthOK, // Cursor doesn't need hooks - }, - { - name: "unknown - unsupported", - aiName: "unknown-ai", - setup: func(dir string) {}, - expectStatus: HealthUnsupported, - }, - } - - for _, tt := range tests { - name := tt.aiName - if tt.name != "" { - name = tt.name - } - t.Run(name, func(t *testing.T) { - tmpDir := t.TempDir() - tt.setup(tmpDir) - - health := probeAIHealth(tmpDir, tt.aiName) - if health.Status != tt.expectStatus { - t.Errorf("probeAIHealth(%q) status = %v, want %v (reason: %s)", - tt.aiName, health.Status, tt.expectStatus, health.Reason) - } - }) - } -} - -// TestDecidePlan_RepairMode_LocalPartialWithoutGlobalMCP tests that partial local configs -// without global MCP are detected for repair -func TestDecidePlan_RepairMode_LocalPartialWithoutGlobalMCP(t *testing.T) { - snapshot := &Snapshot{ - Project: ProjectHealth{Status: HealthOK, DirExists: true, MemoryDirExists: true, PlansDirExists: true}, - AIHealth: map[string]AIHealth{ - "claude": { - Name: "claude", - Status: HealthPartial, - CommandsDirExists: true, // Has local config - MarkerFileExists: true, // TaskWing created this directory - CommandFilesCount: 3, // But incomplete - GlobalMCPExists: false, // NO global MCP - Reason: "only 3/8 command files present", - }, - }, - HasAnyLocalAI: true, - ExistingLocalAI: []string{"claude"}, - HasAnyGlobalMCP: false, - } - - plan := DecidePlan(snapshot, Flags{}) - - // Should be ModeRepair, not ModeRun - if plan.Mode != ModeRepair { - t.Errorf("DecidePlan() mode = %v, want %v (partial local config should trigger repair)", - plan.Mode, ModeRepair) - } - - // Should have claude in AIsNeedingRepair - if !slices.Contains(plan.AIsNeedingRepair, "claude") { - t.Errorf("DecidePlan() AIsNeedingRepair = %v, should contain 'claude'", plan.AIsNeedingRepair) - } -} - -// TestProbeEnvironment_InvalidPath tests error handling for invalid paths -func TestProbeEnvironment_InvalidPath(t *testing.T) { - // Non-existent path - _, err := ProbeEnvironment("/non/existent/path/that/does/not/exist") - if err == nil { - t.Error("ProbeEnvironment() should return error for non-existent path") - } - - // File instead of directory - tmpFile := filepath.Join(t.TempDir(), "file.txt") - _ = os.WriteFile(tmpFile, []byte("test"), 0644) - _, err = ProbeEnvironment(tmpFile) - if err == nil { - t.Error("ProbeEnvironment() should return error for file path") - } -} - -// TestGlobalMCPDetector_Injection tests the GlobalMCPDetector injection -func TestGlobalMCPDetector_Injection(t *testing.T) { - // Reset after test - originalDetector := GlobalMCPDetector - defer func() { GlobalMCPDetector = originalDetector }() - - // Without injection - GlobalMCPDetector = nil - if checkGlobalMCPForAI("claude") { - t.Error("checkGlobalMCPForAI() should return false when no detector is injected") - } - - // With injection - GlobalMCPDetector = func(aiName string) bool { - return aiName == "claude" - } - if !checkGlobalMCPForAI("claude") { - t.Error("checkGlobalMCPForAI('claude') should return true when detector returns true") - } - if checkGlobalMCPForAI("gemini") { - t.Error("checkGlobalMCPForAI('gemini') should return false when detector returns false") - } -} - -// Helper -func containsString(s, substr string) bool { - return len(s) >= len(substr) && findSubstring(s, substr) -} - -func findSubstring(s, substr string) bool { - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return true - } - } - return false -} diff --git a/internal/bootstrap/runner_test.go b/internal/bootstrap/runner_test.go deleted file mode 100644 index 72904f8..0000000 --- a/internal/bootstrap/runner_test.go +++ /dev/null @@ -1,360 +0,0 @@ -package bootstrap - -import ( - "context" - "testing" - "time" - - "github.com/josephgoksu/TaskWing/internal/agents/core" -) - -// mockAgent is a simple agent implementation for testing -type mockAgent struct { - name string - description string - runFunc func(ctx context.Context, input core.Input) (core.Output, error) - closeFn func() -} - -func (m *mockAgent) Name() string { return m.name } -func (m *mockAgent) Description() string { return m.description } -func (m *mockAgent) Run(ctx context.Context, input core.Input) (core.Output, error) { - if m.runFunc != nil { - return m.runFunc(ctx, input) - } - return core.Output{AgentName: m.name}, nil -} - -// Close implements CloseableAgent for testing (optional, called via core.CloseAgents) -func (m *mockAgent) Close() error { - if m.closeFn != nil { - m.closeFn() - } - return nil -} - -func TestRunner_Close(t *testing.T) { - closed := false - agent := &mockAgent{ - name: "test-agent", - closeFn: func() { closed = true }, - } - - runner := &Runner{agents: []core.Agent{agent}} - runner.Close() - - if !closed { - t.Error("Agent was not closed") - } -} - -func TestRunner_Run_ContextCancelled(t *testing.T) { - agent := &mockAgent{ - name: "slow-agent", - runFunc: func(ctx context.Context, input core.Input) (core.Output, error) { - // Simulate slow work - select { - case <-ctx.Done(): - return core.Output{}, ctx.Err() - case <-time.After(5 * time.Second): - return core.Output{AgentName: "slow-agent"}, nil - } - }, - } - - runner := &Runner{agents: []core.Agent{agent}} - defer runner.Close() - - // Create already-cancelled context - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - _, err := runner.Run(ctx, "/test/path") - if err != context.Canceled { - t.Errorf("Expected context.Canceled error, got: %v", err) - } -} - -func TestRunner_Run_Success(t *testing.T) { - finding := core.Finding{ - Type: "test", - Title: "Test Finding", - Description: "Test description", - } - - agent := &mockAgent{ - name: "success-agent", - runFunc: func(ctx context.Context, input core.Input) (core.Output, error) { - return core.Output{ - AgentName: "success-agent", - Findings: []core.Finding{finding}, - }, nil - }, - } - - runner := &Runner{agents: []core.Agent{agent}} - defer runner.Close() - - results, err := runner.Run(context.Background(), "/test/path") - if err != nil { - t.Fatalf("Run failed: %v", err) - } - - if len(results) != 1 { - t.Fatalf("Expected 1 result, got %d", len(results)) - } - - if results[0].AgentName != "success-agent" { - t.Errorf("AgentName = %q, want %q", results[0].AgentName, "success-agent") - } - - if len(results[0].Findings) != 1 { - t.Errorf("Findings count = %d, want 1", len(results[0].Findings)) - } -} - -func TestRunner_Run_MultipleAgents(t *testing.T) { - agents := []core.Agent{ - &mockAgent{ - name: "agent-1", - runFunc: func(ctx context.Context, input core.Input) (core.Output, error) { - return core.Output{AgentName: "agent-1"}, nil - }, - }, - &mockAgent{ - name: "agent-2", - runFunc: func(ctx context.Context, input core.Input) (core.Output, error) { - return core.Output{AgentName: "agent-2"}, nil - }, - }, - } - - runner := &Runner{agents: agents} - defer runner.Close() - - results, err := runner.Run(context.Background(), "/test/path") - if err != nil { - t.Fatalf("Run failed: %v", err) - } - - if len(results) != 2 { - t.Fatalf("Expected 2 results, got %d", len(results)) - } -} - -func TestRunner_Run_PartialFailure(t *testing.T) { - agents := []core.Agent{ - &mockAgent{ - name: "success-agent", - runFunc: func(ctx context.Context, input core.Input) (core.Output, error) { - return core.Output{AgentName: "success-agent"}, nil - }, - }, - &mockAgent{ - name: "fail-agent", - runFunc: func(ctx context.Context, input core.Input) (core.Output, error) { - return core.Output{}, context.DeadlineExceeded - }, - }, - } - - runner := &Runner{agents: agents} - defer runner.Close() - - results, err := runner.Run(context.Background(), "/test/path") - // Should succeed with partial results - if err != nil { - t.Fatalf("Expected nil error for partial failure, got: %v", err) - } - - if len(results) != 1 { - t.Fatalf("Expected 1 result (partial success), got %d", len(results)) - } -} - -func TestRunner_Run_AllFailed(t *testing.T) { - agents := []core.Agent{ - &mockAgent{ - name: "fail-agent-1", - runFunc: func(ctx context.Context, input core.Input) (core.Output, error) { - return core.Output{}, context.DeadlineExceeded - }, - }, - &mockAgent{ - name: "fail-agent-2", - runFunc: func(ctx context.Context, input core.Input) (core.Output, error) { - return core.Output{}, context.Canceled - }, - }, - } - - runner := &Runner{agents: agents} - defer runner.Close() - - _, err := runner.Run(context.Background(), "/test/path") - if err == nil { - t.Error("Expected error when all agents fail") - } -} - -func TestRunner_Run_DurationTracking(t *testing.T) { - agent := &mockAgent{ - name: "timed-agent", - runFunc: func(ctx context.Context, input core.Input) (core.Output, error) { - time.Sleep(10 * time.Millisecond) - return core.Output{AgentName: "timed-agent"}, nil // Duration not set - }, - } - - runner := &Runner{agents: []core.Agent{agent}} - defer runner.Close() - - results, err := runner.Run(context.Background(), "/test/path") - if err != nil { - t.Fatalf("Run failed: %v", err) - } - - if len(results) != 1 { - t.Fatalf("Expected 1 result, got %d", len(results)) - } - - // Duration should be set by runner since agent didn't set it - if results[0].Duration < 10*time.Millisecond { - t.Errorf("Duration = %v, expected at least 10ms", results[0].Duration) - } -} - -// === Workspace Tagging Tests === - -func TestRunner_RunWithOptions_WorkspacePassed(t *testing.T) { - var receivedInput core.Input - agent := &mockAgent{ - name: "workspace-agent", - runFunc: func(ctx context.Context, input core.Input) (core.Output, error) { - receivedInput = input - return core.Output{AgentName: "workspace-agent"}, nil - }, - } - - runner := &Runner{agents: []core.Agent{agent}} - defer runner.Close() - - _, err := runner.RunWithOptions(context.Background(), "/test/path", RunOptions{Workspace: "osprey"}) - if err != nil { - t.Fatalf("RunWithOptions failed: %v", err) - } - - if receivedInput.Workspace != "osprey" { - t.Errorf("Input.Workspace = %q, want %q", receivedInput.Workspace, "osprey") - } -} - -func TestRunner_RunWithOptions_DefaultsToRoot(t *testing.T) { - var receivedInput core.Input - agent := &mockAgent{ - name: "workspace-agent", - runFunc: func(ctx context.Context, input core.Input) (core.Output, error) { - receivedInput = input - return core.Output{AgentName: "workspace-agent"}, nil - }, - } - - runner := &Runner{agents: []core.Agent{agent}} - defer runner.Close() - - // Empty workspace should default to "root" - _, err := runner.RunWithOptions(context.Background(), "/test/path", RunOptions{Workspace: ""}) - if err != nil { - t.Fatalf("RunWithOptions failed: %v", err) - } - - if receivedInput.Workspace != "root" { - t.Errorf("Input.Workspace = %q, want %q (default)", receivedInput.Workspace, "root") - } -} - -func TestRunner_Run_UsesRootWorkspace(t *testing.T) { - var receivedInput core.Input - agent := &mockAgent{ - name: "workspace-agent", - runFunc: func(ctx context.Context, input core.Input) (core.Output, error) { - receivedInput = input - return core.Output{AgentName: "workspace-agent"}, nil - }, - } - - runner := &Runner{agents: []core.Agent{agent}} - defer runner.Close() - - // Regular Run() should use "root" workspace - _, err := runner.Run(context.Background(), "/test/path") - if err != nil { - t.Fatalf("Run failed: %v", err) - } - - if receivedInput.Workspace != "root" { - t.Errorf("Input.Workspace = %q, want %q", receivedInput.Workspace, "root") - } -} - -func TestAgentsWorkspaceTagging(t *testing.T) { - // Test that agents receive workspace and can use it for tagging findings - tests := []struct { - name string - workspace string - wantWorkspace string - }{ - {"explicit workspace", "osprey", "osprey"}, - {"root workspace", "root", "root"}, - {"empty defaults to root", "", "root"}, - {"different service", "studio", "studio"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var capturedWorkspace string - agent := &mockAgent{ - name: "tagging-agent", - runFunc: func(ctx context.Context, input core.Input) (core.Output, error) { - capturedWorkspace = input.Workspace - // Agent can use input.Workspace to tag findings - return core.Output{ - AgentName: "tagging-agent", - Findings: []core.Finding{ - { - Type: "decision", - Title: "Test Decision", - Description: "A test decision", - Metadata: map[string]any{ - "workspace": input.Workspace, // Agents can tag findings - }, - }, - }, - }, nil - }, - } - - runner := &Runner{agents: []core.Agent{agent}} - defer runner.Close() - - results, err := runner.RunWithOptions(context.Background(), "/test/path", RunOptions{Workspace: tt.workspace}) - if err != nil { - t.Fatalf("RunWithOptions failed: %v", err) - } - - if capturedWorkspace != tt.wantWorkspace { - t.Errorf("captured workspace = %q, want %q", capturedWorkspace, tt.wantWorkspace) - } - - // Verify finding has workspace metadata - if len(results) > 0 && len(results[0].Findings) > 0 { - finding := results[0].Findings[0] - if ws, ok := finding.Metadata["workspace"].(string); ok { - if ws != tt.wantWorkspace { - t.Errorf("finding.Metadata[workspace] = %q, want %q", ws, tt.wantWorkspace) - } - } - } - }) - } -} diff --git a/internal/bootstrap/service.go b/internal/bootstrap/service.go index 76c886c..dc42026 100644 --- a/internal/bootstrap/service.go +++ b/internal/bootstrap/service.go @@ -76,6 +76,25 @@ func (s *Service) RunMultiRepoAnalysis(ctx context.Context, ws *project.Workspac // Aggregate findings - workspace tagging happens at agent level via Input.Workspace // We still set metadata["service"] for backward compatibility with ingestion findings := core.AggregateFindings(results) + + // Make evidence paths workspace-relative so verification resolves correctly. + // Evidence paths from agents are relative to servicePath, but verification + // uses s.basePath (workspace root). Prefixing with the service directory + // makes filepath.Join(workspaceRoot, "serviceDir/path") resolve correctly. + serviceRelPath, relErr := filepath.Rel(s.basePath, servicePath) + if relErr != nil { + serviceErrors = append(serviceErrors, fmt.Sprintf("%s: compute relative path: %s", serviceName, relErr.Error())) + continue + } + for i := range findings { + for j := range findings[i].Evidence { + ev := &findings[i].Evidence[j] + if ev.FilePath != "" && !filepath.IsAbs(ev.FilePath) { + ev.FilePath = filepath.Join(serviceRelPath, ev.FilePath) + } + } + } + for i := range findings { findings[i].Title = fmt.Sprintf("[%s] %s", serviceName, findings[i].Title) if findings[i].Metadata == nil { diff --git a/internal/bootstrap/service_test.go b/internal/bootstrap/service_test.go deleted file mode 100644 index 4aa3ec9..0000000 --- a/internal/bootstrap/service_test.go +++ /dev/null @@ -1,111 +0,0 @@ -package bootstrap - -import ( - "testing" - - "github.com/josephgoksu/TaskWing/internal/llm" -) - -func TestNewService(t *testing.T) { - basePath := "/test/path" - cfg := llm.Config{ - Provider: "openai", - Model: "gpt-4", - } - - svc := NewService(basePath, cfg) - - if svc == nil { - t.Fatal("NewService returned nil") - } - if svc.basePath != basePath { - t.Errorf("basePath = %q, want %q", svc.basePath, basePath) - } - if svc.initializer == nil { - t.Error("initializer is nil") - } -} - -func TestBootstrapResult(t *testing.T) { - result := &BootstrapResult{ - FindingsCount: 5, - Warnings: []string{"warning1", "warning2"}, - Errors: nil, - } - - if result.FindingsCount != 5 { - t.Errorf("FindingsCount = %d, want 5", result.FindingsCount) - } - if len(result.Warnings) != 2 { - t.Errorf("Warnings count = %d, want 2", len(result.Warnings)) - } - if result.Errors != nil { - t.Error("Errors should be nil") - } -} - -func TestJoinMax(t *testing.T) { - tests := []struct { - name string - parts []string - n int - expect string - }{ - { - name: "empty", - parts: []string{}, - n: 3, - expect: "", - }, - { - name: "fewer than n", - parts: []string{"a", "b"}, - n: 3, - expect: "a, b", - }, - { - name: "exactly n", - parts: []string{"a", "b", "c"}, - n: 3, - expect: "a, b, c", - }, - { - name: "more than n", - parts: []string{"a", "b", "c", "d", "e"}, - n: 3, - expect: "a, b, c, ...", - }, - { - name: "single item", - parts: []string{"only"}, - n: 3, - expect: "only", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := joinMax(tt.parts, tt.n) - if got != tt.expect { - t.Errorf("joinMax(%v, %d) = %q, want %q", tt.parts, tt.n, got, tt.expect) - } - }) - } -} - -func TestService_InitializeProject(t *testing.T) { - tmpDir := t.TempDir() - svc := NewService(tmpDir, llm.Config{}) - - // Test with empty AIs - err := svc.InitializeProject(false, []string{}) - if err != nil { - t.Errorf("InitializeProject with empty AIs failed: %v", err) - } - - // Test with valid AI - err = svc.InitializeProject(false, []string{"claude"}) - if err != nil { - t.Errorf("InitializeProject with claude failed: %v", err) - } -} diff --git a/internal/bootstrap/slash_descriptions_test.go b/internal/bootstrap/slash_descriptions_test.go deleted file mode 100644 index 61feff8..0000000 --- a/internal/bootstrap/slash_descriptions_test.go +++ /dev/null @@ -1,14 +0,0 @@ -package bootstrap - -import ( - "strings" - "testing" -) - -func TestSlashCommandsDescriptions_AreTriggerFocused(t *testing.T) { - for _, cmd := range SlashCommands { - if !strings.HasPrefix(cmd.Description, "Use when ") { - t.Fatalf("slash command %q description must start with 'Use when ': %q", cmd.BaseName, cmd.Description) - } - } -} diff --git a/internal/bootstrap/slash_parity_test.go b/internal/bootstrap/slash_parity_test.go deleted file mode 100644 index 71d6ef8..0000000 --- a/internal/bootstrap/slash_parity_test.go +++ /dev/null @@ -1,59 +0,0 @@ -package bootstrap - -import ( - "os" - "path/filepath" - "strings" - "testing" -) - -func TestCreateSlashCommands_CrossAssistantDescriptionParity(t *testing.T) { - tmpDir := t.TempDir() - init := NewInitializer(tmpDir) - - assistants := []string{"claude", "codex", "opencode"} - for _, ai := range assistants { - if err := init.CreateSlashCommands(ai, false); err != nil { - t.Fatalf("CreateSlashCommands(%s): %v", ai, err) - } - } - - for _, cmd := range SlashCommands { - expected := cmd.Description - - paths := map[string]string{ - "claude": filepath.Join(tmpDir, ".claude", "commands", cmd.BaseName+".md"), - "codex": filepath.Join(tmpDir, ".codex", "commands", cmd.BaseName+".md"), - "opencode": filepath.Join(tmpDir, ".opencode", "commands", cmd.BaseName+".md"), - } - - for ai, path := range paths { - desc, err := readCommandDescription(path) - if err != nil { - t.Fatalf("read description %s (%s): %v", ai, cmd.BaseName, err) - } - if desc != expected { - t.Fatalf("%s description mismatch for %s: got %q want %q", ai, cmd.BaseName, desc, expected) - } - } - } -} - -func readCommandDescription(path string) (string, error) { - data, err := os.ReadFile(path) - if err != nil { - return "", err - } - - for _, line := range strings.Split(string(data), "\n") { - line = strings.TrimSpace(line) - if !strings.HasPrefix(line, "description:") { - continue - } - desc := strings.TrimSpace(strings.TrimPrefix(line, "description:")) - desc = strings.Trim(desc, "\"") - return desc, nil - } - - return "", os.ErrNotExist -} diff --git a/internal/brief/brief.go b/internal/brief/brief.go index 0e782b1..3ee6a3b 100644 --- a/internal/brief/brief.go +++ b/internal/brief/brief.go @@ -16,7 +16,7 @@ import ( // No node IDs, file paths, or embeddings are included. // // This function is used by: -// - /tw-brief slash command +// - /tw-ask slash command (project knowledge brief) // - SessionStart hook auto-injection func GenerateCompactBrief(repo *memory.Repository) (string, error) { nodes, err := repo.ListNodes("") diff --git a/internal/codeintel/query_test.go b/internal/codeintel/query_test.go deleted file mode 100644 index f3f4b68..0000000 --- a/internal/codeintel/query_test.go +++ /dev/null @@ -1,627 +0,0 @@ -package codeintel - -import ( - "context" - "database/sql" - "testing" - - "github.com/josephgoksu/TaskWing/internal/llm" - _ "modernc.org/sqlite" -) - -// setupTestDB creates an in-memory SQLite database with the required schema. -func setupTestDB(t *testing.T) *sql.DB { - t.Helper() - - db, err := sql.Open("sqlite", ":memory:") - if err != nil { - t.Fatalf("open test db: %v", err) - } - - // Create required tables - schema := ` - CREATE TABLE IF NOT EXISTS symbols ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT NOT NULL, - kind TEXT NOT NULL, - file_path TEXT NOT NULL, - start_line INTEGER NOT NULL, - end_line INTEGER NOT NULL, - signature TEXT, - doc_comment TEXT, - module_path TEXT, - visibility TEXT NOT NULL DEFAULT 'private', - language TEXT NOT NULL DEFAULT 'unknown', - file_hash TEXT, - embedding BLOB, - last_modified TEXT NOT NULL DEFAULT (datetime('now')), - UNIQUE(name, file_path, start_line) - ); - - CREATE TABLE IF NOT EXISTS symbol_relations ( - from_symbol_id INTEGER NOT NULL, - to_symbol_id INTEGER NOT NULL, - relation_type TEXT NOT NULL, - call_site_line INTEGER, - metadata TEXT, - PRIMARY KEY (from_symbol_id, to_symbol_id, relation_type) - ); - - CREATE VIRTUAL TABLE IF NOT EXISTS symbols_fts USING fts5( - name, signature, doc_comment, module_path, - content='', - content_rowid='id' - ); - - CREATE TABLE IF NOT EXISTS dependencies ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT NOT NULL, - version TEXT NOT NULL, - ecosystem TEXT NOT NULL, - lockfile_ref TEXT NOT NULL, - resolved TEXT, - integrity TEXT, - is_dev INTEGER DEFAULT 0, - source TEXT, - extras TEXT, - last_modified TEXT NOT NULL DEFAULT (datetime('now')), - UNIQUE(name, version, lockfile_ref) - ); - - CREATE VIRTUAL TABLE IF NOT EXISTS dependencies_fts USING fts5( - name, version, ecosystem, - content='', - content_rowid='id' - ); - ` - - if _, err := db.Exec(schema); err != nil { - t.Fatalf("create schema: %v", err) - } - - return db -} - -// TestQueryService_FindSymbol tests symbol lookup by ID. -func TestQueryService_FindSymbol(t *testing.T) { - db := setupTestDB(t) - defer func() { _ = db.Close() }() - - repo := NewRepository(db) - qs := NewQueryService(repo, llm.Config{}) - ctx := context.Background() - - // Create a test symbol - sym := &Symbol{ - Name: "TestFunc", - Kind: SymbolFunction, - FilePath: "internal/app/service.go", - StartLine: 10, - EndLine: 20, - Signature: "func TestFunc(ctx context.Context) error", - Visibility: "public", - Language: "go", - } - - id, err := repo.UpsertSymbol(ctx, sym) - if err != nil { - t.Fatalf("upsert symbol: %v", err) - } - - t.Run("successful lookup", func(t *testing.T) { - found, err := qs.FindSymbol(ctx, id) - if err != nil { - t.Errorf("FindSymbol() error = %v", err) - return - } - - if found.Name != "TestFunc" { - t.Errorf("FindSymbol() name = %q, want %q", found.Name, "TestFunc") - } - if found.Kind != SymbolFunction { - t.Errorf("FindSymbol() kind = %q, want %q", found.Kind, SymbolFunction) - } - if found.FilePath != "internal/app/service.go" { - t.Errorf("FindSymbol() filepath = %q, want %q", found.FilePath, "internal/app/service.go") - } - }) - - t.Run("not found returns error", func(t *testing.T) { - _, err := qs.FindSymbol(ctx, 99999) - if err == nil { - t.Errorf("FindSymbol() expected error for non-existent ID, got nil") - } - }) -} - -// TestQueryService_FindSymbolByName tests symbol lookup by name. -func TestQueryService_FindSymbolByName(t *testing.T) { - db := setupTestDB(t) - defer func() { _ = db.Close() }() - - repo := NewRepository(db) - qs := NewQueryService(repo, llm.Config{}) - ctx := context.Background() - - // Create test symbols with same name in different files - symbols := []*Symbol{ - { - Name: "Handler", - Kind: SymbolFunction, - FilePath: "cmd/api/handler.go", - StartLine: 15, - EndLine: 30, - Visibility: "public", - Language: "go", - }, - { - Name: "Handler", - Kind: SymbolFunction, - FilePath: "cmd/web/handler.go", - StartLine: 10, - EndLine: 25, - Visibility: "public", - Language: "go", - }, - { - Name: "Handler", - Kind: SymbolFunction, - FilePath: "pkg/handler.ts", - StartLine: 5, - EndLine: 20, - Visibility: "public", - Language: "typescript", - }, - } - - for _, sym := range symbols { - if _, err := repo.UpsertSymbol(ctx, sym); err != nil { - t.Fatalf("upsert symbol: %v", err) - } - } - - t.Run("finds all symbols with name", func(t *testing.T) { - found, err := qs.FindSymbolByName(ctx, "Handler") - if err != nil { - t.Errorf("FindSymbolByName() error = %v", err) - return - } - - if len(found) != 3 { - t.Errorf("FindSymbolByName() count = %d, want 3", len(found)) - } - }) - - t.Run("finds symbols filtered by language", func(t *testing.T) { - found, err := qs.FindSymbolByNameAndLang(ctx, "Handler", "go") - if err != nil { - t.Errorf("FindSymbolByNameAndLang() error = %v", err) - return - } - - if len(found) != 2 { - t.Errorf("FindSymbolByNameAndLang() count = %d, want 2", len(found)) - } - - for _, s := range found { - if s.Language != "go" { - t.Errorf("FindSymbolByNameAndLang() found language %q, want go", s.Language) - } - } - }) - - t.Run("not found returns empty slice", func(t *testing.T) { - found, err := qs.FindSymbolByName(ctx, "NonExistentFunction") - if err != nil { - t.Errorf("FindSymbolByName() error = %v", err) - return - } - - if len(found) != 0 { - t.Errorf("FindSymbolByName() count = %d, want 0", len(found)) - } - }) -} - -// TestQueryService_AnalyzeImpact tests impact analysis with dependency graph traversal. -func TestQueryService_AnalyzeImpact(t *testing.T) { - db := setupTestDB(t) - defer func() { _ = db.Close() }() - - repo := NewRepository(db) - qs := NewQueryService(repo, llm.Config{}) - ctx := context.Background() - - // Create a call graph: - // main() -> handler() -> service() -> repository() - // -> cache() -> repository() - // - // Changing repository() should impact: service (depth 1), cache (depth 1), - // handler (depth 2), main (depth 3) - // - // Note: repository appears at multiple depths via different paths, - // but should be deduplicated to show only at lowest depth. - - symbols := map[string]*Symbol{ - "main": { - Name: "main", Kind: SymbolFunction, FilePath: "cmd/main.go", - StartLine: 10, EndLine: 20, Visibility: "private", Language: "go", - }, - "handler": { - Name: "handler", Kind: SymbolFunction, FilePath: "internal/api/handler.go", - StartLine: 15, EndLine: 30, Visibility: "public", Language: "go", - }, - "service": { - Name: "service", Kind: SymbolFunction, FilePath: "internal/app/service.go", - StartLine: 20, EndLine: 40, Visibility: "public", Language: "go", - }, - "cache": { - Name: "cache", Kind: SymbolFunction, FilePath: "internal/cache/cache.go", - StartLine: 10, EndLine: 25, Visibility: "public", Language: "go", - }, - "repository": { - Name: "repository", Kind: SymbolFunction, FilePath: "internal/store/repo.go", - StartLine: 30, EndLine: 50, Visibility: "public", Language: "go", - }, - } - - // Insert symbols and collect IDs - ids := make(map[string]uint32) - for name, sym := range symbols { - id, err := repo.UpsertSymbol(ctx, sym) - if err != nil { - t.Fatalf("upsert symbol %s: %v", name, err) - } - ids[name] = id - } - - // Create call relationships - relations := []struct { - from, to string - }{ - {"main", "handler"}, - {"handler", "service"}, - {"handler", "cache"}, - {"service", "repository"}, - {"cache", "repository"}, - } - - for _, rel := range relations { - err := repo.UpsertRelation(ctx, &SymbolRelation{ - FromSymbolID: ids[rel.from], - ToSymbolID: ids[rel.to], - RelationType: RelationCalls, - }) - if err != nil { - t.Fatalf("upsert relation %s->%s: %v", rel.from, rel.to, err) - } - } - - t.Run("finds all affected symbols", func(t *testing.T) { - analysis, err := qs.AnalyzeImpact(ctx, ids["repository"], 5) - if err != nil { - t.Errorf("AnalyzeImpact() error = %v", err) - return - } - - if analysis.Source.Name != "repository" { - t.Errorf("AnalyzeImpact() source = %q, want repository", analysis.Source.Name) - } - - // Should find: service (d1), cache (d1), handler (d2), main (d3) - if analysis.AffectedCount != 4 { - t.Errorf("AnalyzeImpact() count = %d, want 4", analysis.AffectedCount) - t.Logf("Affected symbols:") - for _, node := range analysis.Affected { - t.Logf(" - %s (depth %d)", node.Symbol.Name, node.Depth) - } - } - - // Verify depth grouping - if len(analysis.ByDepth[1]) != 2 { - t.Errorf("AnalyzeImpact() depth 1 count = %d, want 2", len(analysis.ByDepth[1])) - } - if len(analysis.ByDepth[2]) != 1 { - t.Errorf("AnalyzeImpact() depth 2 count = %d, want 1", len(analysis.ByDepth[2])) - } - if len(analysis.ByDepth[3]) != 1 { - t.Errorf("AnalyzeImpact() depth 3 count = %d, want 1", len(analysis.ByDepth[3])) - } - }) - - t.Run("respects max depth limit", func(t *testing.T) { - analysis, err := qs.AnalyzeImpact(ctx, ids["repository"], 1) - if err != nil { - t.Errorf("AnalyzeImpact() error = %v", err) - return - } - - // With maxDepth=1, should only find: service (d1), cache (d1) - if analysis.AffectedCount != 2 { - t.Errorf("AnalyzeImpact() with maxDepth=1 count = %d, want 2", analysis.AffectedCount) - } - }) - - t.Run("handles symbol with no callers", func(t *testing.T) { - analysis, err := qs.AnalyzeImpact(ctx, ids["main"], 5) - if err != nil { - t.Errorf("AnalyzeImpact() error = %v", err) - return - } - - // main has no callers - if analysis.AffectedCount != 0 { - t.Errorf("AnalyzeImpact() for main count = %d, want 0", analysis.AffectedCount) - } - }) -} - -// TestQueryService_AnalyzeImpact_CyclicGraph tests impact analysis handles cycles correctly. -func TestQueryService_AnalyzeImpact_CyclicGraph(t *testing.T) { - db := setupTestDB(t) - defer func() { _ = db.Close() }() - - repo := NewRepository(db) - qs := NewQueryService(repo, llm.Config{}) - ctx := context.Background() - - // Create a cyclic call graph: - // A -> B -> C -> A (cycle) - // - // When analyzing C, we should find B (d1), A (d2) - // The cycle should be handled without infinite recursion. - - symbols := []*Symbol{ - {Name: "A", Kind: SymbolFunction, FilePath: "a.go", StartLine: 1, EndLine: 10, Visibility: "public", Language: "go"}, - {Name: "B", Kind: SymbolFunction, FilePath: "b.go", StartLine: 1, EndLine: 10, Visibility: "public", Language: "go"}, - {Name: "C", Kind: SymbolFunction, FilePath: "c.go", StartLine: 1, EndLine: 10, Visibility: "public", Language: "go"}, - } - - ids := make(map[string]uint32) - for _, sym := range symbols { - id, err := repo.UpsertSymbol(ctx, sym) - if err != nil { - t.Fatalf("upsert symbol %s: %v", sym.Name, err) - } - ids[sym.Name] = id - } - - // Create cyclic relationships: A->B, B->C, C->A - relations := []struct{ from, to string }{ - {"A", "B"}, - {"B", "C"}, - {"C", "A"}, - } - - for _, rel := range relations { - err := repo.UpsertRelation(ctx, &SymbolRelation{ - FromSymbolID: ids[rel.from], - ToSymbolID: ids[rel.to], - RelationType: RelationCalls, - }) - if err != nil { - t.Fatalf("upsert relation: %v", err) - } - } - - t.Run("handles cyclic graph without infinite loop", func(t *testing.T) { - analysis, err := qs.AnalyzeImpact(ctx, ids["C"], 10) - if err != nil { - t.Errorf("AnalyzeImpact() error = %v", err) - return - } - - // Should complete without timeout or stack overflow - // Due to cycle, we expect B (d1), A (d2), and then the cycle continues - // but depth limit should prevent infinite recursion - if analysis.Source.Name != "C" { - t.Errorf("AnalyzeImpact() source = %q, want C", analysis.Source.Name) - } - - // B directly calls C, A calls B (which calls C) - // After that, C calls A which starts the cycle again - // The deduplication should ensure each symbol appears only once - t.Logf("Found %d affected symbols in cyclic graph", analysis.AffectedCount) - for _, node := range analysis.Affected { - t.Logf(" - %s (depth %d)", node.Symbol.Name, node.Depth) - } - }) -} - -// TestQueryService_GetCallers tests the GetCallers function. -func TestQueryService_GetCallers(t *testing.T) { - db := setupTestDB(t) - defer func() { _ = db.Close() }() - - repo := NewRepository(db) - qs := NewQueryService(repo, llm.Config{}) - ctx := context.Background() - - // Create symbols - target := &Symbol{Name: "target", Kind: SymbolFunction, FilePath: "target.go", StartLine: 1, EndLine: 10, Visibility: "public", Language: "go"} - caller1 := &Symbol{Name: "caller1", Kind: SymbolFunction, FilePath: "caller1.go", StartLine: 1, EndLine: 10, Visibility: "public", Language: "go"} - caller2 := &Symbol{Name: "caller2", Kind: SymbolFunction, FilePath: "caller2.go", StartLine: 1, EndLine: 10, Visibility: "public", Language: "go"} - - targetID, _ := repo.UpsertSymbol(ctx, target) - caller1ID, _ := repo.UpsertSymbol(ctx, caller1) - caller2ID, _ := repo.UpsertSymbol(ctx, caller2) - - // caller1 and caller2 both call target - _ = repo.UpsertRelation(ctx, &SymbolRelation{FromSymbolID: caller1ID, ToSymbolID: targetID, RelationType: RelationCalls}) - _ = repo.UpsertRelation(ctx, &SymbolRelation{FromSymbolID: caller2ID, ToSymbolID: targetID, RelationType: RelationCalls}) - - t.Run("returns all callers", func(t *testing.T) { - callers, err := qs.GetCallers(ctx, targetID) - if err != nil { - t.Errorf("GetCallers() error = %v", err) - return - } - - if len(callers) != 2 { - t.Errorf("GetCallers() count = %d, want 2", len(callers)) - } - }) - - t.Run("returns empty for symbol with no callers", func(t *testing.T) { - callers, err := qs.GetCallers(ctx, caller1ID) - if err != nil { - t.Errorf("GetCallers() error = %v", err) - return - } - - if len(callers) != 0 { - t.Errorf("GetCallers() count = %d, want 0", len(callers)) - } - }) -} - -// TestQueryService_GetCallees tests the GetCallees function. -func TestQueryService_GetCallees(t *testing.T) { - db := setupTestDB(t) - defer func() { _ = db.Close() }() - - repo := NewRepository(db) - qs := NewQueryService(repo, llm.Config{}) - ctx := context.Background() - - // Create symbols - caller := &Symbol{Name: "caller", Kind: SymbolFunction, FilePath: "caller.go", StartLine: 1, EndLine: 10, Visibility: "public", Language: "go"} - target1 := &Symbol{Name: "target1", Kind: SymbolFunction, FilePath: "target1.go", StartLine: 1, EndLine: 10, Visibility: "public", Language: "go"} - target2 := &Symbol{Name: "target2", Kind: SymbolFunction, FilePath: "target2.go", StartLine: 1, EndLine: 10, Visibility: "public", Language: "go"} - - callerID, _ := repo.UpsertSymbol(ctx, caller) - target1ID, _ := repo.UpsertSymbol(ctx, target1) - target2ID, _ := repo.UpsertSymbol(ctx, target2) - - // caller calls both target1 and target2 - _ = repo.UpsertRelation(ctx, &SymbolRelation{FromSymbolID: callerID, ToSymbolID: target1ID, RelationType: RelationCalls}) - _ = repo.UpsertRelation(ctx, &SymbolRelation{FromSymbolID: callerID, ToSymbolID: target2ID, RelationType: RelationCalls}) - - t.Run("returns all callees", func(t *testing.T) { - callees, err := qs.GetCallees(ctx, callerID) - if err != nil { - t.Errorf("GetCallees() error = %v", err) - return - } - - if len(callees) != 2 { - t.Errorf("GetCallees() count = %d, want 2", len(callees)) - } - }) - - t.Run("returns empty for symbol with no callees", func(t *testing.T) { - callees, err := qs.GetCallees(ctx, target1ID) - if err != nil { - t.Errorf("GetCallees() error = %v", err) - return - } - - if len(callees) != 0 { - t.Errorf("GetCallees() count = %d, want 0", len(callees)) - } - }) -} - -// TestQueryService_NotFoundScenarios tests various not-found scenarios. -func TestQueryService_NotFoundScenarios(t *testing.T) { - db := setupTestDB(t) - defer func() { _ = db.Close() }() - - repo := NewRepository(db) - qs := NewQueryService(repo, llm.Config{}) - ctx := context.Background() - - t.Run("FindSymbol with invalid ID returns error", func(t *testing.T) { - _, err := qs.FindSymbol(ctx, 12345) - if err == nil { - t.Error("FindSymbol() expected error for non-existent ID") - } - }) - - t.Run("FindSymbolByName with no matches returns empty", func(t *testing.T) { - result, err := qs.FindSymbolByName(ctx, "NonExistent") - if err != nil { - t.Errorf("FindSymbolByName() unexpected error: %v", err) - } - if len(result) != 0 { - t.Errorf("FindSymbolByName() expected empty slice, got %d items", len(result)) - } - }) - - t.Run("AnalyzeImpact with invalid ID returns error", func(t *testing.T) { - _, err := qs.AnalyzeImpact(ctx, 99999, 5) - if err == nil { - t.Error("AnalyzeImpact() expected error for non-existent ID") - } - }) - - t.Run("GetCallers with non-existent ID returns empty", func(t *testing.T) { - callers, err := qs.GetCallers(ctx, 99999) - if err != nil { - t.Errorf("GetCallers() unexpected error: %v", err) - } - if len(callers) != 0 { - t.Errorf("GetCallers() expected empty slice, got %d items", len(callers)) - } - }) - - t.Run("GetCallees with non-existent ID returns empty", func(t *testing.T) { - callees, err := qs.GetCallees(ctx, 99999) - if err != nil { - t.Errorf("GetCallees() unexpected error: %v", err) - } - if len(callees) != 0 { - t.Errorf("GetCallees() expected empty slice, got %d items", len(callees)) - } - }) -} - -// TestCosineSimilarity tests the cosine similarity function. -func TestCosineSimilarity(t *testing.T) { - tests := []struct { - name string - a, b []float32 - want float32 - }{ - { - name: "identical vectors", - a: []float32{1, 0, 0}, - b: []float32{1, 0, 0}, - want: 1.0, - }, - { - name: "orthogonal vectors", - a: []float32{1, 0, 0}, - b: []float32{0, 1, 0}, - want: 0.0, - }, - { - name: "opposite vectors", - a: []float32{1, 0, 0}, - b: []float32{-1, 0, 0}, - want: -1.0, - }, - { - name: "empty vectors", - a: []float32{}, - b: []float32{}, - want: 0.0, - }, - { - name: "different length vectors", - a: []float32{1, 2}, - b: []float32{1, 2, 3}, - want: 0.0, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := cosineSimilarity(tt.a, tt.b) - // Allow small floating point error - if diff := got - tt.want; diff > 0.0001 || diff < -0.0001 { - t.Errorf("cosineSimilarity() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/internal/config/llm_loader_test.go b/internal/config/llm_loader_test.go deleted file mode 100644 index ae62ce2..0000000 --- a/internal/config/llm_loader_test.go +++ /dev/null @@ -1,261 +0,0 @@ -package config - -import ( - "strings" - "testing" - - "github.com/josephgoksu/TaskWing/internal/llm" - "github.com/spf13/viper" -) - -func resetViperForTest(t *testing.T) { - t.Helper() - viper.Reset() - t.Cleanup(viper.Reset) -} - -func TestResolveBedrockBaseURL_FromRegion(t *testing.T) { - resetViperForTest(t) - viper.Set("llm.bedrock.region", "us-west-2") - - got, err := ResolveBedrockBaseURL() - if err != nil { - t.Fatalf("ResolveBedrockBaseURL() error = %v", err) - } - want := "https://bedrock-runtime.us-west-2.amazonaws.com/openai/v1" - t.Logf("resolved Bedrock baseURL: %s", got) - if got != want { - t.Fatalf("ResolveBedrockBaseURL() = %q, want %q", got, want) - } -} - -func TestResolveBedrockBaseURL_MissingRegion(t *testing.T) { - resetViperForTest(t) - t.Setenv("AWS_REGION", "") - t.Setenv("AWS_DEFAULT_REGION", "") - - _, err := ResolveBedrockBaseURL() - if err == nil { - t.Fatal("ResolveBedrockBaseURL() error = nil, want missing-region error") - } - if !strings.Contains(err.Error(), "llm.bedrock.region") { - t.Fatalf("ResolveBedrockBaseURL() error = %v, want llm.bedrock.region guidance", err) - } -} - -func TestValidateBedrockBaseURL(t *testing.T) { - tests := []struct { - name string - url string - wantErr bool - }{ - { - name: "valid bedrock endpoint", - url: "https://bedrock-runtime.us-east-1.amazonaws.com/openai/v1", - wantErr: false, - }, - { - name: "reject non bedrock host", - url: "https://api.openai.com/v1", - wantErr: true, - }, - { - name: "reject invalid path", - url: "https://bedrock-runtime.us-east-1.amazonaws.com/v1", - wantErr: true, - }, - } - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - err := ValidateBedrockBaseURL(tc.url) - if tc.wantErr && err == nil { - t.Fatalf("ValidateBedrockBaseURL(%q) error = nil, want error", tc.url) - } - if !tc.wantErr && err != nil { - t.Fatalf("ValidateBedrockBaseURL(%q) error = %v", tc.url, err) - } - }) - } -} - -func TestResolveBedrockRegion_FromEnvVar(t *testing.T) { - resetViperForTest(t) - - // Simulate TASKWING_LLM_BEDROCK_REGION env var via Viper auto-bind. - // Viper maps TASKWING_LLM_BEDROCK_REGION → llm.bedrock.region - // when SetEnvPrefix("TASKWING") + SetEnvKeyReplacer("." → "_") is configured. - // We test the underlying resolution directly via Viper Set to verify the path. - t.Setenv("AWS_REGION", "") - t.Setenv("AWS_DEFAULT_REGION", "") - viper.Set("llm.bedrock.region", "eu-west-1") - - region := ResolveBedrockRegion() - if region != "eu-west-1" { - t.Fatalf("ResolveBedrockRegion() = %q, want %q", region, "eu-west-1") - } - - got, err := ResolveBedrockBaseURL() - if err != nil { - t.Fatalf("ResolveBedrockBaseURL() error = %v", err) - } - want := "https://bedrock-runtime.eu-west-1.amazonaws.com/openai/v1" - if got != want { - t.Fatalf("ResolveBedrockBaseURL() = %q, want %q", got, want) - } -} - -func TestResolveBedrockRegion_FallsBackToAWSRegion(t *testing.T) { - resetViperForTest(t) - - // No Viper config, should fall back to AWS_REGION env var - t.Setenv("AWS_REGION", "ap-southeast-1") - t.Setenv("AWS_DEFAULT_REGION", "") - - region := ResolveBedrockRegion() - if region != "ap-southeast-1" { - t.Fatalf("ResolveBedrockRegion() = %q, want %q", region, "ap-southeast-1") - } -} - -func TestResolveBedrockRegion_FallsBackToAWSDefaultRegion(t *testing.T) { - resetViperForTest(t) - - t.Setenv("AWS_REGION", "") - t.Setenv("AWS_DEFAULT_REGION", "us-west-1") - - region := ResolveBedrockRegion() - if region != "us-west-1" { - t.Fatalf("ResolveBedrockRegion() = %q, want %q", region, "us-west-1") - } -} - -func TestLoadLLMConfig_BedrockEmbeddingDefault(t *testing.T) { - resetViperForTest(t) - viper.Set("llm.provider", "bedrock") - viper.Set("llm.model", "us.anthropic.claude-sonnet-4-5-20250929-v1:0") - viper.Set("llm.bedrock.region", "us-east-1") - viper.Set("llm.apiKeys.bedrock", "test-bedrock-key") - - cfg, err := LoadLLMConfig() - if err != nil { - t.Fatalf("LoadLLMConfig() error = %v", err) - } - if cfg.EmbeddingModel != llm.DefaultBedrockEmbeddingModel { - t.Fatalf("LoadLLMConfig() EmbeddingModel = %q, want %q", cfg.EmbeddingModel, llm.DefaultBedrockEmbeddingModel) - } -} - -func TestLoadLLMConfig_Bedrock(t *testing.T) { - resetViperForTest(t) - viper.Set("llm.provider", "bedrock") - viper.Set("llm.model", "us.anthropic.claude-sonnet-4-5-20250929-v1:0") - viper.Set("llm.bedrock.region", "us-east-1") - viper.Set("llm.apiKeys.bedrock", "test-bedrock-key") - - cfg, err := LoadLLMConfig() - if err != nil { - t.Fatalf("LoadLLMConfig() error = %v", err) - } - if cfg.Provider != llm.ProviderBedrock { - t.Fatalf("LoadLLMConfig() provider = %q, want %q", cfg.Provider, llm.ProviderBedrock) - } - if cfg.BaseURL != "https://bedrock-runtime.us-east-1.amazonaws.com/openai/v1" { - t.Fatalf("LoadLLMConfig() baseURL = %q", cfg.BaseURL) - } - if cfg.APIKey != "test-bedrock-key" { - t.Fatalf("LoadLLMConfig() apiKey mismatch") - } -} - -// ============================================ -// TaskWing managed provider tests -// ============================================ - -func TestResolveProviderBaseURL_TaskWing_Default(t *testing.T) { - resetViperForTest(t) - got, err := ResolveProviderBaseURL(llm.ProviderTaskWing) - if err != nil { - t.Fatalf("ResolveProviderBaseURL(taskwing) error = %v", err) - } - if got != llm.DefaultTaskWingURL { - t.Fatalf("ResolveProviderBaseURL(taskwing) = %q, want %q", got, llm.DefaultTaskWingURL) - } -} - -func TestResolveProviderBaseURL_TaskWing_Custom(t *testing.T) { - resetViperForTest(t) - customURL := "https://custom.inference.example.com/v1" - viper.Set("llm.taskwing.base_url", customURL) - got, err := ResolveProviderBaseURL(llm.ProviderTaskWing) - if err != nil { - t.Fatalf("ResolveProviderBaseURL(taskwing) error = %v", err) - } - if got != customURL { - t.Fatalf("ResolveProviderBaseURL(taskwing) = %q, want %q", got, customURL) - } -} - -func TestParseModelSpec_TaskWingProvider(t *testing.T) { - resetViperForTest(t) - t.Setenv("TASKWING_API_KEY", "tw-test-key") - - cfg, err := ParseModelSpec("taskwing:karluk", llm.RoleBootstrap) - if err != nil { - t.Fatalf("ParseModelSpec(taskwing:karluk) error = %v", err) - } - if cfg.Provider != llm.ProviderTaskWing { - t.Fatalf("provider = %q, want %q", cfg.Provider, llm.ProviderTaskWing) - } - if cfg.Model != "karluk" { - t.Fatalf("model = %q, want karluk", cfg.Model) - } - if cfg.APIKey != "tw-test-key" { - t.Fatalf("apiKey = %q, want tw-test-key", cfg.APIKey) - } - if cfg.BaseURL != llm.DefaultTaskWingURL { - t.Fatalf("baseURL = %q, want %q", cfg.BaseURL, llm.DefaultTaskWingURL) - } -} - -// Regression test: a stale llm.baseURL pointing to localhost must not leak into -// cloud providers (OpenAI, Anthropic, Gemini). This was the root cause of bootstrap -// agents hitting localhost:11434 despite being configured for OpenAI. -func TestResolveProviderBaseURL_OpenAI_IgnoresLocalhostBaseURL(t *testing.T) { - resetViperForTest(t) - viper.Set("llm.baseURL", "http://localhost:11434") - - got, err := ResolveProviderBaseURL(llm.ProviderOpenAI) - if err != nil { - t.Fatalf("ResolveProviderBaseURL(openai) error = %v", err) - } - if got != "" { - t.Fatalf("ResolveProviderBaseURL(openai) = %q, want empty (should ignore localhost baseURL)", got) - } -} - -func TestResolveProviderBaseURL_OpenAI_AllowsCustomEndpoint(t *testing.T) { - resetViperForTest(t) - customURL := "https://my-proxy.example.com/v1" - viper.Set("llm.baseURL", customURL) - - got, err := ResolveProviderBaseURL(llm.ProviderOpenAI) - if err != nil { - t.Fatalf("ResolveProviderBaseURL(openai) error = %v", err) - } - if got != customURL { - t.Fatalf("ResolveProviderBaseURL(openai) = %q, want %q", got, customURL) - } -} - -func TestParseModelSpec_TaskWingNoKey(t *testing.T) { - resetViperForTest(t) - t.Setenv("TASKWING_API_KEY", "") - - _, err := ParseModelSpec("taskwing:karluk", llm.RoleBootstrap) - if err == nil { - t.Fatal("ParseModelSpec(taskwing:...) should fail without TASKWING_API_KEY") - } - if !strings.Contains(err.Error(), "TASKWING_API_KEY") { - t.Fatalf("error should mention TASKWING_API_KEY, got: %v", err) - } -} diff --git a/internal/config/paths_test.go b/internal/config/paths_test.go deleted file mode 100644 index c75dbbd..0000000 --- a/internal/config/paths_test.go +++ /dev/null @@ -1,180 +0,0 @@ -package config - -import ( - "errors" - "testing" - - "github.com/josephgoksu/TaskWing/internal/project" -) - -func TestSetProjectContext_NilReturnsError(t *testing.T) { - // Clear any existing context - ClearProjectContext() - - err := SetProjectContext(nil) - if err == nil { - t.Fatal("expected error for nil context, got nil") - } - - // Verify error message is helpful - if err.Error() != "SetProjectContext called with nil context" { - t.Errorf("unexpected error message: %s", err.Error()) - } -} - -func TestSetProjectContext_ValidContext(t *testing.T) { - // Clear any existing context - ClearProjectContext() - defer ClearProjectContext() - - ctx := &project.Context{ - RootPath: "/test/path", - MarkerType: project.MarkerGit, - } - - err := SetProjectContext(ctx) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - // Verify context was set - got := GetProjectContext() - if got == nil { - t.Fatal("expected context to be set") - } - if got.RootPath != ctx.RootPath { - t.Errorf("expected RootPath %q, got %q", ctx.RootPath, got.RootPath) - } -} - -func TestGetProjectContextOrError_NotSet(t *testing.T) { - ClearProjectContext() - - ctx, err := GetProjectContextOrError() - if err == nil { - t.Fatal("expected error when context not set") - } - if !errors.Is(err, ErrProjectContextNotSet) { - t.Errorf("expected ErrProjectContextNotSet, got: %v", err) - } - if ctx != nil { - t.Error("expected nil context") - } -} - -func TestGetProjectContextOrError_Set(t *testing.T) { - ClearProjectContext() - defer ClearProjectContext() - - expected := &project.Context{RootPath: "/test"} - _ = SetProjectContext(expected) - - ctx, err := GetProjectContextOrError() - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if ctx != expected { - t.Error("context does not match expected") - } -} - -func TestGetProjectRoot_NotSet(t *testing.T) { - ClearProjectContext() - - root, err := GetProjectRoot() - if err == nil { - t.Fatal("expected error when context not set") - } - if !errors.Is(err, ErrProjectContextNotSet) { - t.Errorf("expected ErrProjectContextNotSet, got: %v", err) - } - if root != "" { - t.Errorf("expected empty root, got: %s", root) - } -} - -func TestGetProjectRoot_EmptyRootPath(t *testing.T) { - ClearProjectContext() - defer ClearProjectContext() - - ctx := &project.Context{RootPath: ""} - _ = SetProjectContext(ctx) - - root, err := GetProjectRoot() - if err == nil { - t.Fatal("expected error for empty RootPath") - } - if root != "" { - t.Errorf("expected empty root, got: %s", root) - } -} - -func TestGetProjectRoot_Valid(t *testing.T) { - ClearProjectContext() - defer ClearProjectContext() - - expected := "/my/project" - ctx := &project.Context{RootPath: expected} - _ = SetProjectContext(ctx) - - root, err := GetProjectRoot() - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if root != expected { - t.Errorf("expected %q, got %q", expected, root) - } -} - -func TestGetMemoryBasePath_NotSet(t *testing.T) { - ClearProjectContext() - - path, err := GetMemoryBasePath() - if err == nil { - t.Fatal("expected error when context not set") - } - if !errors.Is(err, ErrProjectContextNotSet) { - t.Errorf("expected ErrProjectContextNotSet, got: %v", err) - } - if path != "" { - t.Errorf("expected empty path, got: %s", path) - } -} - -func TestGetMemoryBasePathOrGlobal_FallsBackToGlobal(t *testing.T) { - ClearProjectContext() - - // Should fall back to global without error - path, err := GetMemoryBasePathOrGlobal() - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if path == "" { - t.Error("expected non-empty path") - } - // Should contain "memory" in the path - if len(path) < 6 || path[len(path)-6:] != "memory" { - t.Errorf("expected path to end with 'memory', got: %s", path) - } -} - -func TestGetMemoryBasePathOrGlobal_GlobalDirError(t *testing.T) { - ClearProjectContext() - - // Save original function - original := GetGlobalConfigDir - defer func() { GetGlobalConfigDir = original }() - - // Mock to return error - GetGlobalConfigDir = func() (string, error) { - return "", errors.New("test error: cannot get home dir") - } - - path, err := GetMemoryBasePathOrGlobal() - if err == nil { - t.Fatal("expected error when global config dir fails") - } - if path != "" { - t.Errorf("expected empty path on error, got: %s", path) - } -} diff --git a/internal/config/prompts.go b/internal/config/prompts.go index d9256ad..ed309f3 100644 --- a/internal/config/prompts.go +++ b/internal/config/prompts.go @@ -369,7 +369,7 @@ ESSENTIAL complexity (business requirements) from ACCIDENTAL complexity (tech de - Patterns marked "legacy" or "deprecated" in comments **Why This Matters:** -When AI agents recall these patterns, high-debt items will include warnings like: +When AI agents retrieve these patterns, high-debt items will include warnings like: "⚠️ TECHNICAL DEBT: Consider not propagating this pattern." This prevents AI from accidentally spreading tech debt across the codebase. diff --git a/internal/config/writer_test.go b/internal/config/writer_test.go deleted file mode 100644 index afcde4f..0000000 --- a/internal/config/writer_test.go +++ /dev/null @@ -1,393 +0,0 @@ -package config - -import ( - "os" - "path/filepath" - "strings" - "testing" - - "github.com/spf13/viper" -) - -// setupTestConfigDir overrides GetGlobalConfigDir to use a temp directory. -// Returns a cleanup function that restores the original. -func setupTestConfigDir(t *testing.T) (string, func()) { - t.Helper() - - tmpDir := t.TempDir() - original := GetGlobalConfigDir - - GetGlobalConfigDir = func() (string, error) { - return tmpDir, nil - } - - return tmpDir, func() { - GetGlobalConfigDir = original - } -} - -func TestSaveAPIKeyForProvider_Success(t *testing.T) { - tmpDir, cleanup := setupTestConfigDir(t) - defer cleanup() - - // Save a key - err := SaveAPIKeyForProvider("openai", "sk-test-key-12345") - if err != nil { - t.Fatalf("SaveAPIKeyForProvider() error = %v", err) - } - - // Verify the config file was created - configPath := filepath.Join(tmpDir, "config.yaml") - if _, err := os.Stat(configPath); os.IsNotExist(err) { - t.Fatal("config file was not created") - } - - // Read the config and verify the key is stored correctly - v := viper.New() - v.SetConfigFile(configPath) - v.SetConfigType("yaml") - if err := v.ReadInConfig(); err != nil { - t.Fatalf("failed to read config: %v", err) - } - - savedKey := v.GetString("llm.apiKeys.openai") - if savedKey != "sk-test-key-12345" { - t.Errorf("saved key = %q, want %q", savedKey, "sk-test-key-12345") - } -} - -func TestSaveAPIKeyForProvider_EmptyProvider(t *testing.T) { - _, cleanup := setupTestConfigDir(t) - defer cleanup() - - err := SaveAPIKeyForProvider("", "some-key") - if err == nil { - t.Fatal("expected error for empty provider") - } - if !strings.Contains(err.Error(), "provider cannot be empty") { - t.Errorf("error = %q, want to contain %q", err.Error(), "provider cannot be empty") - } -} - -func TestSaveAPIKeyForProvider_EmptyKey(t *testing.T) { - _, cleanup := setupTestConfigDir(t) - defer cleanup() - - err := SaveAPIKeyForProvider("openai", "") - if err == nil { - t.Fatal("expected error for empty key") - } - if !strings.Contains(err.Error(), "API key cannot be empty") { - t.Errorf("error = %q, want to contain %q", err.Error(), "API key cannot be empty") - } -} - -func TestSaveAPIKeyForProvider_MultipleProviders(t *testing.T) { - tmpDir, cleanup := setupTestConfigDir(t) - defer cleanup() - - // Save keys for multiple providers - providers := map[string]string{ - "openai": "sk-openai-key", - "anthropic": "sk-anthropic-key", - "gemini": "gemini-api-key", - } - - for provider, key := range providers { - if err := SaveAPIKeyForProvider(provider, key); err != nil { - t.Fatalf("SaveAPIKeyForProvider(%q) error = %v", provider, err) - } - } - - // Verify all keys are stored - v := viper.New() - v.SetConfigFile(filepath.Join(tmpDir, "config.yaml")) - v.SetConfigType("yaml") - if err := v.ReadInConfig(); err != nil { - t.Fatalf("failed to read config: %v", err) - } - - for provider, expectedKey := range providers { - savedKey := v.GetString("llm.apiKeys." + provider) - if savedKey != expectedKey { - t.Errorf("key for %q = %q, want %q", provider, savedKey, expectedKey) - } - } -} - -func TestSaveAPIKeyForProvider_UpdatesExistingKey(t *testing.T) { - tmpDir, cleanup := setupTestConfigDir(t) - defer cleanup() - - // Save initial key - if err := SaveAPIKeyForProvider("openai", "old-key"); err != nil { - t.Fatalf("initial save error = %v", err) - } - - // Update with new key - if err := SaveAPIKeyForProvider("openai", "new-key"); err != nil { - t.Fatalf("update save error = %v", err) - } - - // Verify the key was updated - v := viper.New() - v.SetConfigFile(filepath.Join(tmpDir, "config.yaml")) - v.SetConfigType("yaml") - if err := v.ReadInConfig(); err != nil { - t.Fatalf("failed to read config: %v", err) - } - - savedKey := v.GetString("llm.apiKeys.openai") - if savedKey != "new-key" { - t.Errorf("saved key = %q, want %q", savedKey, "new-key") - } -} - -func TestSaveAPIKeyForProvider_PreservesOtherConfig(t *testing.T) { - tmpDir, cleanup := setupTestConfigDir(t) - defer cleanup() - - // Create existing config with other settings - configPath := filepath.Join(tmpDir, "config.yaml") - existingConfig := `version: "1" -llm: - provider: gemini - model: gemini-flash -verbose: true -` - if err := os.WriteFile(configPath, []byte(existingConfig), 0600); err != nil { - t.Fatalf("failed to write initial config: %v", err) - } - - // Save API key - if err := SaveAPIKeyForProvider("openai", "sk-test"); err != nil { - t.Fatalf("SaveAPIKeyForProvider() error = %v", err) - } - - // Verify other settings were preserved - v := viper.New() - v.SetConfigFile(configPath) - v.SetConfigType("yaml") - if err := v.ReadInConfig(); err != nil { - t.Fatalf("failed to read config: %v", err) - } - - if v.GetString("llm.provider") != "gemini" { - t.Errorf("provider was modified, got %q", v.GetString("llm.provider")) - } - if v.GetString("llm.model") != "gemini-flash" { - t.Errorf("model was modified, got %q", v.GetString("llm.model")) - } - if !v.GetBool("verbose") { - t.Error("verbose setting was not preserved") - } -} - -func TestSaveBedrockRegion_Success(t *testing.T) { - tmpDir, cleanup := setupTestConfigDir(t) - defer cleanup() - - if err := SaveBedrockRegion("us-east-1"); err != nil { - t.Fatalf("SaveBedrockRegion() error = %v", err) - } - - v := viper.New() - v.SetConfigFile(filepath.Join(tmpDir, "config.yaml")) - v.SetConfigType("yaml") - if err := v.ReadInConfig(); err != nil { - t.Fatalf("failed to read config: %v", err) - } - - got := v.GetString("llm.bedrock.region") - if got != "us-east-1" { - t.Fatalf("llm.bedrock.region = %q, want %q", got, "us-east-1") - } -} - -func TestDeleteAPIKeyForProvider_Success(t *testing.T) { - tmpDir, cleanup := setupTestConfigDir(t) - defer cleanup() - - // First save a key - if err := SaveAPIKeyForProvider("openai", "sk-to-delete"); err != nil { - t.Fatalf("SaveAPIKeyForProvider() error = %v", err) - } - - // Delete it - err := DeleteAPIKeyForProvider("openai") - if err != nil { - t.Fatalf("DeleteAPIKeyForProvider() error = %v", err) - } - - // Verify the key is gone - v := viper.New() - v.SetConfigFile(filepath.Join(tmpDir, "config.yaml")) - v.SetConfigType("yaml") - if err := v.ReadInConfig(); err != nil { - t.Fatalf("failed to read config: %v", err) - } - - if v.IsSet("llm.apiKeys.openai") { - t.Error("key should have been deleted but still exists") - } -} - -func TestDeleteAPIKeyForProvider_EmptyProvider(t *testing.T) { - _, cleanup := setupTestConfigDir(t) - defer cleanup() - - err := DeleteAPIKeyForProvider("") - if err == nil { - t.Fatal("expected error for empty provider") - } - if !strings.Contains(err.Error(), "provider cannot be empty") { - t.Errorf("error = %q, want to contain %q", err.Error(), "provider cannot be empty") - } -} - -func TestDeleteAPIKeyForProvider_NoConfigFile(t *testing.T) { - _, cleanup := setupTestConfigDir(t) - defer cleanup() - - // No config file exists - should be a no-op, not an error - err := DeleteAPIKeyForProvider("openai") - if err != nil { - t.Fatalf("DeleteAPIKeyForProvider() should not error when no config exists, got: %v", err) - } -} - -func TestDeleteAPIKeyForProvider_KeyNotExists(t *testing.T) { - tmpDir, cleanup := setupTestConfigDir(t) - defer cleanup() - - // Create config without the key we'll try to delete - configPath := filepath.Join(tmpDir, "config.yaml") - existingConfig := `version: "1" -llm: - provider: gemini -` - if err := os.WriteFile(configPath, []byte(existingConfig), 0600); err != nil { - t.Fatalf("failed to write initial config: %v", err) - } - - // Delete non-existent key - should be a no-op - err := DeleteAPIKeyForProvider("openai") - if err != nil { - t.Fatalf("DeleteAPIKeyForProvider() should not error for non-existent key, got: %v", err) - } -} - -func TestDeleteAPIKeyForProvider_PreservesOtherKeys(t *testing.T) { - tmpDir, cleanup := setupTestConfigDir(t) - defer cleanup() - - // Save multiple keys - if err := SaveAPIKeyForProvider("openai", "openai-key"); err != nil { - t.Fatalf("save openai error = %v", err) - } - if err := SaveAPIKeyForProvider("anthropic", "anthropic-key"); err != nil { - t.Fatalf("save anthropic error = %v", err) - } - - // Delete only one - if err := DeleteAPIKeyForProvider("openai"); err != nil { - t.Fatalf("DeleteAPIKeyForProvider() error = %v", err) - } - - // Verify the other key is still there - v := viper.New() - v.SetConfigFile(filepath.Join(tmpDir, "config.yaml")) - v.SetConfigType("yaml") - if err := v.ReadInConfig(); err != nil { - t.Fatalf("failed to read config: %v", err) - } - - if v.IsSet("llm.apiKeys.openai") { - t.Error("openai key should have been deleted") - } - if !v.IsSet("llm.apiKeys.anthropic") { - t.Error("anthropic key should still exist") - } - if v.GetString("llm.apiKeys.anthropic") != "anthropic-key" { - t.Errorf("anthropic key = %q, want %q", v.GetString("llm.apiKeys.anthropic"), "anthropic-key") - } -} - -func TestDeleteAPIKeyForProvider_PreservesOtherConfig(t *testing.T) { - tmpDir, cleanup := setupTestConfigDir(t) - defer cleanup() - - // Create config with other settings and a key - configPath := filepath.Join(tmpDir, "config.yaml") - existingConfig := `version: "1" -llm: - provider: openai - model: gpt-4 - apiKeys: - openai: sk-to-delete -verbose: true -` - if err := os.WriteFile(configPath, []byte(existingConfig), 0600); err != nil { - t.Fatalf("failed to write initial config: %v", err) - } - - // Delete the key - if err := DeleteAPIKeyForProvider("openai"); err != nil { - t.Fatalf("DeleteAPIKeyForProvider() error = %v", err) - } - - // Verify other settings were preserved - v := viper.New() - v.SetConfigFile(configPath) - v.SetConfigType("yaml") - if err := v.ReadInConfig(); err != nil { - t.Fatalf("failed to read config: %v", err) - } - - if v.GetString("llm.provider") != "openai" { - t.Errorf("provider was modified, got %q", v.GetString("llm.provider")) - } - if v.GetString("llm.model") != "gpt-4" { - t.Errorf("model was modified, got %q", v.GetString("llm.model")) - } - if !v.GetBool("verbose") { - t.Error("verbose setting was not preserved") - } -} - -// TestAPIKeyNotInMemoryDB verifies that API key functions don't touch memory.db. -// This is a security constraint - API keys must ONLY be stored in user config files. -func TestAPIKeyNotInMemoryDB(t *testing.T) { - tmpDir, cleanup := setupTestConfigDir(t) - defer cleanup() - - // Create a fake memory.db file to verify it's not touched - memoryDir := filepath.Join(tmpDir, ".taskwing", "memory") - if err := os.MkdirAll(memoryDir, 0755); err != nil { - t.Fatalf("failed to create memory dir: %v", err) - } - - memoryDBPath := filepath.Join(memoryDir, "memory.db") - originalContent := []byte("fake db content - should not be modified") - if err := os.WriteFile(memoryDBPath, originalContent, 0600); err != nil { - t.Fatalf("failed to create fake memory.db: %v", err) - } - - // Perform API key operations - if err := SaveAPIKeyForProvider("openai", "test-key"); err != nil { - t.Fatalf("SaveAPIKeyForProvider() error = %v", err) - } - if err := DeleteAPIKeyForProvider("openai"); err != nil { - t.Fatalf("DeleteAPIKeyForProvider() error = %v", err) - } - - // Verify memory.db was not modified - content, err := os.ReadFile(memoryDBPath) - if err != nil { - t.Fatalf("failed to read memory.db: %v", err) - } - - if string(content) != string(originalContent) { - t.Error("memory.db was modified by API key operations - SECURITY VIOLATION") - } -} diff --git a/internal/git/workflow_test.go b/internal/git/workflow_test.go deleted file mode 100644 index d4f84f1..0000000 --- a/internal/git/workflow_test.go +++ /dev/null @@ -1,133 +0,0 @@ -package git - -import ( - "fmt" - "testing" -) - -func TestGenerateBranchName(t *testing.T) { - tests := []struct { - name string - planID string - planTitle string - want string - }{ - { - name: "standard title", - planID: "plan-abc12345", - planTitle: "Add OAuth2 authentication", - want: "feat/add-oauth2-authentication-c12345", - }, - { - name: "empty title falls back to ID", - planID: "plan-xyz789", - planTitle: "", - want: "feat/xyz789", - }, - { - name: "short plan ID", - planID: "abc", - planTitle: "Fix bug", - want: "feat/fix-bug-abc", - }, - { - name: "different plans same title get different branches", - planID: "plan-111111", - planTitle: "Add authentication", - want: "feat/add-authentication-111111", - }, - { - name: "different plans same title - second plan", - planID: "plan-222222", - planTitle: "Add authentication", - want: "feat/add-authentication-222222", - }, - { - name: "long title gets truncated", - planID: "plan-abc123", - planTitle: "This is a very long plan title that exceeds the maximum allowed length for branch names", - want: "feat/this-is-a-very-long-plan-title-that-exceeds-abc123", - }, - { - name: "special characters removed", - planID: "plan-test99", - planTitle: "Fix bug #123: User can't login!", - want: "feat/fix-bug-123-user-cant-login-test99", - }, - { - name: "underscores converted to hyphens", - planID: "plan-under1", - planTitle: "add_new_feature", - want: "feat/add-new-feature-under1", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := GenerateBranchName(tt.planID, tt.planTitle) - if got != tt.want { - t.Errorf("GenerateBranchName(%q, %q) = %q, want %q", - tt.planID, tt.planTitle, got, tt.want) - } - }) - } -} - -func TestGenerateBranchName_Uniqueness(t *testing.T) { - // Critical test: two plans with identical titles MUST produce different branch names - branch1 := GenerateBranchName("plan-aaaaaa", "Add authentication") - branch2 := GenerateBranchName("plan-bbbbbb", "Add authentication") - - if branch1 == branch2 { - t.Errorf("CRITICAL: Two different plans produced identical branch names: %q", branch1) - } -} - -func TestUnrelatedBranchError(t *testing.T) { - err := &UnrelatedBranchError{ - CurrentBranch: "feat/other-work", - ExpectedBranch: "feat/add-auth-abc123", - } - - // Test error message - expected := `currently on branch "feat/other-work" which is unrelated to plan branch "feat/add-auth-abc123"` - if err.Error() != expected { - t.Errorf("Error() = %q, want %q", err.Error(), expected) - } - - // Test type detection - if !IsUnrelatedBranchError(err) { - t.Error("IsUnrelatedBranchError() should return true for UnrelatedBranchError") - } - - // Test non-matching error - otherErr := fmt.Errorf("some other error") - if IsUnrelatedBranchError(otherErr) { - t.Error("IsUnrelatedBranchError() should return false for other errors") - } -} - -func TestSlugify(t *testing.T) { - tests := []struct { - input string - want string - }{ - {"Hello World", "hello-world"}, - {"UPPERCASE", "uppercase"}, - {"under_score", "under-score"}, - {"multiple spaces", "multiple-spaces"}, - {"special!@#$chars", "specialchars"}, - {"--leading-trailing--", "leading-trailing"}, - {"café résumé", "caf-rsum"}, - {"", ""}, - } - - for _, tt := range tests { - t.Run(tt.input, func(t *testing.T) { - got := Slugify(tt.input) - if got != tt.want { - t.Errorf("Slugify(%q) = %q, want %q", tt.input, got, tt.want) - } - }) - } -} diff --git a/internal/knowledge/debug_test.go b/internal/knowledge/debug_test.go deleted file mode 100644 index 138a359..0000000 --- a/internal/knowledge/debug_test.go +++ /dev/null @@ -1,752 +0,0 @@ -package knowledge - -import ( - "context" - "fmt" - "testing" - - "github.com/josephgoksu/TaskWing/internal/llm" - "github.com/josephgoksu/TaskWing/internal/memory" -) - -// MockRepository implements Repository for testing -type MockRepository struct { - nodes []memory.Node - nodeByID map[string]*memory.Node - ftsResult []memory.FTSResult -} - -func NewMockRepository() *MockRepository { - return &MockRepository{ - nodeByID: make(map[string]*memory.Node), - } -} - -func (m *MockRepository) AddNode(n memory.Node) { - m.nodes = append(m.nodes, n) - nodeCopy := n - m.nodeByID[n.ID] = &nodeCopy -} - -func (m *MockRepository) SetFTSResults(results []memory.FTSResult) { - m.ftsResult = results -} - -// Repository interface implementations -func (m *MockRepository) ListNodes(_ string) ([]memory.Node, error) { - return m.nodes, nil -} - -func (m *MockRepository) GetNode(id string) (*memory.Node, error) { - if n, ok := m.nodeByID[id]; ok { - return n, nil - } - return nil, nil -} - -func (m *MockRepository) CreateNode(_ *memory.Node) error { return nil } -func (m *MockRepository) UpsertNodeBySummary(_ memory.Node) error { return nil } -func (m *MockRepository) DeleteNodesByAgent(_ string) error { return nil } -func (m *MockRepository) DeleteNodesByFiles(_ string, _ []string) error { return nil } -func (m *MockRepository) GetNodesByFiles(_ string, _ []string) ([]memory.Node, error) { - return nil, nil -} -func (m *MockRepository) LinkNodes(_, _, _ string, _ float64, _ map[string]any) error { - return nil -} -func (m *MockRepository) GetNodeEdges(_ string) ([]memory.NodeEdge, error) { return nil, nil } -func (m *MockRepository) ListNodesWithEmbeddings() ([]memory.Node, error) { - return m.nodes, nil -} -func (m *MockRepository) SearchFTS(_ string, _ int) ([]memory.FTSResult, error) { - return m.ftsResult, nil -} -func (m *MockRepository) GetEmbeddingStats() (*memory.EmbeddingStats, error) { return nil, nil } -func (m *MockRepository) GetProjectOverview() (*memory.ProjectOverview, error) { - return nil, nil -} - -// Workspace-filtered methods for monorepo support -func (m *MockRepository) ListNodesFiltered(filter memory.NodeFilter) ([]memory.Node, error) { - if filter.Workspace == "" { - return m.nodes, nil - } - var filtered []memory.Node - for _, n := range m.nodes { - if n.Workspace == filter.Workspace { - filtered = append(filtered, n) - } else if filter.IncludeRoot && n.Workspace == "root" { - filtered = append(filtered, n) - } - } - return filtered, nil -} - -func (m *MockRepository) ListNodesWithEmbeddingsFiltered(filter memory.NodeFilter) ([]memory.Node, error) { - return m.ListNodesFiltered(filter) -} - -func (m *MockRepository) SearchFTSFiltered(_ string, _ int, filter memory.NodeFilter) ([]memory.FTSResult, error) { - if filter.Workspace == "" { - return m.ftsResult, nil - } - var filtered []memory.FTSResult - for _, r := range m.ftsResult { - if r.Node.Workspace == filter.Workspace { - filtered = append(filtered, r) - } else if filter.IncludeRoot && r.Node.Workspace == "root" { - filtered = append(filtered, r) - } - } - return filtered, nil -} - -func TestDebugRetrieval_ExactIDMatch(t *testing.T) { - repo := NewMockRepository() - - // Add a task node - taskNode := memory.Node{ - ID: "task-abc123", - Type: "task", - Summary: "Test Task", - Content: "This is a test task", - } - repo.AddNode(taskNode) - - svc := NewService(repo, llm.Config{}) - - // Test exact ID match - result, err := svc.SearchDebug(context.Background(), "task-abc123", 10) - if err != nil { - t.Fatalf("SearchDebug failed: %v", err) - } - - if len(result.Results) == 0 { - t.Fatal("Expected at least one result for exact ID match") - } - - // First result should be the exact match - first := result.Results[0] - if first.ID != "task-abc123" { - t.Errorf("Expected ID task-abc123, got %s", first.ID) - } - if !first.IsExactMatch { - t.Error("Expected IsExactMatch to be true") - } - if first.CombinedScore != 1.0 { - t.Errorf("Expected CombinedScore 1.0 for exact match, got %f", first.CombinedScore) - } - - // Check pipeline includes ExactMatch - hasExactMatch := false - for _, stage := range result.Pipeline { - if stage == "ExactMatch" { - hasExactMatch = true - break - } - } - if !hasExactMatch { - t.Error("Expected Pipeline to include ExactMatch") - } -} - -func TestDebugRetrieval_FTSMatch(t *testing.T) { - repo := NewMockRepository() - - // Add a node - node := memory.Node{ - ID: "n-test1", - Type: "decision", - Summary: "Authentication Decision", - Content: "We use JWT for authentication", - } - repo.AddNode(node) - - // Setup FTS results - repo.SetFTSResults([]memory.FTSResult{ - {Node: node, Rank: -5.0}, // BM25 rank (negative, more negative = better) - }) - - svc := NewService(repo, llm.Config{}) - - result, err := svc.SearchDebug(context.Background(), "authentication", 10) - if err != nil { - t.Fatalf("SearchDebug failed: %v", err) - } - - if len(result.Results) == 0 { - t.Fatal("Expected at least one result") - } - - first := result.Results[0] - if first.FTSScore == 0 { - t.Error("Expected non-zero FTSScore") - } - - // Check pipeline includes FTS - hasFTS := false - for _, stage := range result.Pipeline { - if stage == "FTS" { - hasFTS = true - break - } - } - if !hasFTS { - t.Error("Expected Pipeline to include FTS") - } -} - -func TestDebugRetrieval_ResponseStructure(t *testing.T) { - repo := NewMockRepository() - svc := NewService(repo, llm.Config{}) - - result, err := svc.SearchDebug(context.Background(), "test query", 10) - if err != nil { - t.Fatalf("SearchDebug failed: %v", err) - } - - // Verify response structure - if result.Query != "test query" { - t.Errorf("Expected Query 'test query', got '%s'", result.Query) - } - - if result.Timings == nil { - t.Error("Expected Timings to be initialized") - } - - // Timings should have entries for each stage - expectedTimings := []string{"exact_match", "fts", "vector", "rerank", "graph"} - for _, key := range expectedTimings { - if _, ok := result.Timings[key]; !ok { - t.Errorf("Expected Timings to have key '%s'", key) - } - } -} - -func TestDebugRetrieval_PlanIDMatch(t *testing.T) { - repo := NewMockRepository() - - // Add a plan node - planNode := memory.Node{ - ID: "plan-xyz789", - Type: "plan", - Summary: "Implementation Plan", - Content: "Plan details here", - } - repo.AddNode(planNode) - - svc := NewService(repo, llm.Config{}) - - result, err := svc.SearchDebug(context.Background(), "plan-xyz789", 10) - if err != nil { - t.Fatalf("SearchDebug failed: %v", err) - } - - if len(result.Results) == 0 { - t.Fatal("Expected at least one result for plan ID match") - } - - first := result.Results[0] - if first.ID != "plan-xyz789" { - t.Errorf("Expected ID plan-xyz789, got %s", first.ID) - } - if !first.IsExactMatch { - t.Error("Expected IsExactMatch to be true for plan ID") - } -} - -func TestDebugRetrievalResult_Fields(t *testing.T) { - // Test that DebugRetrievalResult has all expected fields - result := DebugRetrievalResult{ - ID: "test-id", - ChunkID: "test-chunk", - NodeType: "decision", - SourceFilePath: "/path/to/file.go", - SourceAgent: "test-agent", - Summary: "Test Summary", - Content: "Test Content", - FTSScore: 0.5, - VectorScore: 0.7, - CombinedScore: 0.6, - RerankScore: 0.8, - IsExactMatch: true, - IsGraphExpanded: false, - EmbeddingDimension: 1536, - } - - if result.ID != "test-id" { - t.Error("ID field not set correctly") - } - if result.VectorScore != 0.7 { - t.Error("VectorScore field not set correctly") - } - if result.EmbeddingDimension != 1536 { - t.Error("EmbeddingDimension field not set correctly") - } -} - -// === Workspace Scoping Tests === - -func TestWorkspaceFiltering_ListNodesFiltered(t *testing.T) { - repo := NewMockRepository() - - // Add nodes in different workspaces - repo.AddNode(memory.Node{ID: "n-root-1", Summary: "Root Decision", Workspace: "root"}) - repo.AddNode(memory.Node{ID: "n-root-2", Summary: "Root Pattern", Workspace: "root"}) - repo.AddNode(memory.Node{ID: "n-osprey-1", Summary: "Osprey Decision", Workspace: "osprey"}) - repo.AddNode(memory.Node{ID: "n-studio-1", Summary: "Studio Decision", Workspace: "studio"}) - - tests := []struct { - name string - filter memory.NodeFilter - wantIDs []string - wantCount int - description string - }{ - { - name: "empty workspace returns all", - filter: memory.NodeFilter{Workspace: ""}, - wantCount: 4, - description: "Empty workspace should return all nodes", - }, - { - name: "workspace only", - filter: memory.NodeFilter{Workspace: "osprey", IncludeRoot: false}, - wantIDs: []string{"n-osprey-1"}, - wantCount: 1, - description: "Should return only osprey nodes", - }, - { - name: "workspace plus root", - filter: memory.NodeFilter{Workspace: "osprey", IncludeRoot: true}, - wantCount: 3, // osprey + 2 root nodes - description: "Should return osprey + root nodes", - }, - { - name: "root workspace only", - filter: memory.NodeFilter{Workspace: "root", IncludeRoot: false}, - wantIDs: []string{"n-root-1", "n-root-2"}, - wantCount: 2, - description: "Should return only root nodes", - }, - { - name: "nonexistent workspace", - filter: memory.NodeFilter{Workspace: "nonexistent", IncludeRoot: false}, - wantCount: 0, - description: "Should return no nodes for nonexistent workspace", - }, - { - name: "nonexistent workspace with root", - filter: memory.NodeFilter{Workspace: "nonexistent", IncludeRoot: true}, - wantCount: 2, // Only root nodes - description: "Should return root nodes even for nonexistent workspace", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - nodes, err := repo.ListNodesFiltered(tt.filter) - if err != nil { - t.Fatalf("ListNodesFiltered failed: %v", err) - } - - if len(nodes) != tt.wantCount { - t.Errorf("%s: got %d nodes, want %d", tt.description, len(nodes), tt.wantCount) - } - - if len(tt.wantIDs) > 0 { - gotIDs := make(map[string]bool) - for _, n := range nodes { - gotIDs[n.ID] = true - } - for _, wantID := range tt.wantIDs { - if !gotIDs[wantID] { - t.Errorf("Expected node %s not found", wantID) - } - } - } - }) - } -} - -func TestWorkspaceFiltering_SearchFTSFiltered(t *testing.T) { - repo := NewMockRepository() - - // Add nodes for FTS - rootNode := memory.Node{ID: "n-root", Summary: "Auth Decision", Workspace: "root"} - ospreyNode := memory.Node{ID: "n-osprey", Summary: "Auth Pattern", Workspace: "osprey"} - repo.AddNode(rootNode) - repo.AddNode(ospreyNode) - - // Setup FTS results - repo.SetFTSResults([]memory.FTSResult{ - {Node: rootNode, Rank: -5.0}, - {Node: ospreyNode, Rank: -4.0}, - }) - - // Test: No filter returns all - results, err := repo.SearchFTSFiltered("auth", 10, memory.NodeFilter{}) - if err != nil { - t.Fatalf("SearchFTSFiltered failed: %v", err) - } - if len(results) != 2 { - t.Errorf("No filter: expected 2 results, got %d", len(results)) - } - - // Test: Workspace filter with IncludeRoot=true - results, err = repo.SearchFTSFiltered("auth", 10, memory.NodeFilter{Workspace: "osprey", IncludeRoot: true}) - if err != nil { - t.Fatalf("SearchFTSFiltered failed: %v", err) - } - if len(results) != 2 { - t.Errorf("Osprey+root: expected 2 results, got %d", len(results)) - } - - // Test: Workspace filter with IncludeRoot=false - results, err = repo.SearchFTSFiltered("auth", 10, memory.NodeFilter{Workspace: "osprey", IncludeRoot: false}) - if err != nil { - t.Fatalf("SearchFTSFiltered failed: %v", err) - } - if len(results) != 1 { - t.Errorf("Osprey only: expected 1 result, got %d", len(results)) - } - if len(results) > 0 && results[0].Node.ID != "n-osprey" { - t.Errorf("Expected osprey node, got %s", results[0].Node.ID) - } -} - -func TestNodeFilter_DefaultValues(t *testing.T) { - filter := memory.DefaultNodeFilter() - - if filter.Type != "" { - t.Errorf("Type = %q, want empty", filter.Type) - } - if filter.Workspace != "" { - t.Errorf("Workspace = %q, want empty", filter.Workspace) - } - if !filter.IncludeRoot { - t.Error("IncludeRoot = false, want true") - } -} - -// === SearchWithFilter Tests (Service-level workspace scoping) === - -func TestSearchWithFilter_EmptyWorkspaceReturnsAll(t *testing.T) { - repo := NewMockRepository() - - // Add nodes in different workspaces - repo.AddNode(memory.Node{ID: "n-root-1", Summary: "Root Decision", Workspace: "root", Type: "decision"}) - repo.AddNode(memory.Node{ID: "n-osprey-1", Summary: "Osprey Pattern", Workspace: "osprey", Type: "pattern"}) - repo.AddNode(memory.Node{ID: "n-studio-1", Summary: "Studio Constraint", Workspace: "studio", Type: "constraint"}) - - // Setup FTS results for all nodes - repo.SetFTSResults([]memory.FTSResult{ - {Node: memory.Node{ID: "n-root-1", Summary: "Root Decision", Workspace: "root", Type: "decision"}, Rank: -5.0}, - {Node: memory.Node{ID: "n-osprey-1", Summary: "Osprey Pattern", Workspace: "osprey", Type: "pattern"}, Rank: -4.0}, - {Node: memory.Node{ID: "n-studio-1", Summary: "Studio Constraint", Workspace: "studio", Type: "constraint"}, Rank: -3.0}, - }) - - svc := NewService(repo, llm.Config{}) - - // Empty workspace filter should return all nodes - results, err := svc.SearchWithFilter(context.Background(), "test", 10, memory.NodeFilter{ - Workspace: "", - }) - if err != nil { - t.Fatalf("SearchWithFilter failed: %v", err) - } - - if len(results) != 3 { - t.Errorf("Expected 3 results (all nodes), got %d", len(results)) - } -} - -func TestSearchWithFilter_WorkspaceWithIncludeRoot(t *testing.T) { - repo := NewMockRepository() - - // Add nodes: 2 root, 1 osprey, 1 studio - rootNode1 := memory.Node{ID: "n-root-1", Summary: "Root Decision 1", Workspace: "root", Type: "decision"} - rootNode2 := memory.Node{ID: "n-root-2", Summary: "Root Decision 2", Workspace: "root", Type: "decision"} - ospreyNode := memory.Node{ID: "n-osprey-1", Summary: "Osprey Pattern", Workspace: "osprey", Type: "pattern"} - studioNode := memory.Node{ID: "n-studio-1", Summary: "Studio Constraint", Workspace: "studio", Type: "constraint"} - - repo.AddNode(rootNode1) - repo.AddNode(rootNode2) - repo.AddNode(ospreyNode) - repo.AddNode(studioNode) - - // Setup FTS results - repo.SetFTSResults([]memory.FTSResult{ - {Node: rootNode1, Rank: -5.0}, - {Node: rootNode2, Rank: -4.5}, - {Node: ospreyNode, Rank: -4.0}, - {Node: studioNode, Rank: -3.0}, - }) - - svc := NewService(repo, llm.Config{}) - - // Search for osprey workspace WITH IncludeRoot - results, err := svc.SearchWithFilter(context.Background(), "test", 10, memory.NodeFilter{ - Workspace: "osprey", - IncludeRoot: true, - }) - if err != nil { - t.Fatalf("SearchWithFilter failed: %v", err) - } - - // Should return osprey + root nodes (3 total), not studio - if len(results) != 3 { - t.Errorf("Expected 3 results (osprey + root), got %d", len(results)) - } - - // Verify no studio nodes - for _, r := range results { - if r.Node.Workspace == "studio" { - t.Errorf("Should not include studio node, but got %s", r.Node.ID) - } - } -} - -func TestSearchWithFilter_WorkspaceWithoutRoot(t *testing.T) { - repo := NewMockRepository() - - // Add nodes - rootNode := memory.Node{ID: "n-root-1", Summary: "Root Decision", Workspace: "root", Type: "decision"} - ospreyNode := memory.Node{ID: "n-osprey-1", Summary: "Osprey Pattern", Workspace: "osprey", Type: "pattern"} - - repo.AddNode(rootNode) - repo.AddNode(ospreyNode) - - // Setup FTS results - repo.SetFTSResults([]memory.FTSResult{ - {Node: rootNode, Rank: -5.0}, - {Node: ospreyNode, Rank: -4.0}, - }) - - svc := NewService(repo, llm.Config{}) - - // Search for osprey workspace WITHOUT IncludeRoot - results, err := svc.SearchWithFilter(context.Background(), "test", 10, memory.NodeFilter{ - Workspace: "osprey", - IncludeRoot: false, - }) - if err != nil { - t.Fatalf("SearchWithFilter failed: %v", err) - } - - // Should return only osprey nodes - if len(results) != 1 { - t.Errorf("Expected 1 result (osprey only), got %d", len(results)) - } - - if len(results) > 0 && results[0].Node.Workspace != "osprey" { - t.Errorf("Expected osprey node, got workspace %s", results[0].Node.Workspace) - } -} - -func TestSearchWithFilter_LimitRespected(t *testing.T) { - repo := NewMockRepository() - - // Add 10 nodes in osprey workspace - var ftsResults []memory.FTSResult - for i := 0; i < 10; i++ { - node := memory.Node{ - ID: fmt.Sprintf("n-osprey-%d", i), - Summary: fmt.Sprintf("Osprey Decision %d", i), - Workspace: "osprey", - Type: "decision", - } - repo.AddNode(node) - ftsResults = append(ftsResults, memory.FTSResult{Node: node, Rank: float64(-10 + i)}) - } - repo.SetFTSResults(ftsResults) - - svc := NewService(repo, llm.Config{}) - - // Request only 3 results - results, err := svc.SearchWithFilter(context.Background(), "decision", 3, memory.NodeFilter{ - Workspace: "osprey", - IncludeRoot: false, - }) - if err != nil { - t.Fatalf("SearchWithFilter failed: %v", err) - } - - if len(results) > 3 { - t.Errorf("Expected at most 3 results, got %d", len(results)) - } -} - -// === matchesWorkspaceFilter Tests === - -func TestMatchesWorkspaceFilter(t *testing.T) { - tests := []struct { - name string - nodeWorkspace string - filter memory.NodeFilter - want bool - }{ - { - name: "exact match", - nodeWorkspace: "osprey", - filter: memory.NodeFilter{Workspace: "osprey"}, - want: true, - }, - { - name: "no match different workspace", - nodeWorkspace: "studio", - filter: memory.NodeFilter{Workspace: "osprey"}, - want: false, - }, - { - name: "root node with IncludeRoot true", - nodeWorkspace: "root", - filter: memory.NodeFilter{Workspace: "osprey", IncludeRoot: true}, - want: true, - }, - { - name: "root node with IncludeRoot false", - nodeWorkspace: "root", - filter: memory.NodeFilter{Workspace: "osprey", IncludeRoot: false}, - want: false, - }, - { - name: "empty workspace treated as root with IncludeRoot true", - nodeWorkspace: "", - filter: memory.NodeFilter{Workspace: "osprey", IncludeRoot: true}, - want: true, - }, - { - name: "empty workspace treated as root with IncludeRoot false", - nodeWorkspace: "", - filter: memory.NodeFilter{Workspace: "osprey", IncludeRoot: false}, - want: false, - }, - { - name: "filter for root workspace matches root", - nodeWorkspace: "root", - filter: memory.NodeFilter{Workspace: "root"}, - want: true, - }, - { - name: "empty node matches root filter", - nodeWorkspace: "", - filter: memory.NodeFilter{Workspace: "root"}, - want: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := matchesWorkspaceFilter(tt.nodeWorkspace, tt.filter) - if got != tt.want { - t.Errorf("matchesWorkspaceFilter(%q, %+v) = %v, want %v", - tt.nodeWorkspace, tt.filter, got, tt.want) - } - }) - } -} - -// === Recall Workspace Scoping Integration Tests === - -func TestRecallWorkspaceScoping_MonorepoScenario(t *testing.T) { - // Simulates a monorepo with services: api, web, common - // When working in "api" subdirectory, should see: - // - api-specific knowledge - // - global/root knowledge (if IncludeRoot=true) - // - NOT web-specific knowledge - - repo := NewMockRepository() - - // Global decisions - globalAuth := memory.Node{ID: "n-global-auth", Summary: "JWT Authentication", Workspace: "root", Type: "decision"} - globalDB := memory.Node{ID: "n-global-db", Summary: "PostgreSQL Database", Workspace: "root", Type: "decision"} - - // API-specific - apiPattern := memory.Node{ID: "n-api-rest", Summary: "REST API Pattern", Workspace: "api", Type: "pattern"} - apiConstraint := memory.Node{ID: "n-api-rate", Summary: "Rate Limiting", Workspace: "api", Type: "constraint"} - - // Web-specific - webPattern := memory.Node{ID: "n-web-react", Summary: "React Components", Workspace: "web", Type: "pattern"} - - // Common-specific - commonUtil := memory.Node{ID: "n-common-utils", Summary: "Shared Utilities", Workspace: "common", Type: "pattern"} - - for _, n := range []memory.Node{globalAuth, globalDB, apiPattern, apiConstraint, webPattern, commonUtil} { - repo.AddNode(n) - } - - // Setup FTS results - repo.SetFTSResults([]memory.FTSResult{ - {Node: globalAuth, Rank: -8.0}, - {Node: globalDB, Rank: -7.0}, - {Node: apiPattern, Rank: -6.0}, - {Node: apiConstraint, Rank: -5.0}, - {Node: webPattern, Rank: -4.0}, - {Node: commonUtil, Rank: -3.0}, - }) - - svc := NewService(repo, llm.Config{}) - - // Scenario 1: Working in api directory with IncludeRoot=true - t.Run("api workspace with root", func(t *testing.T) { - results, err := svc.SearchWithFilter(context.Background(), "pattern", 10, memory.NodeFilter{ - Workspace: "api", - IncludeRoot: true, - }) - if err != nil { - t.Fatalf("SearchWithFilter failed: %v", err) - } - - // Should have: api nodes + root nodes - workspaces := make(map[string]int) - for _, r := range results { - workspaces[r.Node.Workspace]++ - } - - if workspaces["web"] > 0 { - t.Error("Should NOT include web-specific nodes when in api workspace") - } - if workspaces["common"] > 0 { - t.Error("Should NOT include common-specific nodes when in api workspace") - } - if workspaces["api"] == 0 { - t.Error("Should include api-specific nodes") - } - if workspaces["root"] == 0 { - t.Error("Should include root nodes when IncludeRoot=true") - } - }) - - // Scenario 2: Working in api directory with IncludeRoot=false - t.Run("api workspace without root", func(t *testing.T) { - results, err := svc.SearchWithFilter(context.Background(), "pattern", 10, memory.NodeFilter{ - Workspace: "api", - IncludeRoot: false, - }) - if err != nil { - t.Fatalf("SearchWithFilter failed: %v", err) - } - - for _, r := range results { - if r.Node.Workspace != "api" { - t.Errorf("Should only include api nodes, got workspace %s", r.Node.Workspace) - } - } - }) - - // Scenario 3: Working at monorepo root (empty workspace = all) - t.Run("root directory sees all", func(t *testing.T) { - results, err := svc.SearchWithFilter(context.Background(), "pattern", 10, memory.NodeFilter{ - Workspace: "", // Empty = no filtering - }) - if err != nil { - t.Fatalf("SearchWithFilter failed: %v", err) - } - - // Should see all nodes from all workspaces - if len(results) < 6 { - t.Errorf("Expected all 6 nodes when at root, got %d", len(results)) - } - }) -} diff --git a/internal/knowledge/formatter.go b/internal/knowledge/formatter.go index c69e313..d5622c4 100644 --- a/internal/knowledge/formatter.go +++ b/internal/knowledge/formatter.go @@ -1,4 +1,4 @@ -// Package knowledge provides compact Markdown formatting for recall results. +// Package knowledge provides compact Markdown formatting for ask results. // This formatter produces token-efficient output by grouping nodes by type // and removing all JSON metadata and embedding data. package knowledge @@ -13,7 +13,7 @@ import ( "golang.org/x/text/language" ) -// CompactFormatter produces condensed Markdown output for recall results. +// CompactFormatter produces condensed Markdown output for ask results. // It groups nodes by type and strips all unnecessary metadata to minimize tokens. type CompactFormatter struct { // MaxContentLen limits content preview length per node (default: 120) diff --git a/internal/knowledge/formatter_test.go b/internal/knowledge/formatter_test.go deleted file mode 100644 index b84f159..0000000 --- a/internal/knowledge/formatter_test.go +++ /dev/null @@ -1,324 +0,0 @@ -package knowledge - -import ( - "strings" - "testing" -) - -func TestCompactFormatter_FormatNodes_GroupsByType(t *testing.T) { - formatter := DefaultCompactFormatter() - - nodes := []NodeResponse{ - {ID: "1", Type: "decision", Summary: "Use SQLite", Content: "SQLite for local storage", MatchScore: 0.9}, - {ID: "2", Type: "pattern", Summary: "Repository Pattern", Content: "Data access layer pattern", MatchScore: 0.85}, - {ID: "3", Type: "decision", Summary: "Use Go", Content: "Golang for CLI", MatchScore: 0.8}, - {ID: "4", Type: "constraint", Summary: "No CGO", Content: "CGO-free for portability", MatchScore: 0.75}, - } - - result := formatter.FormatNodes(nodes) - - // Check that types are grouped - if !strings.Contains(result, "### 📋 Decisions") { - t.Error("Expected Decisions header") - } - if !strings.Contains(result, "### 🧩 Patterns") { - t.Error("Expected Patterns header") - } - if !strings.Contains(result, "### ⚠️ Constraints") { - t.Error("Expected Constraints header") - } - - // Check ordering: Decisions should come before Patterns - decisionsIdx := strings.Index(result, "Decisions") - patternsIdx := strings.Index(result, "Patterns") - constraintsIdx := strings.Index(result, "Constraints") - - if decisionsIdx > patternsIdx { - t.Error("Decisions should come before Patterns") - } - if patternsIdx > constraintsIdx { - t.Error("Patterns should come before Constraints") - } -} - -func TestCompactFormatter_FormatNodes_NoJSONOrEmbeddings(t *testing.T) { - formatter := DefaultCompactFormatter() - - nodes := []NodeResponse{ - { - ID: "1", - Type: "decision", - Summary: "Test Decision", - Content: "Some content here", - MatchScore: 0.9, - ConfidenceScore: 0.95, - // These fields should NOT appear in output - Evidence: []EvidenceRef{{File: "test.go", Lines: "10-20"}}, - }, - } - - result := formatter.FormatNodes(nodes) - - // Should not contain JSON-like patterns - if strings.Contains(result, `"id"`) || strings.Contains(result, `"type"`) { - t.Error("Output should not contain JSON keys") - } - if strings.Contains(result, "0.9") || strings.Contains(result, "0.95") { - t.Error("Output should not contain raw scores") - } - if strings.Contains(result, "[{") || strings.Contains(result, "}]") { - t.Error("Output should not contain JSON array notation") - } -} - -func TestCompactFormatter_FormatNodes_TokenReduction(t *testing.T) { - formatter := DefaultCompactFormatter() - - // Create 5 typical nodes (standard test case) - nodes := []NodeResponse{ - {ID: "n1", Type: "decision", Summary: "Embedded Database", Content: "modernc.org/sqlite - Pure Go implementation of SQLite (CGO-free).", MatchScore: 0.9}, - {ID: "n2", Type: "decision", Summary: "CLI Framework", Content: "Cobra and Viper for command-line interface and configuration.", MatchScore: 0.85}, - {ID: "n3", Type: "pattern", Summary: "Repository Pattern", Content: "Data access abstraction through repository interfaces.", MatchScore: 0.8}, - {ID: "n4", Type: "constraint", Summary: "No External Services", Content: "All storage must be local, no network dependencies.", MatchScore: 0.75}, - {ID: "n5", Type: "feature", Summary: "Semantic Search", Content: "Vector embeddings for natural language queries.", MatchScore: 0.7}, - } - - // Verbose format (simulating JSON-like output) - var verboseLen int - for _, n := range nodes { - // Simulate JSON: {"id":"n1","type":"decision","summary":"...","content":"...","match_score":0.9} - verbose := `{"id":"` + n.ID + `","type":"` + n.Type + `","summary":"` + n.Summary + `","content":"` + n.Content + `","match_score":` + "0.9" + `}` - verboseLen += len(verbose) - } - - // Compact format - result := formatter.FormatNodes(nodes) - compactLen := len(result) - - // Compact should be at least 30% smaller than verbose JSON - reduction := float64(verboseLen-compactLen) / float64(verboseLen) * 100 - if reduction < 30 { - t.Errorf("Expected at least 30%% token reduction, got %.1f%% (verbose: %d, compact: %d)", - reduction, verboseLen, compactLen) - } - - // Log actual reduction for visibility - t.Logf("Token reduction: %.1f%% (verbose: %d chars, compact: %d chars)", reduction, verboseLen, compactLen) -} - -func TestCompactFormatter_FormatWithAnswer(t *testing.T) { - formatter := DefaultCompactFormatter() - - answer := "The codebase uses SQLite for persistence." - nodes := []NodeResponse{ - {ID: "1", Type: "decision", Summary: "SQLite", Content: "Local storage", MatchScore: 0.9}, - } - symbols := []SymbolMatch{ - {Name: "NewSQLiteStore", Kind: "function", Location: "store.go:45"}, - } - - result := formatter.FormatWithAnswer(answer, nodes, symbols) - - // Check sections exist in correct order - answerIdx := strings.Index(result, "## Answer") - decisionsIdx := strings.Index(result, "Decisions") - symbolsIdx := strings.Index(result, "## Symbols") - - if answerIdx == -1 { - t.Error("Expected Answer section") - } - if decisionsIdx == -1 { - t.Error("Expected Decisions section") - } - if symbolsIdx == -1 { - t.Error("Expected Symbols section") - } - - // Order: Answer -> Decisions -> Symbols - if answerIdx > decisionsIdx { - t.Error("Answer should come before Decisions") - } - if decisionsIdx > symbolsIdx { - t.Error("Decisions should come before Symbols") - } - - // Check symbol format - if !strings.Contains(result, "`NewSQLiteStore`") { - t.Error("Expected backtick-wrapped symbol name") - } -} - -func TestCompactFormatter_MaxNodesPerType(t *testing.T) { - formatter := &CompactFormatter{ - MaxContentLen: 100, - MaxNodesPerType: 2, // Limit to 2 - } - - // Create 5 decisions - nodes := []NodeResponse{ - {ID: "1", Type: "decision", Summary: "D1", MatchScore: 0.9}, - {ID: "2", Type: "decision", Summary: "D2", MatchScore: 0.8}, - {ID: "3", Type: "decision", Summary: "D3", MatchScore: 0.7}, - {ID: "4", Type: "decision", Summary: "D4", MatchScore: 0.6}, - {ID: "5", Type: "decision", Summary: "D5", MatchScore: 0.5}, - } - - result := formatter.FormatNodes(nodes) - - // Should only show top 2 by score - if !strings.Contains(result, "D1") || !strings.Contains(result, "D2") { - t.Error("Expected top 2 decisions (D1, D2)") - } - if strings.Contains(result, "D3") || strings.Contains(result, "D4") || strings.Contains(result, "D5") { - t.Error("Should not include decisions beyond MaxNodesPerType limit") - } -} - -func TestCleanContentPreview(t *testing.T) { - tests := []struct { - name string - content string - summary string - maxLen int - expected string - }{ - { - name: "removes summary prefix", - content: "Use SQLite: For local persistence", - summary: "Use SQLite", - maxLen: 100, - expected: "For local persistence", - }, - { - name: "truncates long content", - content: "This is a very long content that should be truncated because it exceeds the maximum length", - summary: "Short", - maxLen: 30, - expected: "This is a very long content th...", - }, - { - name: "removes newlines", - content: "Line one\nLine two\nLine three", - summary: "Different", - maxLen: 100, - expected: "Line one Line two Line three", - }, - { - name: "returns empty for same content and summary", - content: "Same text", - summary: "Same text", - maxLen: 100, - expected: "", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := cleanContentPreview(tt.content, tt.summary, tt.maxLen) - if result != tt.expected { - t.Errorf("got %q, want %q", result, tt.expected) - } - }) - } -} - -func TestNormalizeType(t *testing.T) { - tests := []struct { - input string - expected string - }{ - {"decision", "decision"}, - {"decisions", "decision"}, - {"Decision", "decision"}, - {"DECISION", "decision"}, - {"architectural_decision", "decision"}, - {"pattern", "pattern"}, - {"patterns", "pattern"}, - {"constraint", "constraint"}, - {"feature", "feature"}, - {"docs", "documentation"}, - {"doc", "documentation"}, - {"unknown_type", "unknown_type"}, - } - - for _, tt := range tests { - t.Run(tt.input, func(t *testing.T) { - result := normalizeType(tt.input) - if result != tt.expected { - t.Errorf("normalizeType(%q) = %q, want %q", tt.input, result, tt.expected) - } - }) - } -} - -func TestTokenEstimate(t *testing.T) { - // ~4 chars per token heuristic - text := "This is a test string for token estimation" - estimate := TokenEstimate(text) - - // 43 chars / 4 = ~10 tokens - if estimate < 10 || estimate > 12 { - t.Errorf("TokenEstimate returned %d, expected ~10-11", estimate) - } -} - -func TestFormatNodes_EmptyInput(t *testing.T) { - formatter := DefaultCompactFormatter() - - result := formatter.FormatNodes(nil) - if result != "No results found." { - t.Errorf("Expected 'No results found.', got %q", result) - } - - result = formatter.FormatNodes([]NodeResponse{}) - if result != "No results found." { - t.Errorf("Expected 'No results found.', got %q", result) - } -} - -func TestCompactFormatter_ShowEvidence(t *testing.T) { - formatter := &CompactFormatter{ - MaxContentLen: 100, - MaxNodesPerType: 5, - ShowEvidence: true, // Enable evidence - } - - nodes := []NodeResponse{ - { - ID: "1", - Type: "decision", - Summary: "Test", - Content: "Content", - Evidence: []EvidenceRef{{File: "main.go", Lines: "10-20"}}, - }, - } - - result := formatter.FormatNodes(nodes) - - if !strings.Contains(result, "main.go:10-20") { - t.Error("Expected evidence reference when ShowEvidence is true") - } -} - -func TestCompactFormatter_DebtWarning(t *testing.T) { - formatter := DefaultCompactFormatter() - - nodes := []NodeResponse{ - { - ID: "1", - Type: "pattern", - Summary: "Legacy Pattern", - Content: "Old way of doing things", - DebtWarning: "TECH DEBT: Consider refactoring", - }, - } - - result := formatter.FormatNodes(nodes) - - if !strings.Contains(result, "⚠️") { - t.Error("Expected debt warning icon") - } - if !strings.Contains(result, "TECH DEBT") { - t.Error("Expected debt warning text") - } -} diff --git a/internal/knowledge/ingest_test.go b/internal/knowledge/ingest_test.go deleted file mode 100644 index 87a61de..0000000 --- a/internal/knowledge/ingest_test.go +++ /dev/null @@ -1,436 +0,0 @@ -package knowledge - -import ( - "context" - "os" - "testing" - - "github.com/josephgoksu/TaskWing/internal/agents/core" - "github.com/josephgoksu/TaskWing/internal/llm" - "github.com/josephgoksu/TaskWing/internal/memory" -) - -// ============================================================================= -// Ingestion Tests -// ============================================================================= - -// TestService_IngestFindings_BasicFinding tests basic finding ingestion. -func TestService_IngestFindings_BasicFinding(t *testing.T) { - // Create temp directory for repository - tmpDir, err := os.MkdirTemp("", "taskwing-ingest-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - // Initialize real repository using NewDefaultRepository - repo, err := memory.NewDefaultRepository(tmpDir) - if err != nil { - t.Fatalf("failed to create repository: %v", err) - } - defer func() { _ = repo.Close() }() - - // Create service with empty LLM config (embeddings disabled) - svc := NewService(repo, llm.Config{}) - - // Create a basic finding (simulating what bootstrap would produce) - findings := []core.Finding{ - { - Type: memory.NodeTypeDecision, - Title: "Test Decision", - Description: "This is a test decision for ingestion", - SourceAgent: "test-agent", - Metadata: map[string]any{ - "source": "test", - }, - }, - } - - // Ingest the finding - err = svc.IngestFindings(context.Background(), findings, nil, false) - if err != nil { - t.Fatalf("IngestFindings failed: %v", err) - } - - // Verify the node was created - nodes, err := repo.ListNodes("") - if err != nil { - t.Fatalf("ListNodes failed: %v", err) - } - - if len(nodes) == 0 { - t.Error("Expected at least one node after ingestion") - } - - // Verify node content - found := false - for _, n := range nodes { - if n.Summary == "Test Decision" && n.Type == memory.NodeTypeDecision { - found = true - if n.SourceAgent != "test-agent" { - t.Errorf("SourceAgent = %q, want %q", n.SourceAgent, "test-agent") - } - break - } - } - if !found { - t.Error("Expected to find the ingested decision node") - } -} - -// TestService_IngestFindings_OpenCodeSkillMetadata tests ingestion of a finding -// that could come from an OpenCode skill analysis. -func TestService_IngestFindings_OpenCodeSkillMetadata(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "taskwing-opencode-ingest-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - repo, err := memory.NewDefaultRepository(tmpDir) - if err != nil { - t.Fatalf("failed to create repository: %v", err) - } - defer func() { _ = repo.Close() }() - - svc := NewService(repo, llm.Config{}) - - // Simulate a finding from OpenCode skill analysis - // This tests that skill-related metadata can be properly ingested - findings := []core.Finding{ - { - Type: memory.NodeTypePattern, - Title: "OpenCode Skills Pattern", - Description: "OpenCode uses skills in .opencode/skills//SKILL.md format with YAML frontmatter", - SourceAgent: "doc-agent", - Metadata: map[string]any{ - "source": "opencode", - "skill_dir": ".opencode/skills/", - "format": "yaml-frontmatter", - }, - }, - { - Type: memory.NodeTypeConstraint, - Title: "OpenCode Skill Name Validation", - Description: "Skill names must match regex: ^[a-z0-9]+(-[a-z0-9]+)*$", - SourceAgent: "doc-agent", - Metadata: map[string]any{ - "source": "opencode", - "pattern": "^[a-z0-9]+(-[a-z0-9]+)*$", - }, - }, - } - - err = svc.IngestFindings(context.Background(), findings, nil, false) - if err != nil { - t.Fatalf("IngestFindings failed: %v", err) - } - - // Verify both nodes were created - nodes, err := repo.ListNodes("") - if err != nil { - t.Fatalf("ListNodes failed: %v", err) - } - - if len(nodes) < 2 { - t.Errorf("Expected at least 2 nodes, got %d", len(nodes)) - } - - // Check for pattern node - patternFound := false - constraintFound := false - for _, n := range nodes { - if n.Type == memory.NodeTypePattern && n.Summary == "OpenCode Skills Pattern" { - patternFound = true - } - if n.Type == memory.NodeTypeConstraint && n.Summary == "OpenCode Skill Name Validation" { - constraintFound = true - } - } - - if !patternFound { - t.Error("Pattern node not found") - } - if !constraintFound { - t.Error("Constraint node not found") - } -} - -// TestService_IngestFindings_EmptyFindings tests that empty findings is a no-op. -func TestService_IngestFindings_EmptyFindings(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "taskwing-empty-ingest-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - repo, err := memory.NewDefaultRepository(tmpDir) - if err != nil { - t.Fatalf("failed to create repository: %v", err) - } - defer func() { _ = repo.Close() }() - - svc := NewService(repo, llm.Config{}) - - // Ingest empty findings - should be a no-op - err = svc.IngestFindings(context.Background(), []core.Finding{}, nil, false) - if err != nil { - t.Errorf("IngestFindings with empty findings should not error: %v", err) - } - - err = svc.IngestFindings(context.Background(), nil, nil, false) - if err != nil { - t.Errorf("IngestFindings with nil findings should not error: %v", err) - } -} - -// TestService_IngestFindings_MultipleTypes tests ingestion of multiple finding types. -func TestService_IngestFindings_MultipleTypes(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "taskwing-multi-type-ingest-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - repo, err := memory.NewDefaultRepository(tmpDir) - if err != nil { - t.Fatalf("failed to create repository: %v", err) - } - defer func() { _ = repo.Close() }() - - svc := NewService(repo, llm.Config{}) - - // Create findings of different types - findings := []core.Finding{ - { - Type: memory.NodeTypeDecision, - Title: "Architecture Decision", - Description: "Use MVC pattern for web layer", - SourceAgent: "code-agent", - }, - { - Type: memory.NodeTypePattern, - Title: "Repository Pattern", - Description: "Data access through repository interfaces", - SourceAgent: "code-agent", - }, - { - Type: memory.NodeTypeConstraint, - Title: "No External Dependencies", - Description: "Must work offline without network access", - SourceAgent: "code-agent", - }, - { - Type: memory.NodeTypeFeature, - Title: "Semantic Search", - Description: "Search using vector embeddings", - SourceAgent: "doc-agent", - }, - { - Type: memory.NodeTypeDocumentation, - Title: "API Documentation", - Description: "OpenAPI spec for REST endpoints", - SourceAgent: "doc-agent", - }, - } - - err = svc.IngestFindings(context.Background(), findings, nil, false) - if err != nil { - t.Fatalf("IngestFindings failed: %v", err) - } - - // Verify counts by type - nodes, err := repo.ListNodes("") - if err != nil { - t.Fatalf("ListNodes failed: %v", err) - } - - typeCounts := make(map[string]int) - for _, n := range nodes { - typeCounts[n.Type]++ - } - - expectedTypes := []string{ - memory.NodeTypeDecision, - memory.NodeTypePattern, - memory.NodeTypeConstraint, - memory.NodeTypeFeature, - memory.NodeTypeDocumentation, - } - - for _, expectedType := range expectedTypes { - if typeCounts[expectedType] == 0 { - t.Errorf("Expected at least one node of type %s", expectedType) - } - } -} - -// TestService_IngestFindings_WithWorkspace tests ingestion with workspace tagging. -func TestService_IngestFindings_WithWorkspace(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "taskwing-workspace-ingest-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - repo, err := memory.NewDefaultRepository(tmpDir) - if err != nil { - t.Fatalf("failed to create repository: %v", err) - } - defer func() { _ = repo.Close() }() - - svc := NewService(repo, llm.Config{}) - - // Create findings with workspace metadata (simulating monorepo bootstrap) - // NOTE: Titles must be sufficiently distinct to avoid Jaccard similarity deduplication - // (threshold is 0.35). Using completely different titles avoids false positives. - findings := []core.Finding{ - { - Type: memory.NodeTypeDecision, - Title: "REST Endpoint Authentication Strategy", - Description: "JWT-based auth for API gateway", - SourceAgent: "code-agent", - Metadata: map[string]any{ - "service": "api", - "workspace": "api", - }, - }, - { - Type: memory.NodeTypePattern, - Title: "React Component Composition Pattern", - Description: "Higher-order components for shared UI logic", - SourceAgent: "code-agent", - Metadata: map[string]any{ - "service": "web", - "workspace": "web", - }, - }, - } - - err = svc.IngestFindings(context.Background(), findings, nil, false) - if err != nil { - t.Fatalf("IngestFindings failed: %v", err) - } - - // Verify nodes exist - nodes, err := repo.ListNodes("") - if err != nil { - t.Fatalf("ListNodes failed: %v", err) - } - - if len(nodes) < 2 { - t.Errorf("Expected at least 2 nodes, got %d", len(nodes)) - } - - // Verify both workspaces are represented - workspaces := make(map[string]bool) - for _, n := range nodes { - workspaces[n.Workspace] = true - } - if !workspaces["api"] { - t.Error("Expected a node with workspace 'api'") - } - if !workspaces["web"] { - t.Error("Expected a node with workspace 'web'") - } -} - -// ============================================================================= -// Repository Integration Tests (using NewDefaultRepository) -// ============================================================================= - -// TestNewDefaultRepository_CreateAndRetrieve tests basic repository operations. -func TestNewDefaultRepository_CreateAndRetrieve(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "taskwing-repo-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - // Use NewDefaultRepository as mandated by constraints - repo, err := memory.NewDefaultRepository(tmpDir) - if err != nil { - t.Fatalf("NewDefaultRepository failed: %v", err) - } - defer func() { _ = repo.Close() }() - - // Create a node - testNode := &memory.Node{ - ID: "test-node-create-retrieve", - Content: "Test content for create/retrieve test", - Type: memory.NodeTypeDecision, - Summary: "Test Summary", - Workspace: "root", - } - - err = repo.CreateNode(testNode) - if err != nil { - t.Fatalf("CreateNode failed: %v", err) - } - - // Retrieve the node - retrieved, err := repo.GetNode("test-node-create-retrieve") - if err != nil { - t.Fatalf("GetNode failed: %v", err) - } - - if retrieved == nil { - t.Fatal("GetNode returned nil") - } - - // Verify content - if retrieved.Summary != "Test Summary" { - t.Errorf("Summary = %q, want %q", retrieved.Summary, "Test Summary") - } - if retrieved.Type != memory.NodeTypeDecision { - t.Errorf("Type = %q, want %q", retrieved.Type, memory.NodeTypeDecision) - } -} - -// TestNewDefaultRepository_SQLiteIsCanonical verifies SQLite is used as the source of truth. -func TestNewDefaultRepository_SQLiteIsCanonical(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "taskwing-sqlite-canonical-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - repo, err := memory.NewDefaultRepository(tmpDir) - if err != nil { - t.Fatalf("NewDefaultRepository failed: %v", err) - } - - // Create multiple nodes - for i := 0; i < 3; i++ { - node := &memory.Node{ - ID: "node-sqlite-" + string(rune('a'+i)), - Content: "Content " + string(rune('a'+i)), - Type: memory.NodeTypeDecision, - Summary: "Summary " + string(rune('a'+i)), - } - if err := repo.CreateNode(node); err != nil { - t.Fatalf("CreateNode failed for %d: %v", i, err) - } - } - - // Close and reopen to verify persistence - if err := repo.Close(); err != nil { - t.Fatalf("Close failed: %v", err) - } - - repo2, err := memory.NewDefaultRepository(tmpDir) - if err != nil { - t.Fatalf("NewDefaultRepository (reopen) failed: %v", err) - } - defer func() { _ = repo2.Close() }() - - // Verify data persisted (SQLite is the source of truth) - nodes, err := repo2.ListNodes("") - if err != nil { - t.Fatalf("ListNodes failed: %v", err) - } - - if len(nodes) != 3 { - t.Errorf("Expected 3 nodes after reopen, got %d", len(nodes)) - } -} diff --git a/internal/knowledge/service.go b/internal/knowledge/service.go index 3f3a73c..75d9280 100644 --- a/internal/knowledge/service.go +++ b/internal/knowledge/service.go @@ -10,6 +10,7 @@ import ( "time" "github.com/cloudwego/eino/schema" + "github.com/google/uuid" "github.com/josephgoksu/TaskWing/internal/llm" "github.com/josephgoksu/TaskWing/internal/memory" ) @@ -289,14 +290,14 @@ func (s *Service) searchInternal(ctx context.Context, query string, typeFilter s minResultThreshold := float32(cfg.MinResultScoreThreshold) // Two-stage retrieval: fetch more candidates for reranking - // Stage 1 (Recall): Fetch Top-25 candidates using hybrid search - recallLimit := cfg.RerankTopK - if recallLimit <= 0 { - recallLimit = 25 // Default recall candidates + // Stage 1 (Candidate retrieval): Fetch Top-25 candidates using hybrid search + candidateLimit := cfg.RerankTopK + if candidateLimit <= 0 { + candidateLimit = 25 // Default candidates } if !cfg.RerankingEnabled { // If reranking disabled, just fetch what we need - recallLimit = limit * 2 // Fetch 2x for graph expansion buffer + candidateLimit = limit * 2 // Fetch 2x for graph expansion buffer } // Collect results from both search methods @@ -305,7 +306,7 @@ func (s *Service) searchInternal(ctx context.Context, query string, typeFilter s // 1. FTS5 keyword search (fast, no API call, always works) // Note: FTS currently searches all types. We filter later. - ftsResults, err := s.repo.SearchFTS(query, recallLimit) + ftsResults, err := s.repo.SearchFTS(query, candidateLimit) if err != nil { // FTS5 errors are logged but don't fail the search // FTS5 may be unavailable on some systems (missing extension) @@ -401,9 +402,9 @@ func (s *Service) searchInternal(ctx context.Context, query string, typeFilter s return scored[i].Score > scored[j].Score }) - // Limit to recall candidates before reranking - if len(scored) > recallLimit { - scored = scored[:recallLimit] + // Limit to candidates before reranking + if len(scored) > candidateLimit { + scored = scored[:candidateLimit] } // 4. Stage 2 (Precision): Rerank using TEI if enabled @@ -599,6 +600,7 @@ Be concise and direct. // Uses UpsertNodeBySummary for dedup (Jaccard similarity on summaries). func (s *Service) AddNode(ctx context.Context, input NodeInput) (*memory.Node, error) { node := &memory.Node{ + ID: "n-" + uuid.New().String()[:8], Content: input.Content, Type: input.Type, Summary: input.Summary, @@ -793,9 +795,9 @@ func (s *Service) SearchDebug(ctx context.Context, query string, limit int) (*De vectorWeight := float32(cfg.VectorWeight) vectorThreshold := float32(cfg.VectorScoreThreshold) - recallLimit := cfg.RerankTopK - if recallLimit <= 0 { - recallLimit = 25 + candidateLimit := cfg.RerankTopK + if candidateLimit <= 0 { + candidateLimit = 25 } // Track individual scores per node @@ -828,7 +830,7 @@ func (s *Service) SearchDebug(ctx context.Context, query string, limit int) (*De // 2. FTS5 keyword search startFTS := time.Now() - ftsResults, err := s.repo.SearchFTS(query, recallLimit) + ftsResults, err := s.repo.SearchFTS(query, candidateLimit) if err == nil && len(ftsResults) > 0 { pipeline = append(pipeline, "FTS") for _, r := range ftsResults { @@ -899,8 +901,8 @@ func (s *Service) SearchDebug(ctx context.Context, query string, limit int) (*De return scored[i].Score > scored[j].Score }) - if len(scored) > recallLimit { - scored = scored[:recallLimit] + if len(scored) > candidateLimit { + scored = scored[:candidateLimit] } response.TotalCandidates = len(scored) diff --git a/internal/knowledge/summary.go b/internal/knowledge/summary.go index c5e1b59..c7af70d 100644 --- a/internal/knowledge/summary.go +++ b/internal/knowledge/summary.go @@ -8,7 +8,7 @@ import ( // This centralizes summary logic so CLI and MCP usage remains consistent. // Includes the project overview (if available) at the top of the response. func (s *Service) GetProjectSummary(ctx context.Context) (ProjectSummary, error) { - // Fetch project overview first (prepended to all recall responses) + // Fetch project overview first (prepended to all ask responses) var overviewInfo *ProjectOverviewInfo if overview, err := s.repo.GetProjectOverview(); err == nil && overview != nil { overviewInfo = &ProjectOverviewInfo{ diff --git a/internal/llm/provider_test.go b/internal/llm/provider_test.go deleted file mode 100644 index d372b28..0000000 --- a/internal/llm/provider_test.go +++ /dev/null @@ -1,173 +0,0 @@ -package llm - -import "testing" - -func TestValidateProvider_Bedrock(t *testing.T) { - got, err := ValidateProvider("bedrock") - if err != nil { - t.Fatalf("ValidateProvider(bedrock) error = %v", err) - } - if got != ProviderBedrock { - t.Fatalf("ValidateProvider(bedrock) = %q, want %q", got, ProviderBedrock) - } -} - -func TestInferProvider_BedrockModelID(t *testing.T) { - tests := []string{ - "anthropic.claude-opus-4-6-v1", - "us.anthropic.claude-sonnet-4-5-20250929-v1:0", - "amazon.nova-pro-v1:0", - "amazon.nova-premier-v1:0", - "meta.llama4-maverick-17b-instruct-v1:0", - "meta.llama3-3-70b-instruct-v1:0", - "openai.gpt-oss-120b-1:0", - "qwen.qwen3-235b-a22b-instruct-2507-v1:0", - "google.gemma-3-27b-it-v1:0", - } - for _, modelID := range tests { - provider, ok := InferProvider(modelID) - if !ok { - t.Fatalf("InferProvider(%q) = not inferred", modelID) - } - if provider != ProviderBedrock { - t.Fatalf("InferProvider(%q) = %q, want %q", modelID, provider, ProviderBedrock) - } - } -} - -func TestGetProviders_IncludesBedrock(t *testing.T) { - providers := GetProviders() - t.Logf("providers: %+v", providers) - found := false - for _, p := range providers { - if p.ID == ProviderBedrock { - found = true - break - } - } - if !found { - t.Fatalf("GetProviders() missing %q", ProviderBedrock) - } -} - -// ============================================ -// TaskWing provider and model tests -// ============================================ - -func TestValidateProvider_TaskWing(t *testing.T) { - got, err := ValidateProvider("taskwing") - if err != nil { - t.Fatalf("ValidateProvider(taskwing) error = %v", err) - } - if got != ProviderTaskWing { - t.Fatalf("ValidateProvider(taskwing) = %q, want %q", got, ProviderTaskWing) - } -} - -func TestGetProviders_IncludesTaskWing(t *testing.T) { - providers := GetProviders() - found := false - for _, p := range providers { - if p.ID == ProviderTaskWing { - found = true - if p.EnvVar != "TASKWING_API_KEY" { - t.Fatalf("TaskWing provider EnvVar = %q, want TASKWING_API_KEY", p.EnvVar) - } - if p.DefaultModel != ModelKarluk { - t.Fatalf("TaskWing provider DefaultModel = %q, want %q", p.DefaultModel, ModelKarluk) - } - break - } - } - if !found { - t.Fatalf("GetProviders() missing %q", ProviderTaskWing) - } -} - -func TestGetModel_Karluk(t *testing.T) { - tests := []struct { - modelID string - wantNil bool - wantProv string - }{ - {ModelKarluk, false, ProviderTaskWing}, - {"karluk-7b", false, ProviderTaskWing}, // alias - } - for _, tc := range tests { - t.Run(tc.modelID, func(t *testing.T) { - m := GetModel(tc.modelID) - if tc.wantNil && m != nil { - t.Fatalf("GetModel(%q) = %v, want nil", tc.modelID, m) - } - if !tc.wantNil && m == nil { - t.Fatalf("GetModel(%q) = nil, want non-nil", tc.modelID) - } - if m != nil && m.ProviderID != tc.wantProv { - t.Fatalf("GetModel(%q).ProviderID = %q, want %q", tc.modelID, m.ProviderID, tc.wantProv) - } - }) - } -} - -func TestInferProvider_Karluk(t *testing.T) { - tests := []string{ - "karluk", - "karluk-7b", - "karluk-custom", - } - for _, modelID := range tests { - provider, ok := InferProvider(modelID) - if !ok { - t.Fatalf("InferProvider(%q) = not inferred", modelID) - } - if provider != ProviderTaskWing { - t.Fatalf("InferProvider(%q) = %q, want %q", modelID, provider, ProviderTaskWing) - } - } -} - -func TestKarluk_ManagedPricing(t *testing.T) { - m := GetModel(ModelKarluk) - if m == nil { - t.Fatal("GetModel(karluk) = nil") - } - // Managed models have $0 in registry (pricing is account-based, not per-token in registry) - cost := CalculateCost(ModelKarluk, 1_000_000, 1_000_000) - if cost != 0 { - t.Fatalf("CalculateCost(karluk, 1M, 1M) = $%.4f, want $0.00", cost) - } -} - -func TestKarluk_MaxInputTokens(t *testing.T) { - tokens := GetMaxInputTokens(ModelKarluk) - if tokens != 32_768 { - t.Fatalf("GetMaxInputTokens(karluk) = %d, want 32768", tokens) - } -} - -func TestKarluk_Categories(t *testing.T) { - m := GetModel(ModelKarluk) - if m == nil { - t.Fatal("GetModel(karluk) = nil") - } - if m.Category != CategoryBalanced { - t.Fatalf("karluk category = %q, want %q", m.Category, CategoryBalanced) - } -} - -func TestKarluk_IsDefault(t *testing.T) { - m := GetDefaultModel(ProviderTaskWing) - if m == nil { - t.Fatal("GetDefaultModel(taskwing) = nil") - } - if m.ID != ModelKarluk { - t.Fatalf("GetDefaultModel(taskwing) = %q, want %q", m.ID, ModelKarluk) - } -} - -func TestGetEnvVarForProvider_TaskWing(t *testing.T) { - envVar := GetEnvVarForProvider(ProviderTaskWing) - if envVar != "TASKWING_API_KEY" { - t.Fatalf("GetEnvVarForProvider(taskwing) = %q, want TASKWING_API_KEY", envVar) - } -} diff --git a/internal/logger/crash_test.go b/internal/logger/crash_test.go deleted file mode 100644 index 40c04e4..0000000 --- a/internal/logger/crash_test.go +++ /dev/null @@ -1,239 +0,0 @@ -package logger - -import ( - "os" - "path/filepath" - "strings" - "testing" - "time" -) - -func TestCrashHandler_SetContext(t *testing.T) { - // Reset global context - globalContext = &CrashContext{} - - SetBasePath("/tmp/test-taskwing") - SetVersion("1.0.0-test") - SetCommand("test command") - SetLastInput("test input") - SetLastPrompt("test prompt") - - globalContext.mu.RLock() - defer globalContext.mu.RUnlock() - - if globalContext.basePath != "/tmp/test-taskwing" { - t.Errorf("Expected basePath '/tmp/test-taskwing', got '%s'", globalContext.basePath) - } - if globalContext.version != "1.0.0-test" { - t.Errorf("Expected version '1.0.0-test', got '%s'", globalContext.version) - } - if globalContext.command != "test command" { - t.Errorf("Expected command 'test command', got '%s'", globalContext.command) - } - if globalContext.lastInput != "test input" { - t.Errorf("Expected lastInput 'test input', got '%s'", globalContext.lastInput) - } - if globalContext.lastPrompt != "test prompt" { - t.Errorf("Expected lastPrompt 'test prompt', got '%s'", globalContext.lastPrompt) - } -} - -func TestCrashHandler_SetLastPrompt_Truncation(t *testing.T) { - // Reset global context - globalContext = &CrashContext{} - - // Create a long prompt - longPrompt := strings.Repeat("a", 3000) - SetLastPrompt(longPrompt) - - globalContext.mu.RLock() - defer globalContext.mu.RUnlock() - - if len(globalContext.lastPrompt) > 2100 { - t.Errorf("Expected prompt to be truncated, got length %d", len(globalContext.lastPrompt)) - } - if !strings.Contains(globalContext.lastPrompt, "[truncated]") { - t.Error("Expected truncated prompt to contain '[truncated]'") - } -} - -func TestCrashHandler_CreateCrashLog(t *testing.T) { - // Reset global context - globalContext = &CrashContext{ - version: "1.0.0", - command: "test", - lastInput: "user input", - } - - log := createCrashLog("test panic") - - if log.PanicValue != "test panic" { - t.Errorf("Expected PanicValue 'test panic', got '%s'", log.PanicValue) - } - if log.Version != "1.0.0" { - t.Errorf("Expected Version '1.0.0', got '%s'", log.Version) - } - if log.Command != "test" { - t.Errorf("Expected Command 'test', got '%s'", log.Command) - } - if log.LastInput != "user input" { - t.Errorf("Expected LastInput 'user input', got '%s'", log.LastInput) - } - if log.StackTrace == "" { - t.Error("Expected non-empty StackTrace") - } - if log.GoVersion == "" { - t.Error("Expected non-empty GoVersion") - } -} - -func TestCrashHandler_FormatCrashLog(t *testing.T) { - log := CrashLog{ - Timestamp: time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC), - Version: "1.0.0", - Command: "test", - PanicValue: "test panic", - StackTrace: "goroutine 1 [running]:\nmain.main()", - LastInput: "user input", - GoVersion: "go1.24.3", - OS: "darwin", - Arch: "arm64", - } - - formatted := formatCrashLog(log) - - expectedStrings := []string{ - "TASKWING CRASH LOG", - "Timestamp: 2025-01-01T12:00:00Z", - "Version: 1.0.0", - "Command: test", - "Go: go1.24.3", - "OS/Arch: darwin/arm64", - "PANIC VALUE", - "test panic", - "STACK TRACE", - "goroutine 1 [running]", - "LAST USER INPUT", - "user input", - } - - for _, expected := range expectedStrings { - if !strings.Contains(formatted, expected) { - t.Errorf("Expected formatted log to contain '%s'", expected) - } - } -} - -func TestCrashHandler_WriteCrashLog(t *testing.T) { - // Create temp directory - tmpDir := t.TempDir() - basePath := filepath.Join(tmpDir, ".taskwing") - - // Set up context - globalContext = &CrashContext{ - basePath: basePath, - version: "1.0.0", - command: "test", - } - - log := CrashLog{ - Timestamp: time.Now(), - Version: "1.0.0", - Command: "test", - PanicValue: "test panic", - StackTrace: "test stack", - GoVersion: "go1.24", - OS: "test", - Arch: "test", - } - - err := writeCrashLog(log) - if err != nil { - t.Fatalf("writeCrashLog failed: %v", err) - } - - // Verify directory was created - crashDir := filepath.Join(basePath, CrashLogDir) - if _, err := os.Stat(crashDir); os.IsNotExist(err) { - t.Error("Expected crash log directory to be created") - } - - // Verify file was created - logs, err := ListCrashLogs() - if err != nil { - t.Fatalf("ListCrashLogs failed: %v", err) - } - if len(logs) != 1 { - t.Errorf("Expected 1 crash log, got %d", len(logs)) - } - - // Verify file content - if len(logs) > 0 { - content, err := ReadCrashLog(logs[0]) - if err != nil { - t.Fatalf("ReadCrashLog failed: %v", err) - } - if !strings.Contains(content, "test panic") { - t.Error("Expected crash log to contain panic value") - } - } -} - -func TestCrashHandler_CleanOldLogs(t *testing.T) { - // Create temp directory - tmpDir := t.TempDir() - basePath := filepath.Join(tmpDir, ".taskwing") - crashDir := filepath.Join(basePath, CrashLogDir) - - if err := os.MkdirAll(crashDir, 0755); err != nil { - t.Fatalf("Failed to create crash dir: %v", err) - } - - // Set up context - globalContext = &CrashContext{basePath: basePath} - - // Create more than MaxCrashLogs files - for i := range MaxCrashLogs + 5 { - filename := filepath.Join(crashDir, "crash_20250101_1200"+string(rune('0'+i%10))+string(rune('0'+i/10))+".log") - if err := os.WriteFile(filename, []byte("test"), 0644); err != nil { - t.Fatalf("Failed to create test file: %v", err) - } - } - - // Clean old logs - if err := cleanOldCrashLogs(crashDir); err != nil { - t.Fatalf("cleanOldCrashLogs failed: %v", err) - } - - // Verify count - logs, err := ListCrashLogs() - if err != nil { - t.Fatalf("ListCrashLogs failed: %v", err) - } - if len(logs) != MaxCrashLogs { - t.Errorf("Expected %d crash logs after cleanup, got %d", MaxCrashLogs, len(logs)) - } -} - -func TestCrashHandler_GetCrashLogPath(t *testing.T) { - globalContext = &CrashContext{basePath: "/tmp/test"} - - testTime := time.Date(2025, 1, 15, 14, 30, 45, 0, time.UTC) - path := getCrashLogPath(testTime) - - expectedPath := "/tmp/test/crash_logs/crash_20250115_143045.log" - if path != expectedPath { - t.Errorf("Expected path '%s', got '%s'", expectedPath, path) - } -} - -func TestCrashHandler_DefaultBasePath(t *testing.T) { - // Reset global context with empty basePath - globalContext = &CrashContext{} - - dir := getCrashLogDir() - expected := ".taskwing/crash_logs" - if dir != expected { - t.Errorf("Expected default dir '%s', got '%s'", expected, dir) - } -} diff --git a/internal/mcp/handlers.go b/internal/mcp/handlers.go index b7b9575..22152a9 100644 --- a/internal/mcp/handlers.go +++ b/internal/mcp/handlers.go @@ -290,16 +290,16 @@ func handleCodeSimplify(ctx context.Context, repo *memory.Repository, params Cod // Get architectural context for better simplification appCtx := app.NewContextForRole(repo, llm.RoleQuery) - recallApp := app.NewRecallApp(appCtx) + askApp := app.NewAskApp(appCtx) var kgContext string if filePath != "" { - result, err := recallApp.Query(ctx, "patterns and constraints for "+filePath, app.RecallOptions{ + result, err := askApp.Query(ctx, "patterns and constraints for "+filePath, app.AskOptions{ Limit: 3, GenerateAnswer: false, }) if err == nil && result != nil { - kgContext = formatRecallContext(result) + kgContext = formatAskContext(result) } } @@ -405,8 +405,8 @@ func validateAndResolvePath(requestedPath string, projectRoot string) (string, e return absPath, nil } -// formatRecallContext formats recall results for agent context. -func formatRecallContext(result *app.RecallResult) string { +// formatAskContext formats ask results for agent context. +func formatAskContext(result *app.AskResult) string { if result == nil || len(result.Results) == 0 { return "" } @@ -440,7 +440,7 @@ func HandleDebugTool(ctx context.Context, repo *memory.Repository, params DebugT // Get architectural context for better diagnosis appCtx := app.NewContextForRole(repo, llm.RoleQuery) - recallApp := app.NewRecallApp(appCtx) + askApp := app.NewAskApp(appCtx) var kgContext string // Build context query from problem and file path @@ -449,12 +449,12 @@ func HandleDebugTool(ctx context.Context, repo *memory.Repository, params DebugT contextQuery = params.FilePath + " " + problem } - result, err := recallApp.Query(ctx, contextQuery, app.RecallOptions{ + result, err := askApp.Query(ctx, contextQuery, app.AskOptions{ Limit: 5, GenerateAnswer: false, }) if err == nil && result != nil { - kgContext = formatRecallContext(result) + kgContext = formatAskContext(result) } // Create and run the DebugAgent diff --git a/internal/mcp/handlers_test.go b/internal/mcp/handlers_test.go deleted file mode 100644 index 8e1ed3e..0000000 --- a/internal/mcp/handlers_test.go +++ /dev/null @@ -1,725 +0,0 @@ -package mcp - -import ( - "context" - "os" - "path/filepath" - "strings" - "testing" - - "github.com/josephgoksu/TaskWing/internal/memory" -) - -func TestHandleCodeTool_InvalidAction(t *testing.T) { - params := CodeToolParams{ - Action: "invalid_action", - Query: "test", - } - - result, err := HandleCodeTool(context.Background(), nil, params) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Error == "" { - t.Error("expected error for invalid action") - } - if result.Action != "invalid_action" { - t.Errorf("expected action 'invalid_action', got %q", result.Action) - } -} - -func TestHandleCodeTool_SearchMissingQuery(t *testing.T) { - params := CodeToolParams{ - Action: CodeActionSearch, - Query: "", // missing query - } - - result, err := HandleCodeTool(context.Background(), nil, params) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Error == "" { - t.Error("expected error for missing query") - } - if result.Action != "search" { - t.Errorf("expected action 'search', got %q", result.Action) - } -} - -func TestHandleCodeTool_ExplainMissingIdentifier(t *testing.T) { - params := CodeToolParams{ - Action: CodeActionExplain, - Query: "", - SymbolID: 0, // both missing - } - - result, err := HandleCodeTool(context.Background(), nil, params) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Error == "" { - t.Error("expected error for missing query/symbol_id") - } -} - -func TestHandleCodeTool_CallersMissingIdentifier(t *testing.T) { - params := CodeToolParams{ - Action: CodeActionCallers, - Query: "", - SymbolID: 0, // both missing - } - - result, err := HandleCodeTool(context.Background(), nil, params) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Error == "" { - t.Error("expected error for missing query/symbol_id") - } -} - -func TestHandleCodeTool_ImpactMissingIdentifier(t *testing.T) { - params := CodeToolParams{ - Action: CodeActionImpact, - Query: "", - SymbolID: 0, // both missing - } - - result, err := HandleCodeTool(context.Background(), nil, params) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Error == "" { - t.Error("expected error for missing query/symbol_id") - } -} - -func TestHandleCodeTool_ActionRouting(t *testing.T) { - // Test actions that have validation before hitting the repo - // (search, explain, callers, impact all require query/symbol_id) - tests := []struct { - action CodeAction - name string - expectError bool - errorContains string - }{ - {CodeActionSearch, "search", true, "query is required"}, - {CodeActionExplain, "explain", true, "query or symbol_id is required"}, - {CodeActionCallers, "callers", true, "symbol_id or query"}, - {CodeActionImpact, "impact", true, "symbol_id or query"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - params := CodeToolParams{ - Action: tt.action, - // Intentionally missing required fields to trigger validation error - } - - result, err := HandleCodeTool(context.Background(), nil, params) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Action != tt.name { - t.Errorf("expected action %q, got %q", tt.name, result.Action) - } - - if tt.expectError && result.Error == "" { - t.Error("expected validation error") - } - }) - } -} - -// === Task Tool Handler Tests === - -func TestHandleTaskTool_InvalidAction(t *testing.T) { - params := TaskToolParams{ - Action: "invalid_action", - } - - result, err := HandleTaskTool(context.Background(), nil, params, "") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Error == "" { - t.Error("expected error for invalid action") - } - if result.Action != "invalid_action" { - t.Errorf("expected action 'invalid_action', got %q", result.Action) - } -} - -func TestHandleTaskTool_StartMissingTaskID(t *testing.T) { - params := TaskToolParams{ - Action: TaskActionStart, - TaskID: "", // missing - SessionID: "session-123", - } - - result, err := HandleTaskTool(context.Background(), nil, params, "") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Error == "" { - t.Error("expected error for missing task_id") - } - if result.Action != "start" { - t.Errorf("expected action 'start', got %q", result.Action) - } -} - -func TestHandleTaskTool_StartMissingSessionID(t *testing.T) { - params := TaskToolParams{ - Action: TaskActionStart, - TaskID: "task-123", - SessionID: "", // missing - } - - result, err := HandleTaskTool(context.Background(), nil, params, "") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Error == "" { - t.Error("expected error for missing session_id") - } -} - -func TestResolveTaskSessionID(t *testing.T) { - tests := []struct { - name string - explicit string - fallback string - want string - }{ - {name: "explicit wins", explicit: "session-explicit", fallback: "session-default", want: "session-explicit"}, - {name: "fallback used", explicit: "", fallback: "session-default", want: "session-default"}, - {name: "both empty", explicit: "", fallback: "", want: ""}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := resolveTaskSessionID(tt.explicit, tt.fallback) - if got != tt.want { - t.Fatalf("resolveTaskSessionID(%q, %q) = %q, want %q", tt.explicit, tt.fallback, got, tt.want) - } - }) - } -} - -func TestHandleTaskTool_NextUsesDefaultSessionIDWhenOmitted(t *testing.T) { - repo, err := memory.NewDefaultRepository(t.TempDir()) - if err != nil { - t.Fatalf("create repository: %v", err) - } - defer func() { _ = repo.Close() }() - - params := TaskToolParams{ - Action: TaskActionNext, - } - - result, err := HandleTaskTool(context.Background(), repo, params, "session-from-mcp") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if strings.Contains(result.Error, "session_id is required") { - t.Fatalf("expected default session_id to be used, got error: %q", result.Error) - } -} - -func TestHandleTaskTool_CompleteMissingTaskID(t *testing.T) { - params := TaskToolParams{ - Action: TaskActionComplete, - TaskID: "", // missing - } - - result, err := HandleTaskTool(context.Background(), nil, params, "") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Error == "" { - t.Error("expected error for missing task_id") - } - if result.Action != "complete" { - t.Errorf("expected action 'complete', got %q", result.Action) - } -} - -func TestHandleTaskTool_NextMissingSessionID(t *testing.T) { - params := TaskToolParams{ - Action: TaskActionNext, - SessionID: "", // missing - } - - result, err := HandleTaskTool(context.Background(), nil, params, "") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Error == "" { - t.Error("expected error for missing session_id") - } - if !strings.Contains(result.Error, "session_id") { - t.Errorf("error should mention session_id: %s", result.Error) - } - if result.Action != "next" { - t.Errorf("expected action 'next', got %q", result.Action) - } - // Should have actionable guidance in content - if !strings.Contains(result.Content, "session") { - t.Error("content should mention session for guidance") - } -} - -func TestHandleTaskTool_CurrentMissingSessionID(t *testing.T) { - params := TaskToolParams{ - Action: TaskActionCurrent, - SessionID: "", // missing - } - - result, err := HandleTaskTool(context.Background(), nil, params, "") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Error == "" { - t.Error("expected error for missing session_id") - } - if !strings.Contains(result.Error, "session_id") { - t.Errorf("error should mention session_id: %s", result.Error) - } - if result.Action != "current" { - t.Errorf("expected action 'current', got %q", result.Action) - } -} - -func TestHandleTaskTool_ActionRouting(t *testing.T) { - // Test actions that have validation before hitting the repo - tests := []struct { - action TaskAction - name string - expectError bool - }{ - {TaskActionNext, "next", true}, // missing session_id - {TaskActionCurrent, "current", true}, // missing session_id - {TaskActionStart, "start", true}, // missing task_id - {TaskActionComplete, "complete", true}, // missing task_id - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - params := TaskToolParams{ - Action: tt.action, - // Intentionally missing required fields - } - - result, err := HandleTaskTool(context.Background(), nil, params, "") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Action != tt.name { - t.Errorf("expected action %q, got %q", tt.name, result.Action) - } - - if tt.expectError && result.Error == "" { - t.Error("expected validation error") - } - }) - } -} - -// === Plan Tool Handler Tests === - -func TestHandlePlanTool_InvalidAction(t *testing.T) { - params := PlanToolParams{ - Action: "invalid_action", - Goal: "test goal", - } - - result, err := HandlePlanTool(context.Background(), nil, params) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Error == "" { - t.Error("expected error for invalid action") - } - if result.Action != "invalid_action" { - t.Errorf("expected action 'invalid_action', got %q", result.Action) - } -} - -func TestHandlePlanTool_ClarifyMissingGoal(t *testing.T) { - params := PlanToolParams{ - Action: PlanActionClarify, - Goal: "", // missing - } - - result, err := HandlePlanTool(context.Background(), nil, params) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Error == "" { - t.Error("expected error for missing goal") - } - if result.Action != "clarify" { - t.Errorf("expected action 'clarify', got %q", result.Action) - } -} - -func TestHandlePlanTool_ClarifyFollowUpMissingAnswers(t *testing.T) { - params := PlanToolParams{ - Action: PlanActionClarify, - ClarifySessionID: "clarify-123", - } - - result, err := HandlePlanTool(context.Background(), nil, params) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if result.Error == "" { - t.Fatal("expected validation error for missing answers") - } - if !strings.Contains(result.Error, "answers") { - t.Fatalf("expected error to mention answers, got %q", result.Error) - } - if !strings.Contains(result.Content, "auto_answer") { - t.Fatalf("expected remediation to mention auto_answer, got %q", result.Content) - } -} - -func TestHandlePlanTool_ClarifyFollowUpAllowsAutoAnswerWithoutAnswers(t *testing.T) { - params := PlanToolParams{ - Action: PlanActionClarify, - ClarifySessionID: "clarify-123", - AutoAnswer: true, - } - - result, err := HandlePlanTool(context.Background(), nil, params) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if strings.Contains(result.Error, "answers are required") { - t.Fatalf("unexpected answers validation error when auto_answer=true: %q", result.Error) - } -} - -func TestHandlePlanTool_GenerateMissingGoal(t *testing.T) { - params := PlanToolParams{ - Action: PlanActionGenerate, - Goal: "", // missing - EnrichedGoal: "some enriched goal", - } - - result, err := HandlePlanTool(context.Background(), nil, params) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Error == "" { - t.Error("expected error for missing goal") - } - if result.Action != "generate" { - t.Errorf("expected action 'generate', got %q", result.Action) - } -} - -func TestHandlePlanTool_GenerateMissingEnrichedGoal(t *testing.T) { - params := PlanToolParams{ - Action: PlanActionGenerate, - Goal: "test goal", - EnrichedGoal: "", // missing - } - - result, err := HandlePlanTool(context.Background(), nil, params) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Error == "" { - t.Error("expected error for missing enriched_goal") - } - if result.Action != "generate" { - t.Errorf("expected action 'generate', got %q", result.Action) - } -} - -// TestHandlePlanTool_GenerateErrorContainsFieldDetails validates that validation errors -// contain actionable field-level details to help AI clients self-correct. -func TestHandlePlanTool_GenerateErrorContainsFieldDetails(t *testing.T) { - tests := []struct { - name string - params PlanToolParams - expectedFields []string - }{ - { - name: "missing_required_fields_lists_all", - params: PlanToolParams{ - Action: PlanActionGenerate, - // goal, enriched_goal, clarify_session_id missing - }, - expectedFields: []string{"goal", "enriched_goal", "clarify_session_id"}, - }, - { - name: "missing_goal_and_session_lists_both", - params: PlanToolParams{ - Action: PlanActionGenerate, - EnrichedGoal: "some enriched goal", - }, - expectedFields: []string{"goal", "clarify_session_id"}, - }, - { - name: "missing_enriched_goal_and_session", - params: PlanToolParams{ - Action: PlanActionGenerate, - Goal: "some goal", - }, - expectedFields: []string{"enriched_goal", "clarify_session_id"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := HandlePlanTool(context.Background(), nil, tt.params) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - // Error should list missing fields - for _, field := range tt.expectedFields { - if !strings.Contains(result.Error, field) { - t.Errorf("error should contain field %q: %s", field, result.Error) - } - } - - // Content should have actionable guidance - if !strings.Contains(result.Content, "clarify") { - t.Error("content should mention 'clarify' action for guidance") - } - }) - } -} - -func TestHandlePlanTool_ActionRouting(t *testing.T) { - // Test actions that have validation before hitting the repo - tests := []struct { - action PlanAction - name string - expectError bool - errorContains string - }{ - {PlanActionClarify, "clarify", true, "goal is required"}, - {PlanActionGenerate, "generate", true, "goal is required"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - params := PlanToolParams{ - Action: tt.action, - // Intentionally missing required fields - } - - result, err := HandlePlanTool(context.Background(), nil, params) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Action != tt.name { - t.Errorf("expected action %q, got %q", tt.name, result.Action) - } - - if tt.expectError && result.Error == "" { - t.Error("expected validation error") - } - }) - } -} - -// === Path Validation Tests === - -func TestValidateAndResolvePath_PathTraversal(t *testing.T) { - tmpDir := t.TempDir() - - tests := []struct { - name string - path string - projectRoot string - wantErr bool - errContains string - }{ - { - name: "direct traversal", - path: "../../../etc/passwd", - projectRoot: tmpDir, - wantErr: true, - errContains: "path traversal not allowed", - }, - { - name: "hidden traversal in middle", - path: "foo/../../../etc/passwd", - projectRoot: tmpDir, - wantErr: true, - errContains: "path traversal not allowed", - }, - { - name: "absolute path outside project", - path: "/etc/passwd", - projectRoot: tmpDir, - wantErr: true, - errContains: "path outside project root", - }, - { - name: "relative path no project root", - path: "foo/bar.go", - projectRoot: "", - wantErr: true, - errContains: "cannot resolve relative path", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - _, err := validateAndResolvePath(tt.path, tt.projectRoot) - if (err != nil) != tt.wantErr { - t.Errorf("validateAndResolvePath() error = %v, wantErr %v", err, tt.wantErr) - return - } - if tt.wantErr && err != nil { - if tt.errContains != "" && !stringContains(err.Error(), tt.errContains) { - t.Errorf("error %q does not contain %q", err.Error(), tt.errContains) - } - } - }) - } -} - -func TestValidateAndResolvePath_ValidPaths(t *testing.T) { - // Create a temp directory with a test file - tmpDir := t.TempDir() - testFile := filepath.Join(tmpDir, "test.go") - if err := os.WriteFile(testFile, []byte("package test"), 0644); err != nil { - t.Fatalf("failed to create test file: %v", err) - } - - // Create a subdirectory with a file - subDir := filepath.Join(tmpDir, "subdir") - if err := os.MkdirAll(subDir, 0755); err != nil { - t.Fatalf("failed to create subdir: %v", err) - } - subFile := filepath.Join(subDir, "sub.go") - if err := os.WriteFile(subFile, []byte("package sub"), 0644); err != nil { - t.Fatalf("failed to create sub file: %v", err) - } - - tests := []struct { - name string - path string - projectRoot string - wantPath string - }{ - { - name: "relative path in root", - path: "test.go", - projectRoot: tmpDir, - wantPath: testFile, - }, - { - name: "relative path in subdir", - path: "subdir/sub.go", - projectRoot: tmpDir, - wantPath: subFile, - }, - { - name: "absolute path within project", - path: testFile, - projectRoot: tmpDir, - wantPath: testFile, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := validateAndResolvePath(tt.path, tt.projectRoot) - if err != nil { - t.Errorf("validateAndResolvePath() unexpected error: %v", err) - return - } - if got != tt.wantPath { - t.Errorf("validateAndResolvePath() = %q, want %q", got, tt.wantPath) - } - }) - } -} - -func TestValidateAndResolvePath_DirectoryRejection(t *testing.T) { - tmpDir := t.TempDir() - - _, err := validateAndResolvePath(tmpDir, tmpDir) - if err == nil { - t.Error("expected error for directory path") - } - if !stringContains(err.Error(), "directory") { - t.Errorf("error %q does not mention directory", err.Error()) - } -} - -func TestHandleCodeTool_SimplifyMissingInput(t *testing.T) { - params := CodeToolParams{ - Action: CodeActionSimplify, - Code: "", // missing - FilePath: "", // missing - } - - result, err := HandleCodeTool(context.Background(), nil, params) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Error == "" { - t.Error("expected error for missing code/file_path") - } - if result.Action != "simplify" { - t.Errorf("expected action 'simplify', got %q", result.Action) - } -} - -func TestHandleCodeTool_SimplifyPathTraversal(t *testing.T) { - params := CodeToolParams{ - Action: CodeActionSimplify, - FilePath: "../../../etc/passwd", - } - - result, err := HandleCodeTool(context.Background(), nil, params) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Error == "" { - t.Error("expected error for path traversal attempt") - } - if !stringContains(result.Error, "path traversal") && !stringContains(result.Error, "invalid file path") { - t.Errorf("error %q does not mention path traversal", result.Error) - } -} - -// stringContains checks if a string contains a substring (avoids conflict with presenter_test.go) -func stringContains(s, substr string) bool { - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return true - } - } - return false -} diff --git a/internal/mcp/presenter.go b/internal/mcp/presenter.go index cc012e8..9aa9ac6 100644 --- a/internal/mcp/presenter.go +++ b/internal/mcp/presenter.go @@ -17,10 +17,10 @@ import ( "golang.org/x/text/language" ) -// FormatRecall converts a RecallResult into token-efficient Markdown. +// FormatAsk converts an AskResult into token-efficient Markdown. // Structure: Answer (if present) -> Knowledge -> Symbols // Includes debt warnings for patterns/decisions marked as technical debt. -func FormatRecall(result *app.RecallResult) string { +func FormatAsk(result *app.AskResult) string { if result == nil { return "No results found." } diff --git a/internal/mcp/presenter_test.go b/internal/mcp/presenter_test.go deleted file mode 100644 index 09fc5b4..0000000 --- a/internal/mcp/presenter_test.go +++ /dev/null @@ -1,331 +0,0 @@ -package mcp - -import ( - "testing" - - agentcore "github.com/josephgoksu/TaskWing/internal/agents/core" - agentimpl "github.com/josephgoksu/TaskWing/internal/agents/impl" - "github.com/josephgoksu/TaskWing/internal/app" - "github.com/josephgoksu/TaskWing/internal/codeintel" -) - -func TestFormatRecall_NilResult(t *testing.T) { - result := FormatRecall(nil) - if result != "No results found." { - t.Errorf("expected 'No results found.', got %q", result) - } -} - -func TestFormatRecall_EmptyResult(t *testing.T) { - result := FormatRecall(&app.RecallResult{}) - if result != "No results found." { - t.Errorf("expected 'No results found.', got %q", result) - } -} - -func TestFormatRecall_WithAnswer(t *testing.T) { - result := FormatRecall(&app.RecallResult{ - Answer: "This is the answer.", - }) - if result == "" { - t.Error("expected non-empty result") - } - if !contains(result, "## Answer") { - t.Error("expected Answer section header") - } - if !contains(result, "This is the answer.") { - t.Error("expected answer content") - } -} - -func TestFormatSymbolList_Empty(t *testing.T) { - result := FormatSymbolList(nil) - if result != "No symbols found." { - t.Errorf("expected 'No symbols found.', got %q", result) - } -} - -func TestFormatSymbolList_WithSymbols(t *testing.T) { - symbols := []codeintel.Symbol{ - {Name: "TestFunc", Kind: "function", FilePath: "test.go", StartLine: 10}, - {Name: "TestStruct", Kind: "struct", FilePath: "test.go", StartLine: 20}, - } - result := FormatSymbolList(symbols) - if result == "" { - t.Error("expected non-empty result") - } - if !contains(result, "TestFunc") { - t.Error("expected TestFunc in result") - } - if !contains(result, "function") { - t.Error("expected 'function' kind in result") - } -} - -func TestFormatSearchResults_Empty(t *testing.T) { - result := FormatSearchResults(nil) - if result != "No matching symbols found." { - t.Errorf("expected 'No matching symbols found.', got %q", result) - } -} - -func TestFormatSearchResults_WithResults(t *testing.T) { - results := []codeintel.SymbolSearchResult{ - { - Symbol: codeintel.Symbol{Name: "SearchFunc", Kind: "function", FilePath: "search.go", StartLine: 5}, - Score: 0.95, - }, - } - result := FormatSearchResults(results) - if result == "" { - t.Error("expected non-empty result") - } - if !contains(result, "SearchFunc") { - t.Error("expected SearchFunc in result") - } - if !contains(result, "Search Results") { - t.Error("expected header in result") - } -} - -func TestFormatTask_NilResult(t *testing.T) { - result := FormatTask(nil) - if result != "No task information." { - t.Errorf("expected 'No task information.', got %q", result) - } -} - -func TestFormatError(t *testing.T) { - result := FormatError("something went wrong") - if !contains(result, "Error") { - t.Error("expected Error header") - } - if !contains(result, "something went wrong") { - t.Error("expected error message") - } -} - -func TestFormatValidationError(t *testing.T) { - result := FormatValidationError("query", "query is required") - if !contains(result, "Validation Error") { - t.Error("expected Validation Error header") - } - if !contains(result, "query") { - t.Error("expected field name") - } - if !contains(result, "query is required") { - t.Error("expected error message") - } -} - -func TestFormatCallers_NilResult(t *testing.T) { - result := FormatCallers(nil) - if result == "" { - t.Error("expected non-empty result for nil input") - } -} - -func TestFormatImpact_NilResult(t *testing.T) { - result := FormatImpact(nil) - if !contains(result, "Failed") { - t.Error("expected failure message for nil input") - } -} - -func TestFormatExplainResult_NilResult(t *testing.T) { - result := FormatExplainResult(nil) - if result != "No explanation available." { - t.Errorf("expected 'No explanation available.', got %q", result) - } -} - -func TestFormatDebugResult_Empty(t *testing.T) { - result := FormatDebugResult(nil) - if result != "No debug analysis available." { - t.Errorf("expected 'No debug analysis available.', got %q", result) - } -} - -func TestFormatDebugResult_WithJSONStyleData(t *testing.T) { - // Simulate data as it would come from JSON deserialization - // (slices become []interface{}, maps become map[string]interface{}) - findings := []agentcore.Finding{ - { - Type: "debug", - Description: "Database connection timeout", - Metadata: map[string]any{ - "hypotheses": []interface{}{ - map[string]interface{}{ - "cause": "Connection pool exhausted", - "likelihood": "high", - "reasoning": "Many concurrent requests", - "code_locations": []interface{}{"db/pool.go:45"}, - }, - map[string]interface{}{ - "cause": "Network latency", - "likelihood": "medium", - "reasoning": "Remote database", - }, - }, - "investigation_steps": []interface{}{ - map[string]interface{}{ - "step": float64(1), // JSON numbers are float64 - "action": "Check connection pool", - "command": "netstat -an | grep 5432", - "expected_finding": "Many ESTABLISHED connections", - }, - }, - "quick_fixes": []interface{}{ - map[string]interface{}{ - "fix": "Increase pool size", - "when": "Pool is exhausted", - }, - }, - }, - }, - } - - result := FormatDebugResult(findings) - - if result == "No debug analysis available." { - t.Error("expected non-empty result") - } - if !contains(result, "Debug Analysis") { - t.Error("expected Debug Analysis header") - } - if !contains(result, "Connection pool exhausted") { - t.Error("expected hypothesis cause") - } - if !contains(result, "high") { - t.Error("expected likelihood") - } - if !contains(result, "Check connection pool") { - t.Error("expected investigation step") - } - if !contains(result, "Increase pool size") { - t.Error("expected quick fix") - } -} - -func TestFormatSimplifyResult_Empty(t *testing.T) { - result := FormatSimplifyResult(nil) - if result != "No simplification results." { - t.Errorf("expected 'No simplification results.', got %q", result) - } -} - -func TestFormatSimplifyResult_WithJSONStyleData(t *testing.T) { - // Simulate data as it would come from JSON deserialization - findings := []agentcore.Finding{ - { - Type: "simplification", - Description: "Simplified error handling", - Metadata: map[string]any{ - "simplified_code": "return err", - "original_lines": float64(10), // JSON numbers are float64 - "simplified_lines": float64(3), - "reduction_percentage": float64(70), - "risk_assessment": "low", - "changes": []interface{}{ - map[string]interface{}{ - "what": "Removed redundant nil check", - "why": "Error is always non-nil in this branch", - "risk": "none", - }, - map[string]interface{}{ - "what": "Consolidated error wrapping", - "why": "Multiple wrap calls were redundant", - "risk": "low", - }, - }, - }, - }, - } - - result := FormatSimplifyResult(findings) - - if result == "No simplification results." { - t.Error("expected non-empty result") - } - if !contains(result, "Code Simplification") { - t.Error("expected Code Simplification header") - } - if !contains(result, "return err") { - t.Error("expected simplified code") - } - if !contains(result, "10") && !contains(result, "3") { - t.Error("expected line counts") - } - if !contains(result, "Removed redundant nil check") { - t.Error("expected change description") - } - if !contains(result, "low") { - t.Error("expected risk assessment") - } -} - -func TestExtractSimplifyChanges_DirectType(t *testing.T) { - // Test with direct type (from agent before serialization) - direct := []agentimpl.SimplifyChange{ - {What: "removed", Why: "unused", Risk: "none"}, - } - result := extractSimplifyChanges(direct) - if len(result) != 1 { - t.Errorf("expected 1 change, got %d", len(result)) - } - if result[0].What != "removed" { - t.Errorf("expected 'removed', got %q", result[0].What) - } -} - -func TestExtractDebugHypotheses_DirectType(t *testing.T) { - // Test with direct type (from agent before serialization) - direct := []agentimpl.DebugHypothesis{ - {Cause: "bug", Likelihood: "high", Reasoning: "test"}, - } - result := extractDebugHypotheses(direct) - if len(result) != 1 { - t.Errorf("expected 1 hypothesis, got %d", len(result)) - } - if result[0].Cause != "bug" { - t.Errorf("expected 'bug', got %q", result[0].Cause) - } -} - -func TestGetIntFromMetadata(t *testing.T) { - tests := []struct { - name string - metadata map[string]any - key string - want int - }{ - {"nil map", nil, "x", 0}, - {"missing key", map[string]any{}, "x", 0}, - {"float64 value", map[string]any{"x": float64(42)}, "x", 42}, - {"int value", map[string]any{"x": 42}, "x", 42}, - {"string value", map[string]any{"x": "42"}, "x", 0}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := getIntFromMetadata(tt.metadata, tt.key) - if got != tt.want { - t.Errorf("getIntFromMetadata() = %d, want %d", got, tt.want) - } - }) - } -} - -// Helper function -func contains(s, substr string) bool { - return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsSubstring(s, substr)) -} - -func containsSubstring(s, substr string) bool { - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return true - } - } - return false -} diff --git a/internal/mcp/types.go b/internal/mcp/types.go index 83dddd3..8fe23ca 100644 --- a/internal/mcp/types.go +++ b/internal/mcp/types.go @@ -206,7 +206,7 @@ func (p *TaskToolParams) UnmarshalJSON(data []byte) error { // === MCP Tool Parameters (non-unified) === -// ProjectContextParams defines the parameters for the recall tool. +// ProjectContextParams defines the parameters for the ask tool. type ProjectContextParams struct { Query string `json:"query,omitempty"` Answer bool `json:"answer,omitempty"` // If true, generate RAG answer using LLM diff --git a/internal/mcp/types_test.go b/internal/mcp/types_test.go deleted file mode 100644 index 74bb4b4..0000000 --- a/internal/mcp/types_test.go +++ /dev/null @@ -1,364 +0,0 @@ -package mcp - -import ( - "encoding/json" - "strings" - "testing" -) - -func TestCodeAction_IsValid(t *testing.T) { - tests := []struct { - action CodeAction - want bool - }{ - {CodeActionFind, true}, - {CodeActionSearch, true}, - {CodeActionExplain, true}, - {CodeActionCallers, true}, - {CodeActionImpact, true}, - {CodeActionSimplify, true}, - {"invalid", false}, - {"", false}, - {"FIND", false}, // case-sensitive - } - - for _, tt := range tests { - t.Run(string(tt.action), func(t *testing.T) { - if got := tt.action.IsValid(); got != tt.want { - t.Errorf("CodeAction(%q).IsValid() = %v, want %v", tt.action, got, tt.want) - } - }) - } -} - -func TestTaskAction_IsValid(t *testing.T) { - tests := []struct { - action TaskAction - want bool - }{ - {TaskActionNext, true}, - {TaskActionCurrent, true}, - {TaskActionStart, true}, - {TaskActionComplete, true}, - {"invalid", false}, - {"", false}, - } - - for _, tt := range tests { - t.Run(string(tt.action), func(t *testing.T) { - if got := tt.action.IsValid(); got != tt.want { - t.Errorf("TaskAction(%q).IsValid() = %v, want %v", tt.action, got, tt.want) - } - }) - } -} - -func TestPlanAction_IsValid(t *testing.T) { - tests := []struct { - action PlanAction - want bool - }{ - {PlanActionClarify, true}, - {PlanActionDecompose, true}, - {PlanActionExpand, true}, - {PlanActionGenerate, true}, - {PlanActionFinalize, true}, - {PlanActionAudit, true}, - {"invalid", false}, - {"", false}, - } - - for _, tt := range tests { - t.Run(string(tt.action), func(t *testing.T) { - if got := tt.action.IsValid(); got != tt.want { - t.Errorf("PlanAction(%q).IsValid() = %v, want %v", tt.action, got, tt.want) - } - }) - } -} - -func TestValidCodeActions(t *testing.T) { - actions := ValidCodeActions() - if len(actions) != 6 { - t.Errorf("ValidCodeActions() returned %d actions, want 6", len(actions)) - } -} - -func TestValidTaskActions(t *testing.T) { - actions := ValidTaskActions() - if len(actions) != 4 { - t.Errorf("ValidTaskActions() returned %d actions, want 4", len(actions)) - } -} - -func TestValidPlanActions(t *testing.T) { - actions := ValidPlanActions() - if len(actions) != 6 { - t.Errorf("ValidPlanActions() returned %d actions, want 6", len(actions)) - } -} - -// === PlanID JSON Schema Tests === - -// TestTaskToolParams_PlanIDSnakeCase tests that plan_id is correctly unmarshaled. -func TestTaskToolParams_PlanIDSnakeCase(t *testing.T) { - jsonData := `{"action":"next","plan_id":"plan-123","session_id":"sess-456"}` - - var params TaskToolParams - if err := json.Unmarshal([]byte(jsonData), ¶ms); err != nil { - t.Fatalf("Failed to unmarshal: %v", err) - } - - if params.PlanID != "plan-123" { - t.Errorf("PlanID = %q, want %q", params.PlanID, "plan-123") - } -} - -// TestTaskToolParams_RejectLegacyPlanIDAlias tests that planId is rejected. -func TestTaskToolParams_RejectLegacyPlanIDAlias(t *testing.T) { - jsonData := `{"action":"next","planId":"plan-789","session_id":"sess-456"}` - - var params TaskToolParams - err := json.Unmarshal([]byte(jsonData), ¶ms) - if err == nil { - t.Fatal("expected unmarshal error for legacy planId") - } - if !strings.Contains(err.Error(), "planId") { - t.Fatalf("unexpected error: %v", err) - } -} - -// TestTaskToolParams_RejectWhenBothPlanIDFormsProvided ensures strict rejection when legacy key is present. -func TestTaskToolParams_RejectWhenBothPlanIDFormsProvided(t *testing.T) { - jsonData := `{"action":"next","plan_id":"plan-primary","planId":"plan-alias","session_id":"sess-456"}` - - var params TaskToolParams - err := json.Unmarshal([]byte(jsonData), ¶ms) - if err == nil { - t.Fatal("expected unmarshal error for legacy planId") - } - if !strings.Contains(err.Error(), "planId") { - t.Fatalf("unexpected error: %v", err) - } -} - -// TestPlanToolParams_PlanIDSnakeCase tests that plan_id is correctly unmarshaled. -func TestPlanToolParams_PlanIDSnakeCase(t *testing.T) { - jsonData := `{"action":"audit","plan_id":"plan-123"}` - - var params PlanToolParams - if err := json.Unmarshal([]byte(jsonData), ¶ms); err != nil { - t.Fatalf("Failed to unmarshal: %v", err) - } - - if params.PlanID != "plan-123" { - t.Errorf("PlanID = %q, want %q", params.PlanID, "plan-123") - } -} - -// TestPlanToolParams_RejectLegacyPlanIDAlias tests that planId is rejected. -func TestPlanToolParams_RejectLegacyPlanIDAlias(t *testing.T) { - jsonData := `{"action":"audit","planId":"plan-789"}` - - var params PlanToolParams - err := json.Unmarshal([]byte(jsonData), ¶ms) - if err == nil { - t.Fatal("expected unmarshal error for legacy planId") - } - if !strings.Contains(err.Error(), "planId") { - t.Fatalf("unexpected error: %v", err) - } -} - -// TestPlanToolParams_RejectWhenBothPlanIDFormsProvided ensures strict rejection when legacy key is present. -func TestPlanToolParams_RejectWhenBothPlanIDFormsProvided(t *testing.T) { - jsonData := `{"action":"audit","plan_id":"plan-primary","planId":"plan-alias"}` - - var params PlanToolParams - err := json.Unmarshal([]byte(jsonData), ¶ms) - if err == nil { - t.Fatal("expected unmarshal error for legacy planId") - } - if !strings.Contains(err.Error(), "planId") { - t.Fatalf("unexpected error: %v", err) - } -} - -// TestMCPPlanIDEmptyValues tests edge cases with empty values. -func TestMCPPlanIDEmptyValues(t *testing.T) { - tests := []struct { - name string - jsonData string - wantPlanID string - }{ - { - name: "empty plan_id", - jsonData: `{"action":"next","plan_id":"","session_id":"sess-1"}`, - wantPlanID: "", - }, - { - name: "null plan_id", - jsonData: `{"action":"next","plan_id":null,"session_id":"sess-1"}`, - wantPlanID: "", - }, - { - name: "both missing", - jsonData: `{"action":"next","session_id":"sess-1"}`, - wantPlanID: "", - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - var params TaskToolParams - if err := json.Unmarshal([]byte(tc.jsonData), ¶ms); err != nil { - t.Fatalf("Failed to unmarshal: %v", err) - } - - if params.PlanID != tc.wantPlanID { - t.Errorf("PlanID = %q, want %q", params.PlanID, tc.wantPlanID) - } - }) - } -} - -func TestMCPPlanIDLegacyAliasRejected(t *testing.T) { - tests := []struct { - name string - jsonData string - }{ - { - name: "task tool rejects planId", - jsonData: `{"action":"next","planId":"plan-fallback","session_id":"sess-1"}`, - }, - { - name: "plan tool rejects planId", - jsonData: `{"action":"audit","planId":"plan-fallback"}`, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - if strings.Contains(tc.name, "task tool") { - var params TaskToolParams - err := json.Unmarshal([]byte(tc.jsonData), ¶ms) - if err == nil || !strings.Contains(err.Error(), "planId") { - t.Fatalf("expected planId rejection, got: %v", err) - } - return - } - - var params PlanToolParams - err := json.Unmarshal([]byte(tc.jsonData), ¶ms) - if err == nil || !strings.Contains(err.Error(), "planId") { - t.Fatalf("expected planId rejection, got: %v", err) - } - }) - } -} - -// TestMCPParamsPreserveOtherFields ensures that custom UnmarshalJSON preserves other fields. -func TestMCPParamsPreserveOtherFields(t *testing.T) { - t.Run("TaskToolParams", func(t *testing.T) { - jsonData := `{"action":"complete","task_id":"task-abc","session_id":"sess-xyz","summary":"Done","files_modified":["a.go","b.go"]}` - - var params TaskToolParams - if err := json.Unmarshal([]byte(jsonData), ¶ms); err != nil { - t.Fatalf("Failed to unmarshal: %v", err) - } - - if params.Action != TaskActionComplete { - t.Errorf("Action = %q, want %q", params.Action, TaskActionComplete) - } - if params.TaskID != "task-abc" { - t.Errorf("TaskID = %q, want %q", params.TaskID, "task-abc") - } - if params.SessionID != "sess-xyz" { - t.Errorf("SessionID = %q, want %q", params.SessionID, "sess-xyz") - } - if params.Summary != "Done" { - t.Errorf("Summary = %q, want %q", params.Summary, "Done") - } - if len(params.FilesModified) != 2 { - t.Errorf("FilesModified length = %d, want 2", len(params.FilesModified)) - } - }) - - t.Run("PlanToolParams", func(t *testing.T) { - jsonData := `{"action":"generate","goal":"Add auth","enriched_goal":"Full auth spec","auto_answer":true}` - - var params PlanToolParams - if err := json.Unmarshal([]byte(jsonData), ¶ms); err != nil { - t.Fatalf("Failed to unmarshal: %v", err) - } - - if params.Action != PlanActionGenerate { - t.Errorf("Action = %q, want %q", params.Action, PlanActionGenerate) - } - if params.Goal != "Add auth" { - t.Errorf("Goal = %q, want %q", params.Goal, "Add auth") - } - if params.EnrichedGoal != "Full auth spec" { - t.Errorf("EnrichedGoal = %q, want %q", params.EnrichedGoal, "Full auth spec") - } - if !params.AutoAnswer { - t.Errorf("AutoAnswer = %v, want true", params.AutoAnswer) - } - }) - -} - -func TestPlanToolParams_ClarifySessionAndAnswers(t *testing.T) { - jsonData := `{ - "action":"clarify", - "clarify_session_id":"clarify-123", - "answers":[ - {"question":"Target users?","answer":"Backend team"}, - {"answer":"Must support monorepo"} - ] - }` - - var params PlanToolParams - if err := json.Unmarshal([]byte(jsonData), ¶ms); err != nil { - t.Fatalf("failed to unmarshal: %v", err) - } - - if params.ClarifySessionID != "clarify-123" { - t.Fatalf("ClarifySessionID = %q, want %q", params.ClarifySessionID, "clarify-123") - } - if got := len(params.Answers); got != 2 { - t.Fatalf("answers len = %d, want 2", got) - } - if params.Answers[0].Question != "Target users?" || params.Answers[0].Answer != "Backend team" { - t.Fatalf("first answer mismatch: %+v", params.Answers[0]) - } - if params.Answers[1].Question != "" || params.Answers[1].Answer != "Must support monorepo" { - t.Fatalf("second answer mismatch: %+v", params.Answers[1]) - } -} - -func TestPlanToolParams_RejectsLegacyHistory(t *testing.T) { - jsonData := `{ - "action":"clarify", - "goal":"Refactor API", - "history":"Q: old? A: yes" - }` - - var params PlanToolParams - err := json.Unmarshal([]byte(jsonData), ¶ms) - if err == nil { - t.Fatal("expected unmarshal error for legacy history field") - } - if got := err.Error(); got == "" || !containsAll(got, []string{"history", "clarify_session_id", "answers"}) { - t.Fatalf("unexpected error: %v", err) - } -} - -func containsAll(s string, terms []string) bool { - for _, term := range terms { - if !strings.Contains(s, term) { - return false - } - } - return true -} diff --git a/internal/mcpcfg/naming_test.go b/internal/mcpcfg/naming_test.go deleted file mode 100644 index f24a2be..0000000 --- a/internal/mcpcfg/naming_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package mcpcfg - -import "testing" - -func TestServerNameClassification(t *testing.T) { - tests := []struct { - name string - serverName string - canonical bool - legacy bool - }{ - {name: "canonical", serverName: "taskwing-mcp", canonical: true, legacy: false}, - {name: "legacy bare", serverName: "taskwing", canonical: false, legacy: true}, - {name: "legacy suffixed", serverName: "taskwing-mcp-my-project", canonical: false, legacy: true}, - {name: "non taskwing", serverName: "other-mcp", canonical: false, legacy: false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := IsCanonicalServerName(tt.serverName); got != tt.canonical { - t.Fatalf("IsCanonicalServerName(%q) = %v, want %v", tt.serverName, got, tt.canonical) - } - if got := IsLegacyServerName(tt.serverName); got != tt.legacy { - t.Fatalf("IsLegacyServerName(%q) = %v, want %v", tt.serverName, got, tt.legacy) - } - }) - } -} - -func TestExtractTaskWingServerNames(t *testing.T) { - output := ` -taskwing-mcp: /usr/local/bin/taskwing mcp -taskwing-mcp-my-project: /usr/local/bin/taskwing mcp -other-mcp: /usr/local/bin/other mcp -` - - names := ExtractTaskWingServerNames(output) - if len(names) == 0 { - t.Fatal("expected extracted names") - } - - if !ContainsCanonicalServerName(output) { - t.Fatal("expected canonical server name detection") - } - if !ContainsLegacyServerName(output) { - t.Fatal("expected legacy server name detection") - } -} diff --git a/internal/memory/migration_test.go b/internal/memory/migration_test.go deleted file mode 100644 index 2253491..0000000 --- a/internal/memory/migration_test.go +++ /dev/null @@ -1,607 +0,0 @@ -package memory - -import ( - "os" - "testing" -) - -// TestUpdateNodeWorkspace_Success tests that workspace updates work correctly. -func TestUpdateNodeWorkspace_Success(t *testing.T) { - // Create a temporary directory for the test - tmpDir, err := os.MkdirTemp("", "taskwing-migration-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - // Initialize repository - repo, err := NewDefaultRepository(tmpDir) - if err != nil { - t.Fatalf("failed to create repository: %v", err) - } - defer func() { _ = repo.Close() }() - - // Create a test node - testNode := &Node{ - ID: "test-node-001", - Content: "Test content for migration", - Type: NodeTypeDecision, - Summary: "Test decision", - Workspace: "root", - } - if err := repo.CreateNode(testNode); err != nil { - t.Fatalf("failed to create node: %v", err) - } - - // Update workspace - if err := repo.UpdateNodeWorkspace("test-node-001", "osprey"); err != nil { - t.Fatalf("UpdateNodeWorkspace failed: %v", err) - } - - // Verify update - updated, err := repo.GetNode("test-node-001") - if err != nil { - t.Fatalf("GetNode failed: %v", err) - } - - if updated.Workspace != "osprey" { - t.Errorf("workspace = %q, want %q", updated.Workspace, "osprey") - } -} - -// TestUpdateNodeWorkspace_NotFound tests error handling for non-existent nodes. -func TestUpdateNodeWorkspace_NotFound(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "taskwing-migration-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - repo, err := NewDefaultRepository(tmpDir) - if err != nil { - t.Fatalf("failed to create repository: %v", err) - } - defer func() { _ = repo.Close() }() - - // Try to update a non-existent node - err = repo.UpdateNodeWorkspace("nonexistent-node", "osprey") - if err == nil { - t.Fatal("expected error for non-existent node") - } -} - -// TestWorkspaceDefaultsToRoot tests that new nodes default to 'root' workspace. -func TestWorkspaceDefaultsToRoot(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "taskwing-migration-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - repo, err := NewDefaultRepository(tmpDir) - if err != nil { - t.Fatalf("failed to create repository: %v", err) - } - defer func() { _ = repo.Close() }() - - // Create a node without explicit workspace - testNode := &Node{ - ID: "test-node-002", - Content: "Test content", - Type: NodeTypeDecision, - Summary: "Test", - // Workspace not set - should default to "root" - } - if err := repo.CreateNode(testNode); err != nil { - t.Fatalf("failed to create node: %v", err) - } - - // Verify default - node, err := repo.GetNode("test-node-002") - if err != nil { - t.Fatalf("GetNode failed: %v", err) - } - - // Empty workspace should be treated as root by the application - // The DB stores empty string, but business logic treats it as "root" - if node.Workspace != "" && node.Workspace != "root" { - t.Errorf("workspace = %q, want empty or 'root'", node.Workspace) - } -} - -// TestListNodesFiltered_ByWorkspace tests workspace filtering in ListNodesFiltered. -func TestListNodesFiltered_ByWorkspace(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "taskwing-migration-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - repo, err := NewDefaultRepository(tmpDir) - if err != nil { - t.Fatalf("failed to create repository: %v", err) - } - defer func() { _ = repo.Close() }() - - // Create nodes in different workspaces - nodes := []Node{ - {ID: "node-root-1", Content: "Root content 1", Type: NodeTypeDecision, Summary: "Root 1", Workspace: "root"}, - {ID: "node-root-2", Content: "Root content 2", Type: NodeTypePattern, Summary: "Root 2", Workspace: "root"}, - {ID: "node-osprey-1", Content: "Osprey content", Type: NodeTypeDecision, Summary: "Osprey", Workspace: "osprey"}, - {ID: "node-studio-1", Content: "Studio content", Type: NodeTypeFeature, Summary: "Studio", Workspace: "studio"}, - } - - for _, n := range nodes { - node := n // capture - if err := repo.CreateNode(&node); err != nil { - t.Fatalf("failed to create node %s: %v", n.ID, err) - } - } - - // Test: Filter by workspace "osprey" - filter := NodeFilter{Workspace: "osprey"} - filtered, err := repo.ListNodesFiltered(filter) - if err != nil { - t.Fatalf("ListNodesFiltered failed: %v", err) - } - - // Currently placeholder returns all nodes - this test documents expected behavior - // When filtering is implemented, this should return only osprey nodes - if len(filtered) == 0 { - t.Error("ListNodesFiltered returned no nodes") - } - - // Test: Default filter returns all nodes - defaultFilter := DefaultNodeFilter() - all, err := repo.ListNodesFiltered(defaultFilter) - if err != nil { - t.Fatalf("ListNodesFiltered with default filter failed: %v", err) - } - - if len(all) != 4 { - t.Errorf("default filter returned %d nodes, want 4", len(all)) - } -} - -// TestMarkdownMirrorAfterWorkspaceUpdate tests that markdown mirror can be rebuilt. -func TestMarkdownMirrorAfterWorkspaceUpdate(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "taskwing-migration-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - repo, err := NewDefaultRepository(tmpDir) - if err != nil { - t.Fatalf("failed to create repository: %v", err) - } - defer func() { _ = repo.Close() }() - - // Create a feature node (features get markdown files) - testNode := &Node{ - ID: "test-feature-001", - Content: "Test feature content", - Type: NodeTypeFeature, - Summary: "Test Feature", - Workspace: "root", - } - if err := repo.CreateNode(testNode); err != nil { - t.Fatalf("failed to create node: %v", err) - } - - // Update workspace - if err := repo.UpdateNodeWorkspace("test-feature-001", "osprey"); err != nil { - t.Fatalf("UpdateNodeWorkspace failed: %v", err) - } - - // Verify node was updated - node, err := repo.GetNode("test-feature-001") - if err != nil { - t.Fatalf("GetNode failed: %v", err) - } - if node.Workspace != "osprey" { - t.Fatalf("expected workspace 'osprey', got %q", node.Workspace) - } -} - -// TestNodeFilter_DefaultValues tests that DefaultNodeFilter returns expected values. -func TestNodeFilter_DefaultValues(t *testing.T) { - filter := DefaultNodeFilter() - - if filter.Type != "" { - t.Errorf("Type = %q, want empty", filter.Type) - } - if filter.Workspace != "" { - t.Errorf("Workspace = %q, want empty", filter.Workspace) - } - if !filter.IncludeRoot { - t.Error("IncludeRoot = false, want true") - } -} - -// TestSearchFTSFiltered_Workspace tests workspace filtering for full-text search. -func TestSearchFTSFiltered_Workspace(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "taskwing-fts-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - repo, err := NewDefaultRepository(tmpDir) - if err != nil { - t.Fatalf("failed to create repository: %v", err) - } - defer func() { _ = repo.Close() }() - - // Create nodes in different workspaces with searchable content - nodes := []Node{ - {ID: "n-root-auth", Content: "Authentication system using JWT tokens", Type: NodeTypeDecision, Summary: "JWT Auth", Workspace: "root"}, - {ID: "n-osprey-auth", Content: "Authentication middleware for Osprey service", Type: NodeTypeDecision, Summary: "Osprey Auth", Workspace: "osprey"}, - {ID: "n-studio-auth", Content: "Authentication flow for Studio app", Type: NodeTypePattern, Summary: "Studio Auth", Workspace: "studio"}, - {ID: "n-root-db", Content: "Database connection pooling strategy", Type: NodeTypeDecision, Summary: "DB Pool", Workspace: "root"}, - } - - for _, n := range nodes { - node := n - if err := repo.CreateNode(&node); err != nil { - t.Fatalf("failed to create node %s: %v", n.ID, err) - } - } - - tests := []struct { - name string - query string - filter NodeFilter - wantMinimum int // At least this many results expected - wantIDs []string - notWantIDs []string - }{ - { - name: "no filter returns all auth nodes", - query: "authentication", - filter: NodeFilter{}, - wantMinimum: 3, - }, - { - name: "osprey workspace only", - query: "authentication", - filter: NodeFilter{Workspace: "osprey", IncludeRoot: false}, - wantIDs: []string{"n-osprey-auth"}, - notWantIDs: []string{"n-root-auth", "n-studio-auth"}, - }, - { - name: "osprey workspace with root", - query: "authentication", - filter: NodeFilter{Workspace: "osprey", IncludeRoot: true}, - wantMinimum: 2, - wantIDs: []string{"n-osprey-auth", "n-root-auth"}, - notWantIDs: []string{"n-studio-auth"}, - }, - { - name: "root workspace only", - query: "authentication", - filter: NodeFilter{Workspace: "root", IncludeRoot: false}, - wantIDs: []string{"n-root-auth"}, - notWantIDs: []string{"n-osprey-auth", "n-studio-auth"}, - }, - { - name: "nonexistent workspace returns empty", - query: "authentication", - filter: NodeFilter{Workspace: "nonexistent", IncludeRoot: false}, - wantMinimum: 0, - }, - { - name: "nonexistent workspace with root returns root only", - query: "authentication", - filter: NodeFilter{Workspace: "nonexistent", IncludeRoot: true}, - wantIDs: []string{"n-root-auth"}, - notWantIDs: []string{"n-osprey-auth", "n-studio-auth"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - results, err := repo.SearchFTSFiltered(tt.query, 10, tt.filter) - if err != nil { - t.Fatalf("SearchFTSFiltered failed: %v", err) - } - - // Check minimum count - if tt.wantMinimum > 0 && len(results) < tt.wantMinimum { - t.Errorf("got %d results, want at least %d", len(results), tt.wantMinimum) - } - - // Build ID set for checking - gotIDs := make(map[string]bool) - for _, r := range results { - gotIDs[r.Node.ID] = true - } - - // Check expected IDs - for _, wantID := range tt.wantIDs { - if !gotIDs[wantID] { - t.Errorf("expected result %s not found", wantID) - } - } - - // Check excluded IDs - for _, notWantID := range tt.notWantIDs { - if gotIDs[notWantID] { - t.Errorf("unexpected result %s found", notWantID) - } - } - }) - } -} - -// TestListNodesFiltered_WorkspaceWithType tests workspace + type filtering. -func TestListNodesFiltered_WorkspaceWithType(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "taskwing-filter-type-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - repo, err := NewDefaultRepository(tmpDir) - if err != nil { - t.Fatalf("failed to create repository: %v", err) - } - defer func() { _ = repo.Close() }() - - // Create nodes in different workspaces with different types - nodes := []Node{ - {ID: "n-root-dec", Content: "Root decision", Type: NodeTypeDecision, Summary: "Root Dec", Workspace: "root"}, - {ID: "n-root-pat", Content: "Root pattern", Type: NodeTypePattern, Summary: "Root Pat", Workspace: "root"}, - {ID: "n-osprey-dec", Content: "Osprey decision", Type: NodeTypeDecision, Summary: "Osprey Dec", Workspace: "osprey"}, - {ID: "n-osprey-feat", Content: "Osprey feature", Type: NodeTypeFeature, Summary: "Osprey Feat", Workspace: "osprey"}, - } - - for _, n := range nodes { - node := n - if err := repo.CreateNode(&node); err != nil { - t.Fatalf("failed to create node %s: %v", n.ID, err) - } - } - - tests := []struct { - name string - filter NodeFilter - wantCount int - wantIDs []string - }{ - { - name: "osprey decisions only", - filter: NodeFilter{Workspace: "osprey", Type: NodeTypeDecision, IncludeRoot: false}, - wantCount: 1, - wantIDs: []string{"n-osprey-dec"}, - }, - { - name: "osprey decisions with root", - filter: NodeFilter{Workspace: "osprey", Type: NodeTypeDecision, IncludeRoot: true}, - wantCount: 2, - wantIDs: []string{"n-osprey-dec", "n-root-dec"}, - }, - { - name: "root patterns only", - filter: NodeFilter{Workspace: "root", Type: NodeTypePattern, IncludeRoot: false}, - wantCount: 1, - wantIDs: []string{"n-root-pat"}, - }, - { - name: "empty workspace returns all of type", - filter: NodeFilter{Type: NodeTypeDecision}, - wantCount: 2, - wantIDs: []string{"n-root-dec", "n-osprey-dec"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - nodes, err := repo.ListNodesFiltered(tt.filter) - if err != nil { - t.Fatalf("ListNodesFiltered failed: %v", err) - } - - if len(nodes) != tt.wantCount { - t.Errorf("got %d nodes, want %d", len(nodes), tt.wantCount) - } - - gotIDs := make(map[string]bool) - for _, n := range nodes { - gotIDs[n.ID] = true - } - - for _, wantID := range tt.wantIDs { - if !gotIDs[wantID] { - t.Errorf("expected node %s not found", wantID) - } - } - }) - } -} - -// TestListNodesFiltered_IncludeRootBehavior tests the IncludeRoot flag in detail. -func TestListNodesFiltered_IncludeRootBehavior(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "taskwing-include-root-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - repo, err := NewDefaultRepository(tmpDir) - if err != nil { - t.Fatalf("failed to create repository: %v", err) - } - defer func() { _ = repo.Close() }() - - // Create a mix of workspaces including empty (legacy nodes) - nodes := []Node{ - {ID: "n-explicit-root", Content: "Explicit root", Type: NodeTypeDecision, Summary: "Explicit", Workspace: "root"}, - {ID: "n-empty-ws", Content: "Empty workspace", Type: NodeTypeDecision, Summary: "Empty", Workspace: ""}, - {ID: "n-osprey", Content: "Osprey node", Type: NodeTypeDecision, Summary: "Osprey", Workspace: "osprey"}, - } - - for _, n := range nodes { - node := n - if err := repo.CreateNode(&node); err != nil { - t.Fatalf("failed to create node %s: %v", n.ID, err) - } - } - - // IncludeRoot=true should include both "root" and "" (empty) workspaces - t.Run("include root gets explicit root and empty", func(t *testing.T) { - filter := NodeFilter{Workspace: "osprey", IncludeRoot: true} - nodes, err := repo.ListNodesFiltered(filter) - if err != nil { - t.Fatalf("ListNodesFiltered failed: %v", err) - } - - // Should have osprey + root + empty = 3 nodes - if len(nodes) != 3 { - t.Errorf("got %d nodes, want 3 (osprey + root + empty)", len(nodes)) - } - - gotIDs := make(map[string]bool) - for _, n := range nodes { - gotIDs[n.ID] = true - } - - if !gotIDs["n-explicit-root"] { - t.Error("missing n-explicit-root") - } - if !gotIDs["n-empty-ws"] { - t.Error("missing n-empty-ws (empty workspace should be treated as root)") - } - if !gotIDs["n-osprey"] { - t.Error("missing n-osprey") - } - }) - - // IncludeRoot=false should only get osprey - t.Run("exclude root gets only specified workspace", func(t *testing.T) { - filter := NodeFilter{Workspace: "osprey", IncludeRoot: false} - nodes, err := repo.ListNodesFiltered(filter) - if err != nil { - t.Fatalf("ListNodesFiltered failed: %v", err) - } - - if len(nodes) != 1 { - t.Errorf("got %d nodes, want 1 (only osprey)", len(nodes)) - } - - if len(nodes) > 0 && nodes[0].ID != "n-osprey" { - t.Errorf("expected n-osprey, got %s", nodes[0].ID) - } - }) -} - -// TestListNodesWithEmbeddingsFiltered_Workspace tests workspace filtering for nodes with embeddings. -func TestListNodesWithEmbeddingsFiltered_Workspace(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "taskwing-embeddings-filter-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - repo, err := NewDefaultRepository(tmpDir) - if err != nil { - t.Fatalf("failed to create repository: %v", err) - } - defer func() { _ = repo.Close() }() - - // Create nodes with embeddings in different workspaces - // Note: We need to manually set embeddings since CreateNode doesn't generate them - embedding := make([]float32, 4) // Small test embedding - for i := range embedding { - embedding[i] = float32(i) * 0.1 - } - - nodes := []struct { - node Node - embedding []float32 - }{ - { - node: Node{ID: "n-root-emb", Content: "Root with embedding", Type: NodeTypeDecision, Summary: "Root Emb", Workspace: "root"}, - embedding: embedding, - }, - { - node: Node{ID: "n-osprey-emb", Content: "Osprey with embedding", Type: NodeTypeDecision, Summary: "Osprey Emb", Workspace: "osprey"}, - embedding: embedding, - }, - { - node: Node{ID: "n-studio-emb", Content: "Studio with embedding", Type: NodeTypePattern, Summary: "Studio Emb", Workspace: "studio"}, - embedding: embedding, - }, - { - node: Node{ID: "n-no-emb", Content: "No embedding", Type: NodeTypeDecision, Summary: "No Emb", Workspace: "osprey"}, - embedding: nil, // No embedding - }, - } - - for _, n := range nodes { - node := n.node - node.Embedding = n.embedding - if err := repo.CreateNode(&node); err != nil { - t.Fatalf("failed to create node %s: %v", n.node.ID, err) - } - } - - tests := []struct { - name string - filter NodeFilter - wantCount int - wantIDs []string - notWantIDs []string - }{ - { - name: "no filter returns all with embeddings", - filter: NodeFilter{}, - wantCount: 3, - notWantIDs: []string{"n-no-emb"}, - }, - { - name: "osprey workspace only", - filter: NodeFilter{Workspace: "osprey", IncludeRoot: false}, - wantCount: 1, - wantIDs: []string{"n-osprey-emb"}, - notWantIDs: []string{"n-root-emb", "n-studio-emb", "n-no-emb"}, - }, - { - name: "osprey workspace with root", - filter: NodeFilter{Workspace: "osprey", IncludeRoot: true}, - wantCount: 2, - wantIDs: []string{"n-osprey-emb", "n-root-emb"}, - notWantIDs: []string{"n-studio-emb", "n-no-emb"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - nodes, err := repo.ListNodesWithEmbeddingsFiltered(tt.filter) - if err != nil { - t.Fatalf("ListNodesWithEmbeddingsFiltered failed: %v", err) - } - - if len(nodes) != tt.wantCount { - t.Errorf("got %d nodes, want %d", len(nodes), tt.wantCount) - } - - gotIDs := make(map[string]bool) - for _, n := range nodes { - gotIDs[n.ID] = true - } - - for _, wantID := range tt.wantIDs { - if !gotIDs[wantID] { - t.Errorf("expected node %s not found", wantID) - } - } - - for _, notWantID := range tt.notWantIDs { - if gotIDs[notWantID] { - t.Errorf("unexpected node %s found", notWantID) - } - } - }) - } -} diff --git a/internal/memory/models.go b/internal/memory/models.go index 87c507b..0ba7c16 100644 --- a/internal/memory/models.go +++ b/internal/memory/models.go @@ -42,7 +42,7 @@ type Node struct { // Debt Classification fields (v2.2+) // Distinguishes essential complexity from accidental complexity (technical debt). - // When AI recalls context, high-debt patterns include warnings to prevent propagation. + // When AI retrieves context, high-debt patterns include warnings to prevent propagation. // DebtScore indicates how much this represents technical debt (0.0 = clean, 1.0 = pure debt) DebtScore float64 `json:"debtScore,omitempty"` diff --git a/internal/memory/models_test.go b/internal/memory/models_test.go deleted file mode 100644 index 1b7b318..0000000 --- a/internal/memory/models_test.go +++ /dev/null @@ -1,159 +0,0 @@ -package memory - -import ( - "encoding/json" - "testing" -) - -func TestNode_Text_PlainText(t *testing.T) { - n := Node{Content: "some plain text content"} - if got := n.Text(); got != "some plain text content" { - t.Errorf("Text() = %q, want %q", got, "some plain text content") - } -} - -func TestNode_Text_EmptyContent(t *testing.T) { - n := Node{Content: ""} - if got := n.Text(); got != "" { - t.Errorf("Text() = %q, want empty string", got) - } -} - -func TestNode_Text_StructuredContent(t *testing.T) { - sc := StructuredContent{ - Title: "SQLite as primary store", - Description: "Chose SQLite for local-first persistence", - Why: "Embedded, zero-config, good enough perf", - Tradeoffs: "No concurrent writes", - } - content, _ := json.Marshal(sc) - n := Node{Content: string(content)} - - got := n.Text() - want := "SQLite as primary store\nChose SQLite for local-first persistence\n\nWhy: Embedded, zero-config, good enough perf\nTradeoffs: No concurrent writes" - if got != want { - t.Errorf("Text() =\n%q\nwant\n%q", got, want) - } -} - -func TestNode_Text_StructuredWithSnippets(t *testing.T) { - sc := StructuredContent{ - Title: "Repository pattern", - Description: "All data access goes through Repository interface", - Snippets: []EvidenceSnippet{ - {FilePath: "internal/memory/store.go", Lines: "10-25", Code: "type Repository interface{...}"}, - }, - } - content, _ := json.Marshal(sc) - n := Node{Content: string(content)} - - got := n.Text() - if got == "" { - t.Fatal("Text() returned empty string for structured content with snippets") - } - // Should contain the evidence section - if !contains(got, "Evidence:") { - t.Error("Text() missing Evidence section") - } - if !contains(got, "internal/memory/store.go:10-25") { - t.Error("Text() missing file:lines reference") - } -} - -func TestNode_Text_InvalidJSON(t *testing.T) { - n := Node{Content: "{invalid json}"} - if got := n.Text(); got != "{invalid json}" { - t.Errorf("Text() = %q, want passthrough for invalid JSON", got) - } -} - -func TestNode_Text_JSONWithoutTitle(t *testing.T) { - // Valid JSON but not a StructuredContent (no title) - n := Node{Content: `{"description":"no title here"}`} - if got := n.Text(); got != `{"description":"no title here"}` { - t.Errorf("Text() should return as-is when Title is empty, got %q", got) - } -} - -func TestNode_ParseStructuredContent_Roundtrip(t *testing.T) { - original := StructuredContent{ - Title: "Test title", - Description: "Test desc", - Why: "Test why", - Tradeoffs: "Test tradeoffs", - Snippets: []EvidenceSnippet{ - {FilePath: "foo.go", Lines: "1-10", Code: "func Foo() {}"}, - }, - } - data, err := json.Marshal(original) - if err != nil { - t.Fatalf("Marshal: %v", err) - } - - n := Node{Content: string(data)} - sc := n.ParseStructuredContent() - if sc == nil { - t.Fatal("ParseStructuredContent() returned nil") - } - if sc.Title != original.Title { - t.Errorf("Title = %q, want %q", sc.Title, original.Title) - } - if sc.Description != original.Description { - t.Errorf("Description = %q, want %q", sc.Description, original.Description) - } - if sc.Why != original.Why { - t.Errorf("Why = %q, want %q", sc.Why, original.Why) - } - if sc.Tradeoffs != original.Tradeoffs { - t.Errorf("Tradeoffs = %q, want %q", sc.Tradeoffs, original.Tradeoffs) - } - if len(sc.Snippets) != 1 { - t.Fatalf("Snippets len = %d, want 1", len(sc.Snippets)) - } - if sc.Snippets[0].Code != "func Foo() {}" { - t.Errorf("Snippet code = %q", sc.Snippets[0].Code) - } -} - -func TestNode_ParseStructuredContent_PlainText(t *testing.T) { - n := Node{Content: "just some text"} - if sc := n.ParseStructuredContent(); sc != nil { - t.Errorf("ParseStructuredContent() = %+v, want nil for plain text", sc) - } -} - -func TestNode_ParseStructuredContent_Empty(t *testing.T) { - n := Node{Content: ""} - if sc := n.ParseStructuredContent(); sc != nil { - t.Errorf("ParseStructuredContent() = %+v, want nil for empty", sc) - } -} - -func TestNode_Text_MinimalStructured(t *testing.T) { - // Only title + description, no optional fields - sc := StructuredContent{ - Title: "Minimal", - Description: "Just the basics", - } - content, _ := json.Marshal(sc) - n := Node{Content: string(content)} - - got := n.Text() - want := "Minimal\nJust the basics" - if got != want { - t.Errorf("Text() = %q, want %q", got, want) - } -} - -func contains(s, substr string) bool { - return len(s) >= len(substr) && searchString(s, substr) -} - -func searchString(s, sub string) bool { - for i := 0; i <= len(s)-len(sub); i++ { - if s[i:i+len(sub)] == sub { - return true - } - } - return false -} diff --git a/internal/memory/rows_err_test.go b/internal/memory/rows_err_test.go deleted file mode 100644 index 61d7e83..0000000 --- a/internal/memory/rows_err_test.go +++ /dev/null @@ -1,753 +0,0 @@ -package memory - -import ( - "database/sql" - "os" - "strings" - "testing" - - "github.com/josephgoksu/TaskWing/internal/task" -) - -// TestCheckRowsErr_NilError tests that checkRowsErr returns nil when rows.Err() is nil. -func TestCheckRowsErr_NilError(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "taskwing-rows-err-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - store, err := NewSQLiteStore(tmpDir) - if err != nil { - t.Fatalf("failed to create store: %v", err) - } - defer func() { _ = store.Close() }() - - // Execute a simple query and iterate fully - rows, err := store.db.Query("SELECT 1") - if err != nil { - t.Fatalf("query failed: %v", err) - } - defer func() { _ = rows.Close() }() - - for rows.Next() { - var v int - if err := rows.Scan(&v); err != nil { - t.Fatalf("scan failed: %v", err) - } - } - - // After successful iteration, rows.Err() should be nil - err = checkRowsErr(rows) - if err != nil { - t.Errorf("checkRowsErr returned error for successful iteration: %v", err) - } -} - -// TestListPlans_ErrorPropagation tests that errors are propagated from ListPlans. -// This tests that a closed database connection causes proper error propagation. -func TestListPlans_ErrorPropagation(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "taskwing-rows-err-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - store, err := NewSQLiteStore(tmpDir) - if err != nil { - t.Fatalf("failed to create store: %v", err) - } - - // Create some test data - plan := &task.Plan{ - ID: "plan-test-001", - Goal: "Test plan for error propagation", - } - if err := store.CreatePlan(plan); err != nil { - t.Fatalf("failed to create plan: %v", err) - } - - // Close the database to force errors on subsequent queries - if err := store.Close(); err != nil { - t.Fatalf("failed to close store: %v", err) - } - - // Now listing plans should return an error (database closed) - _, err = store.ListPlans() - if err == nil { - t.Error("expected error when listing plans on closed database, got nil") - } -} - -// TestListTasks_ErrorPropagation tests that errors are propagated from ListTasks. -func TestListTasks_ErrorPropagation(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "taskwing-rows-err-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - store, err := NewSQLiteStore(tmpDir) - if err != nil { - t.Fatalf("failed to create store: %v", err) - } - - // Create a test plan first - plan := &task.Plan{ - ID: "plan-test-002", - Goal: "Test plan", - } - if err := store.CreatePlan(plan); err != nil { - t.Fatalf("failed to create plan: %v", err) - } - - // Create a test task - testTask := &task.Task{ - ID: "task-test-001", - PlanID: "plan-test-002", - Title: "Test task", - } - if err := store.CreateTask(testTask); err != nil { - t.Fatalf("failed to create task: %v", err) - } - - // Close the database - if err := store.Close(); err != nil { - t.Fatalf("failed to close store: %v", err) - } - - // Now listing tasks should return an error - _, err = store.ListTasks("plan-test-002") - if err == nil { - t.Error("expected error when listing tasks on closed database, got nil") - } -} - -// TestListNodes_ErrorPropagation tests that errors are propagated from ListNodes. -func TestListNodes_ErrorPropagation(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "taskwing-rows-err-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - store, err := NewSQLiteStore(tmpDir) - if err != nil { - t.Fatalf("failed to create store: %v", err) - } - - // Create a test node - node := &Node{ - ID: "node-test-001", - Content: "Test content", - Type: NodeTypeDecision, - Summary: "Test node", - } - if err := store.CreateNode(node); err != nil { - t.Fatalf("failed to create node: %v", err) - } - - // Close the database - if err := store.Close(); err != nil { - t.Fatalf("failed to close store: %v", err) - } - - // Now listing nodes should return an error - _, err = store.ListNodes("") - if err == nil { - t.Error("expected error when listing nodes on closed database, got nil") - } -} - -// TestSearchPlans_ErrorPropagation tests that errors are propagated from SearchPlans. -func TestSearchPlans_ErrorPropagation(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "taskwing-rows-err-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - store, err := NewSQLiteStore(tmpDir) - if err != nil { - t.Fatalf("failed to create store: %v", err) - } - - // Close the database - if err := store.Close(); err != nil { - t.Fatalf("failed to close store: %v", err) - } - - // Now searching plans should return an error - _, err = store.SearchPlans("test", "") - if err == nil { - t.Error("expected error when searching plans on closed database, got nil") - } -} - -// TestGetNodeEdges_ErrorPropagation tests that errors are propagated from GetNodeEdges. -func TestGetNodeEdges_ErrorPropagation(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "taskwing-rows-err-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - store, err := NewSQLiteStore(tmpDir) - if err != nil { - t.Fatalf("failed to create store: %v", err) - } - - // Close the database - if err := store.Close(); err != nil { - t.Fatalf("failed to close store: %v", err) - } - - // Now getting node edges should return an error - _, err = store.GetNodeEdges("node-test-001") - if err == nil { - t.Error("expected error when getting node edges on closed database, got nil") - } -} - -// TestRowsErrPropagation_TableDriven uses table-driven tests for multiple functions. -func TestRowsErrPropagation_TableDriven(t *testing.T) { - tests := []struct { - name string - testFunc func(store *SQLiteStore) error - }{ - { - name: "ListPlans", - testFunc: func(store *SQLiteStore) error { - _, err := store.ListPlans() - return err - }, - }, - { - name: "ListNodes", - testFunc: func(store *SQLiteStore) error { - _, err := store.ListNodes("") - return err - }, - }, - { - name: "GetAllNodeEdges", - testFunc: func(store *SQLiteStore) error { - _, err := store.GetAllNodeEdges() - return err - }, - }, - { - name: "ListBootstrapStates", - testFunc: func(store *SQLiteStore) error { - _, err := store.ListBootstrapStates() - return err - }, - }, - { - name: "ListToolVersions", - testFunc: func(store *SQLiteStore) error { - _, err := store.ListToolVersions() - return err - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "taskwing-rows-err-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - store, err := NewSQLiteStore(tmpDir) - if err != nil { - t.Fatalf("failed to create store: %v", err) - } - - // Close the database to force errors - if err := store.Close(); err != nil { - t.Fatalf("failed to close store: %v", err) - } - - // The function should return an error on closed database - err = tc.testFunc(store) - if err == nil { - t.Errorf("%s should return error on closed database, got nil", tc.name) - } - }) - } -} - -// TestRowsErrPropagation_SuccessPath verifies no errors on successful iteration. -func TestRowsErrPropagation_SuccessPath(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "taskwing-rows-err-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - store, err := NewSQLiteStore(tmpDir) - if err != nil { - t.Fatalf("failed to create store: %v", err) - } - defer func() { _ = store.Close() }() - - // Create test data - plan := &task.Plan{ - ID: "plan-success-001", - Goal: "Test successful iteration", - } - if err := store.CreatePlan(plan); err != nil { - t.Fatalf("failed to create plan: %v", err) - } - - testTask := &task.Task{ - ID: "task-success-001", - PlanID: "plan-success-001", - Title: "Test task", - } - if err := store.CreateTask(testTask); err != nil { - t.Fatalf("failed to create task: %v", err) - } - - node := &Node{ - ID: "node-success-001", - Content: "Test content", - Type: NodeTypeDecision, - Summary: "Test node", - } - if err := store.CreateNode(node); err != nil { - t.Fatalf("failed to create node: %v", err) - } - - // All these operations should succeed without errors - t.Run("ListPlans", func(t *testing.T) { - plans, err := store.ListPlans() - if err != nil { - t.Errorf("ListPlans failed: %v", err) - } - if len(plans) == 0 { - t.Error("expected at least one plan") - } - }) - - t.Run("ListTasks", func(t *testing.T) { - tasks, err := store.ListTasks("plan-success-001") - if err != nil { - t.Errorf("ListTasks failed: %v", err) - } - if len(tasks) == 0 { - t.Error("expected at least one task") - } - }) - - t.Run("ListNodes", func(t *testing.T) { - nodes, err := store.ListNodes("") - if err != nil { - t.Errorf("ListNodes failed: %v", err) - } - if len(nodes) == 0 { - t.Error("expected at least one node") - } - }) - - t.Run("Check", func(t *testing.T) { - _, err := store.Check() - if err != nil { - t.Errorf("Check failed: %v", err) - } - }) -} - -// TestCheckRowsErr_ReturnsWrappedError tests that checkRowsErr wraps errors properly. -func TestCheckRowsErr_ReturnsWrappedError(t *testing.T) { - // We can't easily trigger a real rows.Err() in SQLite without network issues, - // but we can verify the function signature and behavior with a mock rows. - // This test verifies the helper function exists and returns nil for successful rows. - - tmpDir, err := os.MkdirTemp("", "taskwing-rows-err-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - store, err := NewSQLiteStore(tmpDir) - if err != nil { - t.Fatalf("failed to create store: %v", err) - } - defer func() { _ = store.Close() }() - - // Create and fully iterate rows - rows, err := store.db.Query("SELECT 1 UNION SELECT 2 UNION SELECT 3") - if err != nil { - t.Fatalf("query failed: %v", err) - } - - count := 0 - for rows.Next() { - var v int - if err := rows.Scan(&v); err != nil { - t.Fatalf("scan failed: %v", err) - } - count++ - } - _ = rows.Close() - - if count != 3 { - t.Errorf("expected 3 rows, got %d", count) - } - - // After full iteration and close, Err() should be nil - // Note: calling Err() after Close() is implementation-dependent but safe - if rows.Err() != nil { - t.Errorf("unexpected error after successful iteration: %v", rows.Err()) - } -} - -// TestErrorMessageContainsContext verifies error messages include context. -func TestErrorMessageContainsContext(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "taskwing-rows-err-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - store, err := NewSQLiteStore(tmpDir) - if err != nil { - t.Fatalf("failed to create store: %v", err) - } - - // Close the database to force errors - _ = store.Close() - - // Test that error messages contain meaningful context - tests := []struct { - name string - testFunc func() error - contains string - }{ - { - name: "ListPlans", - testFunc: func() error { - _, err := store.ListPlans() - return err - }, - contains: "plan", // Should mention "plan" somewhere in error - }, - { - name: "ListNodes", - testFunc: func() error { - _, err := store.ListNodes("") - return err - }, - contains: "node", // Should mention "node" somewhere in error - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - err := tc.testFunc() - if err == nil { - t.Fatal("expected error, got nil") - } - errMsg := strings.ToLower(err.Error()) - if !strings.Contains(errMsg, tc.contains) { - t.Errorf("error message %q should contain %q", err.Error(), tc.contains) - } - }) - } -} - -// TestMockRowsErr demonstrates how rows.Err() works conceptually. -// In a real scenario with network databases, rows.Err() would capture -// errors that occur during iteration (like connection drops). -func TestMockRowsErr(t *testing.T) { - // This test documents the expected behavior of rows.Err(): - // - Returns nil if iteration completed successfully - // - Returns error if iteration was interrupted (e.g., network failure) - - tmpDir, err := os.MkdirTemp("", "taskwing-rows-err-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - store, err := NewSQLiteStore(tmpDir) - if err != nil { - t.Fatalf("failed to create store: %v", err) - } - defer func() { _ = store.Close() }() - - // Test 1: Successful iteration should have nil Err() - rows, err := store.db.Query("SELECT 1") - if err != nil { - t.Fatalf("query failed: %v", err) - } - - for rows.Next() { - var v int - if err := rows.Scan(&v); err != nil { - t.Fatalf("scan failed: %v", err) - } - } - - // Before Close, Err() should be nil for successful iteration - if err := checkRowsErr(rows); err != nil { - t.Errorf("expected nil error after successful iteration, got: %v", err) - } - _ = rows.Close() - - // Test 2: Query with no results should also have nil Err() - rows2, err := store.db.Query("SELECT 1 WHERE 1=0") // returns no rows - if err != nil { - t.Fatalf("query failed: %v", err) - } - - count := 0 - for rows2.Next() { - count++ - } - - if err := checkRowsErr(rows2); err != nil { - t.Errorf("expected nil error for empty result set, got: %v", err) - } - _ = rows2.Close() - - if count != 0 { - t.Errorf("expected 0 rows, got %d", count) - } -} - -// TestCheckRowsErrHelper_FunctionExists verifies the helper function works. -func TestCheckRowsErrHelper_FunctionExists(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "taskwing-rows-err-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - store, err := NewSQLiteStore(tmpDir) - if err != nil { - t.Fatalf("failed to create store: %v", err) - } - defer func() { _ = store.Close() }() - - // Create rows and verify checkRowsErr exists and works - rows, err := store.db.Query("SELECT 1") - if err != nil { - t.Fatalf("query failed: %v", err) - } - defer func() { _ = rows.Close() }() - - // Iterate - for rows.Next() { - var v int - if err := rows.Scan(&v); err != nil { - t.Fatalf("scan failed: %v", err) - } - } - - // Call checkRowsErr - should return nil - if err := checkRowsErr(rows); err != nil { - t.Errorf("checkRowsErr failed: %v", err) - } -} - -// Verify that sql package is imported and types are correct -var _ *sql.Rows // Ensures sql package is correctly imported - -// === Prefix Finder Tests === - -// TestFindTaskIDsByPrefix tests the task ID prefix finder. -func TestFindTaskIDsByPrefix(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "taskwing-prefix-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - store, err := NewSQLiteStore(tmpDir) - if err != nil { - t.Fatalf("failed to create store: %v", err) - } - defer func() { _ = store.Close() }() - - // Create a test plan - plan := &task.Plan{ID: "plan-prefix001", Goal: "Test plan"} - if err := store.CreatePlan(plan); err != nil { - t.Fatalf("failed to create plan: %v", err) - } - - // Create test tasks with various prefixes - tasks := []*task.Task{ - {ID: "task-abc11111", PlanID: "plan-prefix001", Title: "Task A1"}, - {ID: "task-abc22222", PlanID: "plan-prefix001", Title: "Task A2"}, - {ID: "task-abc33333", PlanID: "plan-prefix001", Title: "Task A3"}, - {ID: "task-xyz11111", PlanID: "plan-prefix001", Title: "Task X1"}, - } - for _, tsk := range tasks { - if err := store.CreateTask(tsk); err != nil { - t.Fatalf("failed to create task: %v", err) - } - } - - tests := []struct { - name string - prefix string - wantCount int - wantIDs []string - }{ - { - name: "full ID match", - prefix: "task-abc11111", - wantCount: 1, - wantIDs: []string{"task-abc11111"}, - }, - { - name: "prefix matches multiple", - prefix: "task-abc", - wantCount: 3, - }, - { - name: "prefix matches one", - prefix: "task-xyz", - wantCount: 1, - wantIDs: []string{"task-xyz11111"}, - }, - { - name: "prefix matches none", - prefix: "task-zzz", - wantCount: 0, - }, - { - name: "empty prefix matches all", - prefix: "", - wantCount: 4, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - ids, err := store.FindTaskIDsByPrefix(tc.prefix) - if err != nil { - t.Fatalf("FindTaskIDsByPrefix failed: %v", err) - } - if len(ids) != tc.wantCount { - t.Errorf("got %d IDs, want %d", len(ids), tc.wantCount) - } - if tc.wantIDs != nil { - for i, wantID := range tc.wantIDs { - if i >= len(ids) || ids[i] != wantID { - t.Errorf("ID[%d] = %q, want %q", i, ids[i], wantID) - } - } - } - }) - } -} - -// TestFindPlanIDsByPrefix tests the plan ID prefix finder. -func TestFindPlanIDsByPrefix(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "taskwing-prefix-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - store, err := NewSQLiteStore(tmpDir) - if err != nil { - t.Fatalf("failed to create store: %v", err) - } - defer func() { _ = store.Close() }() - - // Create test plans with various prefixes - plans := []*task.Plan{ - {ID: "plan-abc11111", Goal: "Plan A1"}, - {ID: "plan-abc22222", Goal: "Plan A2"}, - {ID: "plan-xyz11111", Goal: "Plan X1"}, - } - for _, p := range plans { - if err := store.CreatePlan(p); err != nil { - t.Fatalf("failed to create plan: %v", err) - } - } - - tests := []struct { - name string - prefix string - wantCount int - wantIDs []string - }{ - { - name: "full ID match", - prefix: "plan-abc11111", - wantCount: 1, - wantIDs: []string{"plan-abc11111"}, - }, - { - name: "prefix matches multiple", - prefix: "plan-abc", - wantCount: 2, - }, - { - name: "prefix matches one", - prefix: "plan-xyz", - wantCount: 1, - wantIDs: []string{"plan-xyz11111"}, - }, - { - name: "prefix matches none", - prefix: "plan-zzz", - wantCount: 0, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - ids, err := store.FindPlanIDsByPrefix(tc.prefix) - if err != nil { - t.Fatalf("FindPlanIDsByPrefix failed: %v", err) - } - if len(ids) != tc.wantCount { - t.Errorf("got %d IDs, want %d", len(ids), tc.wantCount) - } - if tc.wantIDs != nil { - for i, wantID := range tc.wantIDs { - if i >= len(ids) || ids[i] != wantID { - t.Errorf("ID[%d] = %q, want %q", i, ids[i], wantID) - } - } - } - }) - } -} - -// TestFindPrefixMethods_ClosedDB tests error propagation for prefix finders. -func TestFindPrefixMethods_ClosedDB(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "taskwing-prefix-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - store, err := NewSQLiteStore(tmpDir) - if err != nil { - t.Fatalf("failed to create store: %v", err) - } - _ = store.Close() // Close immediately - - t.Run("FindTaskIDsByPrefix", func(t *testing.T) { - _, err := store.FindTaskIDsByPrefix("task-") - if err == nil { - t.Error("expected error on closed database, got nil") - } - }) - - t.Run("FindPlanIDsByPrefix", func(t *testing.T) { - _, err := store.FindPlanIDsByPrefix("plan-") - if err == nil { - t.Error("expected error on closed database, got nil") - } - }) -} diff --git a/internal/memory/sqlite.go b/internal/memory/sqlite.go index 674ed1a..7bed8db 100644 --- a/internal/memory/sqlite.go +++ b/internal/memory/sqlite.go @@ -593,7 +593,7 @@ func (s *SQLiteStore) initSchema() error { }{ {"scope", "ALTER TABLE tasks ADD COLUMN scope TEXT"}, // e.g., "auth", "api", "vectorsearch" {"keywords", "ALTER TABLE tasks ADD COLUMN keywords TEXT"}, // JSON array of extracted keywords - {"suggested_recall_queries", "ALTER TABLE tasks ADD COLUMN suggested_recall_queries TEXT"}, // JSON array of pre-computed queries + {"suggested_ask_queries", "ALTER TABLE tasks ADD COLUMN suggested_ask_queries TEXT"}, // JSON array of pre-computed ask queries {"claimed_by", "ALTER TABLE tasks ADD COLUMN claimed_by TEXT"}, // Session ID that claimed this task {"claimed_at", "ALTER TABLE tasks ADD COLUMN claimed_at TEXT"}, // Timestamp when claimed {"completed_at", "ALTER TABLE tasks ADD COLUMN completed_at TEXT"}, // Timestamp when completed @@ -634,6 +634,10 @@ func (s *SQLiteStore) initSchema() error { } } + // Migration: Rename legacy column suggested_recall_queries → suggested_ask_queries. + // SQLite supports RENAME COLUMN since 3.25.0 (2018). Silently ignore if column doesn't exist. + _, _ = s.db.Exec(`ALTER TABLE tasks RENAME COLUMN suggested_recall_queries TO suggested_ask_queries`) + // Ensure index ordering matches task urgency semantics (lower number = higher urgency). // We drop/recreate to correct existing DBs that were created with DESC. _, _ = s.db.Exec(`DROP INDEX IF EXISTS idx_tasks_status_priority`) diff --git a/internal/memory/task_store.go b/internal/memory/task_store.go index 8dde3a8..9d87775 100644 --- a/internal/memory/task_store.go +++ b/internal/memory/task_store.go @@ -56,9 +56,9 @@ func insertTaskTx(tx txExecutor, t *task.Task) error { if err != nil { return fmt.Errorf("marshal keywords for task %s: %w", t.ID, err) } - queriesJSON, err := json.Marshal(t.SuggestedRecallQueries) + queriesJSON, err := json.Marshal(t.SuggestedAskQueries) if err != nil { - return fmt.Errorf("marshal suggested_recall_queries for task %s: %w", t.ID, err) + return fmt.Errorf("marshal suggested ask queries for task %s: %w", t.ID, err) } filesJSON, err := json.Marshal(t.FilesModified) if err != nil { @@ -84,7 +84,7 @@ func insertTaskTx(tx txExecutor, t *task.Task) error { id, plan_id, phase_id, title, description, acceptance_criteria, validation_steps, status, priority, complexity, assigned_agent, parent_task_id, context_summary, - scope, keywords, suggested_recall_queries, + scope, keywords, suggested_ask_queries, claimed_by, claimed_at, completed_at, completion_summary, files_modified, expected_files, created_at, updated_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) @@ -657,7 +657,7 @@ func scanTaskRow(row taskRowScanner) (task.Task, error) { _ = json.Unmarshal([]byte(keywordsJSON.String), &t.Keywords) } if queriesJSON.Valid && queriesJSON.String != "" { - _ = json.Unmarshal([]byte(queriesJSON.String), &t.SuggestedRecallQueries) + _ = json.Unmarshal([]byte(queriesJSON.String), &t.SuggestedAskQueries) } if filesJSON.Valid && filesJSON.String != "" { _ = json.Unmarshal([]byte(filesJSON.String), &t.FilesModified) @@ -674,7 +674,7 @@ func scanTaskRow(row taskRowScanner) (task.Task, error) { const taskSelectColumns = `id, plan_id, phase_id, title, description, acceptance_criteria, validation_steps, status, priority, complexity, assigned_agent, parent_task_id, context_summary, - scope, keywords, suggested_recall_queries, + scope, keywords, suggested_ask_queries, claimed_by, claimed_at, completed_at, completion_summary, files_modified, expected_files, git_baseline, created_at, updated_at` diff --git a/internal/memory/task_store_test.go b/internal/memory/task_store_test.go deleted file mode 100644 index 652632f..0000000 --- a/internal/memory/task_store_test.go +++ /dev/null @@ -1,224 +0,0 @@ -package memory - -import ( - "os" - "path/filepath" - "testing" - - "github.com/josephgoksu/TaskWing/internal/task" -) - -func TestListPlans_TaskCountNotPlaceholderSlice(t *testing.T) { - // Create a temporary database - tmpDir, err := os.MkdirTemp("", "taskwing-test-*") - if err != nil { - t.Fatalf("create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - dbPath := filepath.Join(tmpDir, "memory.db") - store, err := NewSQLiteStore(dbPath) - if err != nil { - t.Fatalf("create store: %v", err) - } - defer func() { _ = store.Close() }() - - // Create a plan - plan := &task.Plan{ - ID: "plan-test-123", - Goal: "Test goal", - EnrichedGoal: "Enriched test goal", - Status: task.PlanStatusActive, - } - if err := store.CreatePlan(plan); err != nil { - t.Fatalf("create plan: %v", err) - } - - // Create some tasks for this plan - tasks := []task.Task{ - {ID: "task-1", PlanID: plan.ID, Title: "Task 1", Description: "Desc 1", Status: task.StatusPending, Priority: 80}, - {ID: "task-2", PlanID: plan.ID, Title: "Task 2", Description: "Desc 2", Status: task.StatusCompleted, Priority: 70}, - {ID: "task-3", PlanID: plan.ID, Title: "Task 3", Description: "Desc 3", Status: task.StatusInProgress, Priority: 60}, - } - for _, tsk := range tasks { - if err := store.CreateTask(&tsk); err != nil { - t.Fatalf("create task %s: %v", tsk.ID, err) - } - } - - // List plans - plans, err := store.ListPlans() - if err != nil { - t.Fatalf("list plans: %v", err) - } - - if len(plans) != 1 { - t.Fatalf("expected 1 plan, got %d", len(plans)) - } - - p := plans[0] - - // Verify TaskCount is set correctly - if p.TaskCount != 3 { - t.Errorf("expected TaskCount=3, got %d", p.TaskCount) - } - - // Verify Tasks slice is nil (not placeholder slice) - if p.Tasks != nil { - t.Errorf("expected Tasks to be nil, got slice of length %d", len(p.Tasks)) - } - - // Verify GetTaskCount() returns correct value - if p.GetTaskCount() != 3 { - t.Errorf("expected GetTaskCount()=3, got %d", p.GetTaskCount()) - } - - // Verify iterating over Tasks doesn't yield misleading zero-value structs - for i, tsk := range p.Tasks { - // This loop should not execute since Tasks is nil - t.Errorf("unexpected task at index %d: %+v", i, tsk) - } -} - -func TestGetTaskCount_FallsBackToTasksLength(t *testing.T) { - // When TaskCount is 0 but Tasks are populated, use len(Tasks) - plan := &task.Plan{ - ID: "test-plan", - TaskCount: 0, // Not set - Tasks: []task.Task{ - {ID: "t1", Title: "Task 1"}, - {ID: "t2", Title: "Task 2"}, - }, - } - - if plan.GetTaskCount() != 2 { - t.Errorf("expected GetTaskCount()=2 (from len(Tasks)), got %d", plan.GetTaskCount()) - } -} - -func TestGetTaskCount_UsesTaskCountIfSet(t *testing.T) { - // When TaskCount is set, use it regardless of Tasks - plan := &task.Plan{ - ID: "test-plan", - TaskCount: 5, - Tasks: nil, // Not loaded - } - - if plan.GetTaskCount() != 5 { - t.Errorf("expected GetTaskCount()=5 (from TaskCount), got %d", plan.GetTaskCount()) - } -} - -func TestGetNextTask_SelectsLowestNumericPriority(t *testing.T) { - tmpDir := t.TempDir() - dbPath := filepath.Join(tmpDir, "memory.db") - - store, err := NewSQLiteStore(dbPath) - if err != nil { - t.Fatalf("create store: %v", err) - } - defer func() { _ = store.Close() }() - - plan := &task.Plan{ - ID: "plan-priority-next", - Goal: "Priority ordering test", - EnrichedGoal: "Priority ordering test", - Status: task.PlanStatusActive, - } - if err := store.CreatePlan(plan); err != nil { - t.Fatalf("create plan: %v", err) - } - - highUrgency := &task.Task{ - ID: "task-pri-10", - PlanID: plan.ID, - Title: "Priority 10", - Description: "Should be selected first", - Status: task.StatusPending, - Priority: 10, - } - lowUrgency := &task.Task{ - ID: "task-pri-90", - PlanID: plan.ID, - Title: "Priority 90", - Description: "Should be selected later", - Status: task.StatusPending, - Priority: 90, - } - - if err := store.CreateTask(lowUrgency); err != nil { - t.Fatalf("create task priority 90: %v", err) - } - if err := store.CreateTask(highUrgency); err != nil { - t.Fatalf("create task priority 10: %v", err) - } - - next, err := store.GetNextTask(plan.ID) - if err != nil { - t.Fatalf("get next task: %v", err) - } - if next == nil { - t.Fatal("expected next task, got nil") - } - if next.ID != highUrgency.ID { - t.Fatalf("expected next task %q, got %q", highUrgency.ID, next.ID) - } -} - -func TestListTasksByPhase_OrdersByAscendingPriority(t *testing.T) { - tmpDir := t.TempDir() - dbPath := filepath.Join(tmpDir, "memory.db") - - store, err := NewSQLiteStore(dbPath) - if err != nil { - t.Fatalf("create store: %v", err) - } - defer func() { _ = store.Close() }() - - plan := &task.Plan{ - ID: "plan-phase-order", - Goal: "Phase order test", - EnrichedGoal: "Phase order test", - Status: task.PlanStatusActive, - } - if err := store.CreatePlan(plan); err != nil { - t.Fatalf("create plan: %v", err) - } - - phase := &task.Phase{ - ID: "phase-1", - PlanID: plan.ID, - Title: "Phase 1", - Status: task.PhaseStatusExpanded, - OrderIndex: 0, - } - if err := store.CreatePhase(phase); err != nil { - t.Fatalf("create phase: %v", err) - } - - tasks := []*task.Task{ - {ID: "task-p90", PlanID: plan.ID, PhaseID: phase.ID, Title: "P90", Description: "P90", Status: task.StatusPending, Priority: 90}, - {ID: "task-p10", PlanID: plan.ID, PhaseID: phase.ID, Title: "P10", Description: "P10", Status: task.StatusPending, Priority: 10}, - {ID: "task-p50", PlanID: plan.ID, PhaseID: phase.ID, Title: "P50", Description: "P50", Status: task.StatusPending, Priority: 50}, - } - for _, tt := range tasks { - if err := store.CreateTask(tt); err != nil { - t.Fatalf("create task %s: %v", tt.ID, err) - } - } - - phaseTasks, err := store.ListTasksByPhase(phase.ID) - if err != nil { - t.Fatalf("list tasks by phase: %v", err) - } - if len(phaseTasks) != 3 { - t.Fatalf("expected 3 tasks, got %d", len(phaseTasks)) - } - - wantOrder := []string{"task-p10", "task-p50", "task-p90"} - for i, wantID := range wantOrder { - if phaseTasks[i].ID != wantID { - t.Fatalf("phase task order mismatch at index %d: got %s, want %s", i, phaseTasks[i].ID, wantID) - } - } -} diff --git a/internal/planner/generator_test.go b/internal/planner/generator_test.go deleted file mode 100644 index f0893eb..0000000 --- a/internal/planner/generator_test.go +++ /dev/null @@ -1,168 +0,0 @@ -package planner - -import ( - "strings" - "testing" -) - -func TestRetryLogic_FormatErrorFeedback(t *testing.T) { - feedback := formatErrorFeedback("JSON Parse Error", "unexpected end of JSON", `{"incomplete":`) - - if !strings.Contains(feedback, "JSON Parse Error") { - t.Error("Expected feedback to contain error type") - } - if !strings.Contains(feedback, "unexpected end of JSON") { - t.Error("Expected feedback to contain error message") - } - if !strings.Contains(feedback, `{"incomplete":`) { - t.Error("Expected feedback to contain raw output") - } -} - -func TestRetryLogic_FormatErrorFeedback_Truncation(t *testing.T) { - // Create a long string that should be truncated - longOutput := strings.Repeat("a", 600) - feedback := formatErrorFeedback("Test Error", "test message", longOutput) - - if !strings.Contains(feedback, "[truncated]") { - t.Error("Expected long output to be truncated") - } - // Should not contain the full 600 character string - if strings.Contains(feedback, longOutput) { - t.Error("Expected output to be truncated, but found full string") - } -} - -func TestRetryLogic_FormatValidationFeedback(t *testing.T) { - result := ValidationResult{ - Valid: false, - Errors: []ValidationError{ - {Field: "Title", Tag: "required", Message: "Title is required"}, - {Field: "Priority", Tag: "priority_range", Value: 150, Message: "Priority must be between 0 and 100"}, - }, - } - - feedback := formatValidationFeedback(result) - - if !strings.Contains(feedback, "SCHEMA VALIDATION ERRORS") { - t.Error("Expected feedback to contain validation errors header") - } - if !strings.Contains(feedback, "Title") { - t.Error("Expected feedback to mention Title field") - } - if !strings.Contains(feedback, "Priority") { - t.Error("Expected feedback to mention Priority field") - } - if !strings.Contains(feedback, "150") { - t.Error("Expected feedback to show current value") - } -} - -func TestRetryLogic_IsTransientError(t *testing.T) { - tests := []struct { - name string - errMsg string - expected bool - }{ - {"nil error", "", false}, - {"rate limit", "rate limit exceeded", true}, - {"http 429", "HTTP 429 Too Many Requests", true}, - {"quota exceeded", "API quota exceeded for today", true}, - {"timeout", "context deadline exceeded: timeout", true}, - {"connection reset", "connection reset by peer", true}, - {"validation error", "validation failed: Title is required", false}, - {"json parse error", "json: cannot unmarshal", false}, - {"generic error", "something went wrong", false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var err error - if tt.errMsg != "" { - err = &testError{msg: tt.errMsg} - } - result := isTransientError(err) - if result != tt.expected { - t.Errorf("isTransientError(%q) = %v, want %v", tt.errMsg, result, tt.expected) - } - }) - } -} - -func TestRetryLogic_CopyMap(t *testing.T) { - original := map[string]any{ - "Goal": "test goal", - "Context": "test context", - } - - copied := copyMap(original) - - // Verify copy has same values - if copied["Goal"] != "test goal" { - t.Error("Expected Goal to be copied") - } - if copied["Context"] != "test context" { - t.Error("Expected Context to be copied") - } - - // Verify modification doesn't affect original - copied["NewKey"] = "new value" - if _, exists := original["NewKey"]; exists { - t.Error("Expected copy to be independent of original") - } -} - -func TestRetryLogic_PromptTemplates(t *testing.T) { - // Verify plan template has required placeholders - if !strings.Contains(planPromptTemplate, "{{.Goal}}") { - t.Error("Plan template missing Goal placeholder") - } - if !strings.Contains(planPromptTemplate, "{{.Context}}") { - t.Error("Plan template missing Context placeholder") - } - if !strings.Contains(planPromptTemplate, "{{.ValidationErrors}}") { - t.Error("Plan template missing ValidationErrors placeholder") - } - - // Verify clarification template has required placeholders - if !strings.Contains(clarificationPromptTemplate, "{{.Goal}}") { - t.Error("Clarification template missing Goal placeholder") - } - if !strings.Contains(clarificationPromptTemplate, "{{.History}}") { - t.Error("Clarification template missing History placeholder") - } - if !strings.Contains(clarificationPromptTemplate, "{{.ValidationErrors}}") { - t.Error("Clarification template missing ValidationErrors placeholder") - } -} - -func TestRetryLogic_GeneratorConfig_DefaultTemperature(t *testing.T) { - gen := NewGenerator(GeneratorConfig{}) - - // Temperature should be 0 (deterministic) by default - if gen.cfg.Temperature != 0.0 { - t.Errorf("Expected default temperature 0.0, got %f", gen.cfg.Temperature) - } -} - -func TestRetryLogic_Constants(t *testing.T) { - // Verify retry constants are sensible - if MaxGenerationRetries < 1 { - t.Error("MaxGenerationRetries should be at least 1") - } - if MaxGenerationRetries > 10 { - t.Error("MaxGenerationRetries should not be excessive") - } - if RetryDelay < 100*1000000 { // 100ms in nanoseconds - t.Error("RetryDelay should be at least 100ms") - } -} - -// testError is a simple error implementation for testing -type testError struct { - msg string -} - -func (e *testError) Error() string { - return e.msg -} diff --git a/internal/planner/middleware_test.go b/internal/planner/middleware_test.go deleted file mode 100644 index c60ca17..0000000 --- a/internal/planner/middleware_test.go +++ /dev/null @@ -1,426 +0,0 @@ -package planner - -import ( - "os" - "path/filepath" - "testing" -) - -func TestSemanticMiddleware_ValidPlan(t *testing.T) { - // Create a valid plan with no file references - plan := &LLMPlanResponse{ - GoalSummary: "Test plan", - Rationale: "This is a test plan with valid tasks", - EstimatedComplexity: "low", - Tasks: []LLMTaskSchema{ - { - Title: "Implement feature", - Description: "Add new feature to the system", - Priority: 50, - Complexity: "low", - AssignedAgent: "coder", - AcceptanceCriteria: []string{"Feature works correctly"}, - ValidationSteps: []string{"echo 'test'"}, - }, - }, - } - - middleware := NewSemanticMiddleware(MiddlewareConfig{ - SkipFileValidation: true, // No file paths in this test - }) - - result := middleware.Validate(plan) - - if !result.Valid { - t.Errorf("Expected valid plan, got errors: %s", result.ErrorSummary()) - } -} - -func TestSemanticMiddleware_MissingFile(t *testing.T) { - plan := &LLMPlanResponse{ - GoalSummary: "Test plan", - Rationale: "Plan references non-existent file", - EstimatedComplexity: "low", - Tasks: []LLMTaskSchema{ - { - Title: "Fix bug in handler", - Description: "Update the handler in /nonexistent/path/handler.go", - Priority: 50, - Complexity: "low", - AssignedAgent: "coder", - AcceptanceCriteria: []string{"Bug is fixed"}, - ValidationSteps: []string{"go test ./..."}, - }, - }, - } - - middleware := NewSemanticMiddleware(MiddlewareConfig{}) - - result := middleware.Validate(plan) - - if result.Valid { - t.Error("Expected validation to fail for missing file") - } - - if len(result.Errors) == 0 { - t.Error("Expected at least one error") - } - - foundMissingFile := false - for _, e := range result.Errors { - if e.Type == "missing_file" { - foundMissingFile = true - break - } - } - if !foundMissingFile { - t.Error("Expected missing_file error type") - } -} - -func TestSemanticMiddleware_MissingFileAsWarning(t *testing.T) { - plan := &LLMPlanResponse{ - GoalSummary: "Test plan", - Rationale: "Plan references non-existent file", - EstimatedComplexity: "low", - Tasks: []LLMTaskSchema{ - { - Title: "Fix bug in handler", - Description: "Update the handler in /nonexistent/path/handler.go", - Priority: 50, - Complexity: "low", - AssignedAgent: "coder", - AcceptanceCriteria: []string{"Bug is fixed"}, - ValidationSteps: []string{"echo done"}, - }, - }, - } - - middleware := NewSemanticMiddleware(MiddlewareConfig{ - AllowMissingFiles: true, - }) - - result := middleware.Validate(plan) - - if !result.Valid { - t.Error("Expected valid when AllowMissingFiles is true") - } - - if len(result.Warnings) == 0 { - t.Error("Expected at least one warning") - } -} - -func TestSemanticMiddleware_ExistingFile(t *testing.T) { - // Create a temp file with a known structure - tmpDir := t.TempDir() - subDir := filepath.Join(tmpDir, "internal") - if err := os.MkdirAll(subDir, 0755); err != nil { - t.Fatal(err) - } - testFile := filepath.Join(subDir, "test.go") - if err := os.WriteFile(testFile, []byte("package test"), 0644); err != nil { - t.Fatal(err) - } - - plan := &LLMPlanResponse{ - GoalSummary: "Test plan", - Rationale: "Plan references existing file", - EstimatedComplexity: "low", - Tasks: []LLMTaskSchema{ - { - Title: "Update test file", - Description: "Modify internal/test.go for the fix", - Priority: 50, - Complexity: "low", - AssignedAgent: "coder", - AcceptanceCriteria: []string{"File is updated"}, - ValidationSteps: []string{"echo done"}, - }, - }, - } - - middleware := NewSemanticMiddleware(MiddlewareConfig{ - BasePath: tmpDir, - }) - - result := middleware.Validate(plan) - - if !result.Valid { - t.Errorf("Expected valid plan for existing file, got: %s", result.ErrorSummary()) - } -} - -func TestSemanticMiddleware_InvalidCommand(t *testing.T) { - plan := &LLMPlanResponse{ - GoalSummary: "Test plan", - Rationale: "Plan with invalid shell command", - EstimatedComplexity: "low", - Tasks: []LLMTaskSchema{ - { - Title: "Run tests", - Description: "Execute test suite", - Priority: 50, - Complexity: "low", - AssignedAgent: "qa", - AcceptanceCriteria: []string{"Tests pass"}, - ValidationSteps: []string{"echo 'unclosed quote"}, // Invalid syntax - }, - }, - } - - middleware := NewSemanticMiddleware(MiddlewareConfig{ - SkipFileValidation: true, - }) - - result := middleware.Validate(plan) - - if result.Valid { - t.Error("Expected validation to fail for invalid command") - } - - foundInvalidCommand := false - for _, e := range result.Errors { - if e.Type == "invalid_command" { - foundInvalidCommand = true - break - } - } - if !foundInvalidCommand { - t.Error("Expected invalid_command error type") - } -} - -func TestSemanticMiddleware_ValidCommand(t *testing.T) { - plan := &LLMPlanResponse{ - GoalSummary: "Test plan", - Rationale: "Plan with valid shell commands", - EstimatedComplexity: "low", - Tasks: []LLMTaskSchema{ - { - Title: "Run tests", - Description: "Execute test suite", - Priority: 50, - Complexity: "low", - AssignedAgent: "qa", - AcceptanceCriteria: []string{"Tests pass"}, - ValidationSteps: []string{ - "go test ./...", - "echo 'done'", - "ls -la && pwd", - }, - }, - }, - } - - middleware := NewSemanticMiddleware(MiddlewareConfig{ - SkipFileValidation: true, - }) - - result := middleware.Validate(plan) - - if !result.Valid { - t.Errorf("Expected valid plan, got: %s", result.ErrorSummary()) - } - - if result.Stats.CommandsValidated != 3 { - t.Errorf("Expected 3 commands validated, got %d", result.Stats.CommandsValidated) - } -} - -func TestSemanticMiddleware_CreationContext(t *testing.T) { - plan := &LLMPlanResponse{ - GoalSummary: "Test plan", - Rationale: "Plan that creates a new file", - EstimatedComplexity: "low", - Tasks: []LLMTaskSchema{ - { - Title: "Create new handler", - Description: "Create a new file internal/handlers/new_handler.go", - Priority: 50, - Complexity: "low", - AssignedAgent: "coder", - AcceptanceCriteria: []string{"File is created"}, - ValidationSteps: []string{"echo done"}, - }, - }, - } - - middleware := NewSemanticMiddleware(MiddlewareConfig{}) - - result := middleware.Validate(plan) - - // Should be valid because the file is mentioned in a "create" context - if !result.Valid { - t.Errorf("Expected valid plan for creation context, got: %s", result.ErrorSummary()) - } -} - -func TestSemanticMiddleware_Stats(t *testing.T) { - plan := &LLMPlanResponse{ - GoalSummary: "Test plan", - Rationale: "Multi-task plan", - EstimatedComplexity: "medium", - Tasks: []LLMTaskSchema{ - { - Title: "Task 1", - Description: "First task", - Priority: 50, - Complexity: "low", - AssignedAgent: "coder", - AcceptanceCriteria: []string{"Done"}, - ValidationSteps: []string{"echo 1"}, - }, - { - Title: "Task 2", - Description: "Second task", - Priority: 60, - Complexity: "low", - AssignedAgent: "coder", - AcceptanceCriteria: []string{"Done"}, - ValidationSteps: []string{"echo 2", "echo 3"}, - }, - }, - } - - middleware := NewSemanticMiddleware(MiddlewareConfig{ - SkipFileValidation: true, - }) - - result := middleware.Validate(plan) - - if result.Stats.TotalTasks != 2 { - t.Errorf("Expected 2 total tasks, got %d", result.Stats.TotalTasks) - } - - if result.Stats.CommandsValidated != 3 { - t.Errorf("Expected 3 commands validated, got %d", result.Stats.CommandsValidated) - } -} - -func TestSemanticMiddleware_ExtractFilePaths(t *testing.T) { - tests := []struct { - name string - text string - expected int - }{ - {"absolute path", "Update /path/to/file.go", 1}, - {"relative path", "Check internal/handler.go", 1}, - {"quoted path", "Edit `config/settings.yaml`", 1}, - {"multiple paths", "Update file.go and test.ts", 2}, - {"no paths", "Just some text without files", 0}, - {"url should not match", "Visit http://example.com/test.html", 0}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - paths := extractFilePaths(tt.text) - if len(paths) != tt.expected { - t.Errorf("extractFilePaths(%q) returned %d paths, want %d: %v", - tt.text, len(paths), tt.expected, paths) - } - }) - } -} - -func TestSemanticMiddleware_IsLikelyFilePath(t *testing.T) { - tests := []struct { - path string - expected bool - }{ - {"handler.go", true}, - {"config.yaml", true}, - {"test.ts", true}, - {"noext", false}, - {"http://example.com", false}, - {"file.xyz", false}, // Unknown extension - {"component.tsx", true}, - } - - for _, tt := range tests { - t.Run(tt.path, func(t *testing.T) { - result := isLikelyFilePath(tt.path) - if result != tt.expected { - t.Errorf("isLikelyFilePath(%q) = %v, want %v", tt.path, result, tt.expected) - } - }) - } -} - -func TestSemanticMiddleware_ErrorSummary(t *testing.T) { - result := SemanticValidationResult{ - Valid: false, - Errors: []SemanticError{ - {TaskIndex: 0, TaskTitle: "Task 1", Type: "missing_file", Message: "File not found"}, - {TaskIndex: 1, TaskTitle: "Task 2", Type: "invalid_command", Message: "Syntax error"}, - }, - } - - summary := result.ErrorSummary() - - if summary == "" { - t.Error("Expected non-empty error summary") - } - if !contains(summary, "Task 1") { - t.Error("Expected summary to contain Task 1") - } - if !contains(summary, "Task 2") { - t.Error("Expected summary to contain Task 2") - } -} - -func TestSemanticMiddleware_WarningSummary(t *testing.T) { - result := SemanticValidationResult{ - Valid: true, - Warnings: []SemanticWarning{ - {TaskIndex: 0, TaskTitle: "Task 1", Type: "missing_file", Message: "Optional file not found"}, - }, - } - - summary := result.WarningSummary() - - if summary == "" { - t.Error("Expected non-empty warning summary") - } - if !contains(summary, "Task 1") { - t.Error("Expected summary to contain Task 1") - } -} - -func TestSemanticMiddleware_SkipAllValidation(t *testing.T) { - plan := &LLMPlanResponse{ - GoalSummary: "Test plan", - Rationale: "Plan with issues that should be skipped", - EstimatedComplexity: "low", - Tasks: []LLMTaskSchema{ - { - Title: "Task", - Description: "Reference /nonexistent/file.go", - Priority: 50, - Complexity: "low", - AssignedAgent: "coder", - AcceptanceCriteria: []string{"Done"}, - ValidationSteps: []string{"invalid ( syntax"}, - }, - }, - } - - middleware := NewSemanticMiddleware(MiddlewareConfig{ - SkipFileValidation: true, - SkipCommandValidation: true, - }) - - result := middleware.Validate(plan) - - if !result.Valid { - t.Error("Expected valid when all validation is skipped") - } - - if result.Stats.PathsChecked != 0 { - t.Error("Expected no paths checked when file validation is skipped") - } - - if result.Stats.CommandsValidated != 0 { - t.Error("Expected no commands validated when command validation is skipped") - } -} diff --git a/internal/planner/schema_test.go b/internal/planner/schema_test.go deleted file mode 100644 index 89d4fd9..0000000 --- a/internal/planner/schema_test.go +++ /dev/null @@ -1,279 +0,0 @@ -package planner - -import ( - "testing" -) - -func TestPlanSchema_Valid(t *testing.T) { - plan := LLMPlanResponse{ - GoalSummary: "Implement user authentication with JWT", - Rationale: "JWT provides stateless authentication that scales well with our microservices architecture", - EstimatedComplexity: "medium", - Tasks: []LLMTaskSchema{ - { - Title: "Create JWT token service", - Description: "Implement a service to generate and validate JWT tokens with refresh token support", - Priority: 80, - Complexity: "medium", - AssignedAgent: "coder", - AcceptanceCriteria: []string{"Token generation works", "Token validation works"}, - ValidationSteps: []string{"go test ./internal/auth/..."}, - }, - }, - } - - result := plan.Validate() - if !result.Valid { - t.Errorf("Expected valid plan, got errors: %s", result.ErrorSummary()) - } -} - -func TestPlanSchema_MissingGoalSummary(t *testing.T) { - plan := LLMPlanResponse{ - GoalSummary: "", // Empty - should fail - Rationale: "This is a valid rationale with enough characters", - EstimatedComplexity: "low", - Tasks: []LLMTaskSchema{ - { - Title: "Task 1", - Description: "Valid description", - Priority: 50, - Complexity: "low", - AssignedAgent: "coder", - AcceptanceCriteria: []string{"Done"}, - }, - }, - } - - result := plan.Validate() - if result.Valid { - t.Error("Expected validation to fail for empty GoalSummary") - } - - found := false - for _, e := range result.Errors { - if e.Field == "GoalSummary" { - found = true - break - } - } - if !found { - t.Error("Expected error for GoalSummary field") - } -} - -func TestPlanSchema_MissingTasks(t *testing.T) { - plan := LLMPlanResponse{ - GoalSummary: "Valid goal summary", - Rationale: "This is a valid rationale with enough characters", - EstimatedComplexity: "low", - Tasks: []LLMTaskSchema{}, // Empty - should fail - } - - result := plan.Validate() - if result.Valid { - t.Error("Expected validation to fail for empty Tasks") - } - - found := false - for _, e := range result.Errors { - if e.Field == "Tasks" && e.Tag == "min" { - found = true - break - } - } - if !found { - t.Error("Expected min error for Tasks field") - } -} - -func TestPlanSchema_InvalidComplexity(t *testing.T) { - plan := LLMPlanResponse{ - GoalSummary: "Valid goal summary", - Rationale: "This is a valid rationale with enough characters", - EstimatedComplexity: "very_high", // Invalid - should fail - Tasks: []LLMTaskSchema{ - { - Title: "Task 1", - Description: "Valid description here", - Priority: 50, - Complexity: "low", - AssignedAgent: "coder", - AcceptanceCriteria: []string{"Done"}, - }, - }, - } - - result := plan.Validate() - if result.Valid { - t.Error("Expected validation to fail for invalid EstimatedComplexity") - } -} - -func TestTaskSchema_Valid(t *testing.T) { - task := LLMTaskSchema{ - Title: "Implement login endpoint", - Description: "Create a POST /api/auth/login endpoint that validates credentials", - Priority: 75, - Complexity: "medium", - AssignedAgent: "coder", - AcceptanceCriteria: []string{"Endpoint returns JWT on success", "Returns 401 on invalid credentials"}, - ValidationSteps: []string{"curl -X POST http://localhost:8080/api/auth/login"}, - Scope: "auth", - Keywords: []string{"login", "jwt", "auth"}, - } - - result := task.Validate() - if !result.Valid { - t.Errorf("Expected valid task, got errors: %s", result.ErrorSummary()) - } -} - -func TestTaskSchema_InvalidPriority(t *testing.T) { - task := LLMTaskSchema{ - Title: "Test task", - Description: "Valid description here", - Priority: 150, // Invalid - should be 0-100 - Complexity: "low", - AssignedAgent: "coder", - AcceptanceCriteria: []string{"Done"}, - } - - result := task.Validate() - if result.Valid { - t.Error("Expected validation to fail for priority > 100") - } -} - -func TestTaskSchema_NegativePriority(t *testing.T) { - task := LLMTaskSchema{ - Title: "Test task", - Description: "Valid description here", - Priority: -10, // Invalid - should be >= 0 - Complexity: "low", - AssignedAgent: "coder", - AcceptanceCriteria: []string{"Done"}, - } - - result := task.Validate() - if result.Valid { - t.Error("Expected validation to fail for negative priority") - } -} - -func TestTaskSchema_InvalidAgent(t *testing.T) { - task := LLMTaskSchema{ - Title: "Test task", - Description: "Valid description here", - Priority: 50, - Complexity: "low", - AssignedAgent: "designer", // Invalid - not in allowed list - AcceptanceCriteria: []string{"Done"}, - } - - result := task.Validate() - if result.Valid { - t.Error("Expected validation to fail for invalid AssignedAgent") - } -} - -func TestTaskSchema_MissingAcceptanceCriteria(t *testing.T) { - task := LLMTaskSchema{ - Title: "Test task", - Description: "Valid description here", - Priority: 50, - Complexity: "low", - AssignedAgent: "coder", - AcceptanceCriteria: []string{}, // Empty - should fail - } - - result := task.Validate() - if result.Valid { - t.Error("Expected validation to fail for empty AcceptanceCriteria") - } -} - -func TestTaskSchema_WhitespaceOnlyTitle(t *testing.T) { - task := LLMTaskSchema{ - Title: " ", // Whitespace only - should fail - Description: "Valid description here", - Priority: 50, - Complexity: "low", - AssignedAgent: "coder", - AcceptanceCriteria: []string{"Done"}, - } - - result := task.Validate() - if result.Valid { - t.Error("Expected validation to fail for whitespace-only Title") - } -} - -func TestClarificationSchema_ReadyToPlan(t *testing.T) { - clarification := LLMClarificationResponse{ - IsReadyToPlan: true, - EnrichedGoal: "Implement a comprehensive authentication system using JWT tokens with refresh token rotation", - GoalSummary: "Implement JWT authentication with refresh tokens", - Assumptions: []string{"Using Go standard library for crypto"}, - Constraints: []string{"Must be stateless"}, - } - - result := clarification.Validate() - if !result.Valid { - t.Errorf("Expected valid clarification, got errors: %s", result.ErrorSummary()) - } -} - -func TestClarificationSchema_NeedsQuestions(t *testing.T) { - clarification := LLMClarificationResponse{ - IsReadyToPlan: false, - GoalSummary: "Implement authentication", - Questions: []string{"What authentication method?", "Do you need MFA?"}, - } - - result := clarification.Validate() - if !result.Valid { - t.Errorf("Expected valid clarification with questions, got errors: %s", result.ErrorSummary()) - } -} - -func TestValidationResult_ErrorSummary(t *testing.T) { - result := ValidationResult{ - Valid: false, - Errors: []ValidationError{ - {Field: "Title", Tag: "required", Message: "Title is required"}, - {Field: "Priority", Tag: "priority_range", Message: "Priority must be between 0 and 100"}, - }, - } - - summary := result.ErrorSummary() - if summary == "" { - t.Error("Expected non-empty error summary") - } - if !contains(summary, "Title is required") { - t.Error("Expected error summary to contain 'Title is required'") - } - if !contains(summary, "Priority must be between") { - t.Error("Expected error summary to contain priority error") - } -} - -func TestValidationResult_EmptySummaryWhenValid(t *testing.T) { - result := ValidationResult{Valid: true} - if result.ErrorSummary() != "" { - t.Error("Expected empty error summary for valid result") - } -} - -func contains(s, substr string) bool { - return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr)) -} - -func containsHelper(s, substr string) bool { - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return true - } - } - return false -} diff --git a/internal/policy/audit_test.go b/internal/policy/audit_test.go deleted file mode 100644 index 2d08ae3..0000000 --- a/internal/policy/audit_test.go +++ /dev/null @@ -1,522 +0,0 @@ -package policy - -import ( - "database/sql" - "testing" - "time" - - _ "modernc.org/sqlite" -) - -// setupTestDB creates an in-memory SQLite database with the policy_decisions table. -func setupTestDB(t *testing.T) *sql.DB { - t.Helper() - - db, err := sql.Open("sqlite", ":memory:") - if err != nil { - t.Fatalf("open database: %v", err) - } - - // Enable foreign keys - if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil { - t.Fatalf("enable foreign keys: %v", err) - } - - // Create the policy_decisions table - schema := ` - CREATE TABLE IF NOT EXISTS policy_decisions ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - decision_id TEXT UNIQUE NOT NULL, - policy_path TEXT NOT NULL, - result TEXT NOT NULL, - violations TEXT, - input_json TEXT NOT NULL, - task_id TEXT, - session_id TEXT, - evaluated_at TEXT NOT NULL - ); - - CREATE INDEX IF NOT EXISTS idx_policy_decisions_task ON policy_decisions(task_id); - CREATE INDEX IF NOT EXISTS idx_policy_decisions_session ON policy_decisions(session_id); - CREATE INDEX IF NOT EXISTS idx_policy_decisions_result ON policy_decisions(result); - CREATE INDEX IF NOT EXISTS idx_policy_decisions_evaluated_at ON policy_decisions(evaluated_at); - ` - - if _, err := db.Exec(schema); err != nil { - t.Fatalf("create schema: %v", err) - } - - t.Cleanup(func() { - _ = db.Close() - }) - - return db -} - -func TestAuditStore_SaveDecision(t *testing.T) { - db := setupTestDB(t) - store := NewAuditStore(db) - - tests := []struct { - name string - decision *PolicyDecision - wantErr bool - }{ - { - name: "save allow decision", - decision: &PolicyDecision{ - DecisionID: "test-allow-1", - PolicyPath: "taskwing.policy", - Result: PolicyResultAllow, - Input: map[string]any{"task": map[string]any{"id": "task-1"}}, - }, - wantErr: false, - }, - { - name: "save deny decision with violations", - decision: &PolicyDecision{ - DecisionID: "test-deny-1", - PolicyPath: "taskwing.policy", - Result: PolicyResultDeny, - Violations: []string{"Cannot modify protected file", "Raw SQL in controller"}, - Input: map[string]any{"task": map[string]any{"id": "task-2"}}, - TaskID: "task-2", - SessionID: "session-1", - }, - wantErr: false, - }, - { - name: "auto-generate decision ID", - decision: &PolicyDecision{ - PolicyPath: "taskwing.policy", - Result: PolicyResultAllow, - Input: map[string]any{}, - }, - wantErr: false, - }, - { - name: "nil decision", - decision: nil, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := store.SaveDecision(tt.decision) - if (err != nil) != tt.wantErr { - t.Errorf("SaveDecision() error = %v, wantErr %v", err, tt.wantErr) - } - - if !tt.wantErr && tt.decision != nil { - // Verify the decision was saved - saved, err := store.GetDecision(tt.decision.DecisionID) - if err != nil { - t.Errorf("GetDecision() error = %v", err) - return - } - if saved.PolicyPath != tt.decision.PolicyPath { - t.Errorf("PolicyPath = %v, want %v", saved.PolicyPath, tt.decision.PolicyPath) - } - if saved.Result != tt.decision.Result { - t.Errorf("Result = %v, want %v", saved.Result, tt.decision.Result) - } - } - }) - } -} - -func TestAuditStore_GetDecision(t *testing.T) { - db := setupTestDB(t) - store := NewAuditStore(db) - - // Save a test decision - decision := &PolicyDecision{ - DecisionID: "test-get-1", - PolicyPath: "taskwing.policy.protected", - Result: PolicyResultDeny, - Violations: []string{"Access denied to core/"}, - Input: map[string]any{"file": "core/router.go"}, - TaskID: "task-123", - SessionID: "session-456", - } - if err := store.SaveDecision(decision); err != nil { - t.Fatalf("SaveDecision() error = %v", err) - } - - // Retrieve it - got, err := store.GetDecision("test-get-1") - if err != nil { - t.Fatalf("GetDecision() error = %v", err) - } - - if got.DecisionID != decision.DecisionID { - t.Errorf("DecisionID = %v, want %v", got.DecisionID, decision.DecisionID) - } - if got.PolicyPath != decision.PolicyPath { - t.Errorf("PolicyPath = %v, want %v", got.PolicyPath, decision.PolicyPath) - } - if got.Result != decision.Result { - t.Errorf("Result = %v, want %v", got.Result, decision.Result) - } - if len(got.Violations) != 1 || got.Violations[0] != "Access denied to core/" { - t.Errorf("Violations = %v, want %v", got.Violations, decision.Violations) - } - if got.TaskID != decision.TaskID { - t.Errorf("TaskID = %v, want %v", got.TaskID, decision.TaskID) - } - if got.SessionID != decision.SessionID { - t.Errorf("SessionID = %v, want %v", got.SessionID, decision.SessionID) - } - - // Test not found - _, err = store.GetDecision("nonexistent") - if err == nil { - t.Error("GetDecision() expected error for nonexistent decision") - } -} - -func TestAuditStore_ListDecisions(t *testing.T) { - db := setupTestDB(t) - store := NewAuditStore(db) - - // Save multiple decisions - decisions := []*PolicyDecision{ - { - DecisionID: "list-1", - PolicyPath: "taskwing.policy", - Result: PolicyResultAllow, - Input: map[string]any{}, - TaskID: "task-1", - SessionID: "session-A", - }, - { - DecisionID: "list-2", - PolicyPath: "taskwing.policy", - Result: PolicyResultDeny, - Violations: []string{"violation 1"}, - Input: map[string]any{}, - TaskID: "task-2", - SessionID: "session-A", - }, - { - DecisionID: "list-3", - PolicyPath: "taskwing.policy", - Result: PolicyResultDeny, - Violations: []string{"violation 2"}, - Input: map[string]any{}, - TaskID: "task-3", - SessionID: "session-B", - }, - } - - for _, d := range decisions { - if err := store.SaveDecision(d); err != nil { - t.Fatalf("SaveDecision() error = %v", err) - } - } - - tests := []struct { - name string - opts ListDecisionsOptions - wantCount int - }{ - { - name: "list all", - opts: ListDecisionsOptions{}, - wantCount: 3, - }, - { - name: "filter by session A", - opts: ListDecisionsOptions{SessionID: "session-A"}, - wantCount: 2, - }, - { - name: "filter by session B", - opts: ListDecisionsOptions{SessionID: "session-B"}, - wantCount: 1, - }, - { - name: "filter by deny result", - opts: ListDecisionsOptions{Result: PolicyResultDeny}, - wantCount: 2, - }, - { - name: "filter by allow result", - opts: ListDecisionsOptions{Result: PolicyResultAllow}, - wantCount: 1, - }, - { - name: "filter by task", - opts: ListDecisionsOptions{TaskID: "task-2"}, - wantCount: 1, - }, - { - name: "limit results", - opts: ListDecisionsOptions{Limit: 2}, - wantCount: 2, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := store.ListDecisions(tt.opts) - if err != nil { - t.Errorf("ListDecisions() error = %v", err) - return - } - if len(got) != tt.wantCount { - t.Errorf("ListDecisions() returned %d decisions, want %d", len(got), tt.wantCount) - } - }) - } -} - -func TestAuditStore_CountViolations(t *testing.T) { - db := setupTestDB(t) - store := NewAuditStore(db) - - // Save some decisions - now := time.Now().UTC() - decisions := []*PolicyDecision{ - { - DecisionID: "count-1", - PolicyPath: "taskwing.policy", - Result: PolicyResultDeny, - Input: map[string]any{}, - EvaluatedAt: now.Add(-1 * time.Hour), - }, - { - DecisionID: "count-2", - PolicyPath: "taskwing.policy", - Result: PolicyResultDeny, - Input: map[string]any{}, - EvaluatedAt: now.Add(-30 * time.Minute), - }, - { - DecisionID: "count-3", - PolicyPath: "taskwing.policy", - Result: PolicyResultAllow, - Input: map[string]any{}, - EvaluatedAt: now.Add(-15 * time.Minute), - }, - } - - for _, d := range decisions { - if err := store.SaveDecision(d); err != nil { - t.Fatalf("SaveDecision() error = %v", err) - } - } - - count, err := store.CountViolations(now.Add(-2 * time.Hour)) - if err != nil { - t.Fatalf("CountViolations() error = %v", err) - } - if count != 2 { - t.Errorf("CountViolations() = %d, want 2", count) - } -} - -func TestAuditStore_DeleteDecision(t *testing.T) { - db := setupTestDB(t) - store := NewAuditStore(db) - - // Save a decision - decision := &PolicyDecision{ - DecisionID: "delete-1", - PolicyPath: "taskwing.policy", - Result: PolicyResultAllow, - Input: map[string]any{}, - } - if err := store.SaveDecision(decision); err != nil { - t.Fatalf("SaveDecision() error = %v", err) - } - - // Delete it - if err := store.DeleteDecision("delete-1"); err != nil { - t.Errorf("DeleteDecision() error = %v", err) - } - - // Verify it's gone - _, err := store.GetDecision("delete-1") - if err == nil { - t.Error("GetDecision() expected error after deletion") - } - - // Delete nonexistent - err = store.DeleteDecision("nonexistent") - if err == nil { - t.Error("DeleteDecision() expected error for nonexistent decision") - } -} - -func TestAuditStore_PruneOldDecisions(t *testing.T) { - db := setupTestDB(t) - store := NewAuditStore(db) - - // Save decisions with different timestamps - now := time.Now().UTC() - decisions := []*PolicyDecision{ - { - DecisionID: "prune-1", - PolicyPath: "taskwing.policy", - Result: PolicyResultAllow, - Input: map[string]any{}, - EvaluatedAt: now.Add(-48 * time.Hour), // 2 days old - }, - { - DecisionID: "prune-2", - PolicyPath: "taskwing.policy", - Result: PolicyResultAllow, - Input: map[string]any{}, - EvaluatedAt: now.Add(-1 * time.Hour), // 1 hour old - }, - } - - for _, d := range decisions { - if err := store.SaveDecision(d); err != nil { - t.Fatalf("SaveDecision() error = %v", err) - } - } - - // Prune decisions older than 24 hours - pruned, err := store.PruneOldDecisions(24 * time.Hour) - if err != nil { - t.Fatalf("PruneOldDecisions() error = %v", err) - } - if pruned != 1 { - t.Errorf("PruneOldDecisions() pruned %d, want 1", pruned) - } - - // Verify only one remains - remaining, err := store.ListDecisions(ListDecisionsOptions{}) - if err != nil { - t.Fatalf("ListDecisions() error = %v", err) - } - if len(remaining) != 1 { - t.Errorf("Expected 1 remaining decision, got %d", len(remaining)) - } -} - -func TestPolicyDecision_Methods(t *testing.T) { - t.Run("IsAllowed", func(t *testing.T) { - allow := &PolicyDecision{Result: PolicyResultAllow} - deny := &PolicyDecision{Result: PolicyResultDeny} - - if !allow.IsAllowed() { - t.Error("IsAllowed() = false for allow result") - } - if allow.IsDenied() { - t.Error("IsDenied() = true for allow result") - } - if deny.IsAllowed() { - t.Error("IsAllowed() = true for deny result") - } - if !deny.IsDenied() { - t.Error("IsDenied() = false for deny result") - } - }) - - t.Run("ViolationsJSON", func(t *testing.T) { - d := &PolicyDecision{Violations: []string{"a", "b"}} - got := d.ViolationsJSON() - want := `["a","b"]` - if got != want { - t.Errorf("ViolationsJSON() = %v, want %v", got, want) - } - - empty := &PolicyDecision{} - if empty.ViolationsJSON() != "[]" { - t.Errorf("ViolationsJSON() for empty = %v, want []", empty.ViolationsJSON()) - } - }) - - t.Run("InputJSON", func(t *testing.T) { - d := &PolicyDecision{Input: map[string]any{"key": "value"}} - got := d.InputJSON() - want := `{"key":"value"}` - if got != want { - t.Errorf("InputJSON() = %v, want %v", got, want) - } - - empty := &PolicyDecision{} - if empty.InputJSON() != "{}" { - t.Errorf("InputJSON() for nil = %v, want {}", empty.InputJSON()) - } - }) -} - -func TestParseViolations(t *testing.T) { - tests := []struct { - name string - input string - want []string - }{ - { - name: "valid JSON array", - input: `["a","b","c"]`, - want: []string{"a", "b", "c"}, - }, - { - name: "empty string", - input: "", - want: nil, - }, - { - name: "empty array", - input: "[]", - want: nil, - }, - { - name: "invalid JSON", - input: "not json", - want: nil, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := ParseViolations(tt.input) - if len(got) != len(tt.want) { - t.Errorf("ParseViolations() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestEvaluationResult_IsBlocked(t *testing.T) { - tests := []struct { - name string - result EvaluationResult - want bool - }{ - { - name: "allowed", - result: EvaluationResult{Allowed: true, Denied: false}, - want: false, - }, - { - name: "denied", - result: EvaluationResult{Allowed: false, Denied: true}, - want: true, - }, - { - name: "violations present", - result: EvaluationResult{Allowed: true, Violations: []string{"violation"}}, - want: true, - }, - { - name: "warnings only", - result: EvaluationResult{Allowed: true, Warnings: []string{"warning"}}, - want: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := tt.result.IsBlocked(); got != tt.want { - t.Errorf("IsBlocked() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/internal/policy/builtins_test.go b/internal/policy/builtins_test.go deleted file mode 100644 index 570a34f..0000000 --- a/internal/policy/builtins_test.go +++ /dev/null @@ -1,488 +0,0 @@ -package policy - -import ( - "testing" - - "github.com/spf13/afero" -) - -func TestFileLineCountImpl(t *testing.T) { - fs := afero.NewMemMapFs() - ctx := &BuiltinContext{ - WorkDir: "/project", - Fs: fs, - } - - // Create a test file with 5 lines - _ = fs.MkdirAll("/project", 0755) - content := "line1\nline2\nline3\nline4\nline5" - _ = afero.WriteFile(fs, "/project/test.go", []byte(content), 0644) - - tests := []struct { - name string - path string - want int - }{ - { - name: "existing file", - path: "test.go", - want: 5, - }, - { - name: "absolute path", - path: "/project/test.go", - want: 5, - }, - { - name: "non-existent file", - path: "nonexistent.go", - want: -1, - }, - { - name: "empty path resolves to workdir", - path: "", - want: 0, // Empty path resolves to directory which has 0 scannable lines - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := fileLineCountImpl(ctx, tt.path) - if got != tt.want { - t.Errorf("fileLineCountImpl() = %d, want %d", got, tt.want) - } - }) - } -} - -func TestFileLineCountImpl_EmptyFile(t *testing.T) { - fs := afero.NewMemMapFs() - ctx := &BuiltinContext{ - WorkDir: "/project", - Fs: fs, - } - - _ = fs.MkdirAll("/project", 0755) - _ = afero.WriteFile(fs, "/project/empty.go", []byte(""), 0644) - - got := fileLineCountImpl(ctx, "empty.go") - if got != 0 { - t.Errorf("fileLineCountImpl() for empty file = %d, want 0", got) - } -} - -func TestHasPatternImpl(t *testing.T) { - fs := afero.NewMemMapFs() - ctx := &BuiltinContext{ - WorkDir: "/project", - Fs: fs, - } - - _ = fs.MkdirAll("/project", 0755) - content := `package main - -import "fmt" - -func main() { - password := "secret123" - apiKey := "sk-abc123" - db.Query("SELECT * FROM users") -} -` - _ = afero.WriteFile(fs, "/project/main.go", []byte(content), 0644) - - tests := []struct { - name string - path string - pattern string - want bool - }{ - { - name: "find hardcoded password", - path: "main.go", - pattern: `password\s*:?=\s*"[^"]+"`, - want: true, - }, - { - name: "find api key pattern", - path: "main.go", - pattern: `apiKey\s*:?=\s*"sk-`, - want: true, - }, - { - name: "find raw SQL", - path: "main.go", - pattern: `db\.Query\(`, - want: true, - }, - { - name: "pattern not found", - path: "main.go", - pattern: `DOESNOTEXIST`, - want: false, - }, - { - name: "file not found", - path: "nonexistent.go", - pattern: `.*`, - want: false, - }, - { - name: "invalid regex", - path: "main.go", - pattern: `[invalid`, - want: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := hasPatternImpl(ctx, tt.path, tt.pattern) - if got != tt.want { - t.Errorf("hasPatternImpl() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestFileImportsImpl(t *testing.T) { - fs := afero.NewMemMapFs() - ctx := &BuiltinContext{ - WorkDir: "/project", - Fs: fs, - } - - _ = fs.MkdirAll("/project", 0755) - - // Single import - singleImport := `package main - -import "fmt" - -func main() {} -` - _ = afero.WriteFile(fs, "/project/single.go", []byte(singleImport), 0644) - - // Block import - blockImport := `package main - -import ( - "context" - "fmt" - "strings" - - "github.com/example/pkg" -) - -func main() {} -` - _ = afero.WriteFile(fs, "/project/block.go", []byte(blockImport), 0644) - - // No imports - noImport := `package main - -func main() {} -` - _ = afero.WriteFile(fs, "/project/none.go", []byte(noImport), 0644) - - tests := []struct { - name string - path string - wantCount int - }{ - { - name: "single import", - path: "single.go", - wantCount: 1, - }, - { - name: "block imports", - path: "block.go", - wantCount: 4, // context, fmt, strings, github.com/example/pkg - }, - { - name: "no imports", - path: "none.go", - wantCount: 0, - }, - { - name: "file not found", - path: "nonexistent.go", - wantCount: 0, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := fileImportsImpl(ctx, tt.path) - if len(got) != tt.wantCount { - t.Errorf("fileImportsImpl() returned %d imports, want %d: %v", len(got), tt.wantCount, got) - } - }) - } -} - -func TestSymbolExistsImpl(t *testing.T) { - fs := afero.NewMemMapFs() - ctx := &BuiltinContext{ - WorkDir: "/project", - Fs: fs, - } - - _ = fs.MkdirAll("/project", 0755) - - goCode := `package main - -func MyFunction() {} - -func (s *Server) HandleRequest() {} - -type Config struct { - Port int -} - -var globalVar = "test" - -const MaxRetries = 3 -` - _ = afero.WriteFile(fs, "/project/code.go", []byte(goCode), 0644) - - jsCode := `class UserService { - constructor() {} -} - -function processData(data) { - return data; -} -` - _ = afero.WriteFile(fs, "/project/code.js", []byte(jsCode), 0644) - - tests := []struct { - name string - path string - symbolName string - want bool - }{ - { - name: "Go function", - path: "code.go", - symbolName: "MyFunction", - want: true, - }, - { - name: "Go method", - path: "code.go", - symbolName: "HandleRequest", - want: true, - }, - { - name: "Go type", - path: "code.go", - symbolName: "Config", - want: true, - }, - { - name: "Go var", - path: "code.go", - symbolName: "globalVar", - want: true, - }, - { - name: "Go const", - path: "code.go", - symbolName: "MaxRetries", - want: true, - }, - { - name: "JS class", - path: "code.js", - symbolName: "UserService", - want: true, - }, - { - name: "JS function", - path: "code.js", - symbolName: "processData", - want: true, - }, - { - name: "symbol not found", - path: "code.go", - symbolName: "NonExistentSymbol", - want: false, - }, - { - name: "file not found", - path: "nonexistent.go", - symbolName: "Anything", - want: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := symbolExistsImpl(ctx, tt.path, tt.symbolName) - if got != tt.want { - t.Errorf("symbolExistsImpl(%q, %q) = %v, want %v", tt.path, tt.symbolName, got, tt.want) - } - }) - } -} - -func TestFileExistsImpl(t *testing.T) { - fs := afero.NewMemMapFs() - ctx := &BuiltinContext{ - WorkDir: "/project", - Fs: fs, - } - - _ = fs.MkdirAll("/project", 0755) - _ = afero.WriteFile(fs, "/project/exists.txt", []byte("content"), 0644) - - tests := []struct { - name string - path string - want bool - }{ - { - name: "existing file", - path: "exists.txt", - want: true, - }, - { - name: "absolute existing file", - path: "/project/exists.txt", - want: true, - }, - { - name: "non-existent file", - path: "nonexistent.txt", - want: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := fileExistsImpl(ctx, tt.path) - if got != tt.want { - t.Errorf("fileExistsImpl() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestParseGoImports(t *testing.T) { - tests := []struct { - name string - content string - want []string - }{ - { - name: "single import", - content: `package main -import "fmt" -`, - want: []string{"fmt"}, - }, - { - name: "import block", - content: `package main -import ( - "context" - "fmt" -) -`, - want: []string{"context", "fmt"}, - }, - { - name: "named imports", - content: `package main -import ( - ctx "context" - . "fmt" - _ "net/http/pprof" -) -`, - want: []string{"context", "fmt", "net/http/pprof"}, - }, - { - name: "no imports", - content: `package main`, - want: []string{}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := parseGoImports(tt.content) - if len(got) != len(tt.want) { - t.Errorf("parseGoImports() returned %d imports, want %d: %v", len(got), len(tt.want), got) - } - }) - } -} - -func TestGetBuiltinNames(t *testing.T) { - names := GetBuiltinNames() - if len(names) < 5 { - t.Errorf("GetBuiltinNames() returned %d names, want at least 5", len(names)) - } - - expected := []string{ - "taskwing.file_line_count", - "taskwing.has_pattern", - "taskwing.file_imports", - "taskwing.symbol_exists", - "taskwing.file_exists", - } - - for _, e := range expected { - found := false - for _, n := range names { - if n == e { - found = true - break - } - } - if !found { - t.Errorf("GetBuiltinNames() missing %q", e) - } - } -} - -func TestIsBuiltin(t *testing.T) { - tests := []struct { - name string - want bool - }{ - {"taskwing.file_line_count", true}, - {"taskwing.has_pattern", true}, - {"file_line_count", true}, // Short form should match - {"unknown_builtin", false}, - {"rego.parse_json", false}, // Not a TaskWing builtin - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := IsBuiltin(tt.name); got != tt.want { - t.Errorf("IsBuiltin(%q) = %v, want %v", tt.name, got, tt.want) - } - }) - } -} - -func TestNewBuiltinContext(t *testing.T) { - ctx := NewBuiltinContext("/test/dir") - if ctx == nil { - t.Fatal("NewBuiltinContext() returned nil") - } - if ctx.WorkDir != "/test/dir" { - t.Errorf("WorkDir = %q, want %q", ctx.WorkDir, "/test/dir") - } - if ctx.Fs == nil { - t.Error("Fs is nil") - } - if ctx.CodeIntel != nil { - t.Error("CodeIntel should be nil for basic context") - } -} diff --git a/internal/policy/engine_test.go b/internal/policy/engine_test.go deleted file mode 100644 index d8b3d2a..0000000 --- a/internal/policy/engine_test.go +++ /dev/null @@ -1,493 +0,0 @@ -package policy - -import ( - "context" - "testing" - - "github.com/spf13/afero" -) - -func TestEngine_Evaluate_NoPolicies(t *testing.T) { - // When no policies are loaded, everything should be allowed - engine := &Engine{ - policies: nil, - policyPackage: DefaultPolicyPackage, - compiled: true, - } - - input := map[string]any{ - "task": map[string]any{ - "files_modified": []string{"anything.go"}, - }, - } - - decision, err := engine.Evaluate(context.Background(), input) - if err != nil { - t.Fatalf("Evaluate() error = %v", err) - } - - if decision.Result != PolicyResultAllow { - t.Errorf("Result = %v, want %v", decision.Result, PolicyResultAllow) - } - - if len(decision.Violations) != 0 { - t.Errorf("Violations = %v, want empty", decision.Violations) - } -} - -func TestEngine_Evaluate_DenyRule(t *testing.T) { - policy := &PolicyFile{ - Name: "test_deny", - Path: "test_deny.rego", - Content: `package taskwing.policy - -import rego.v1 - -deny contains msg if { - some file in input.task.files_modified - startswith(file, "core/") - msg := sprintf("Cannot modify protected file: %s", [file]) -} -`, - } - - engine := NewEngineWithPolicies("/project", []*PolicyFile{policy}) - - tests := []struct { - name string - input map[string]any - wantResult string - wantViolate bool - }{ - { - name: "allow non-core file", - input: map[string]any{ - "task": map[string]any{ - "files_modified": []string{"internal/app/main.go"}, - }, - }, - wantResult: PolicyResultAllow, - wantViolate: false, - }, - { - name: "deny core file", - input: map[string]any{ - "task": map[string]any{ - "files_modified": []string{"core/router.go"}, - }, - }, - wantResult: PolicyResultDeny, - wantViolate: true, - }, - { - name: "deny multiple core files", - input: map[string]any{ - "task": map[string]any{ - "files_modified": []string{"core/a.go", "core/b.go"}, - }, - }, - wantResult: PolicyResultDeny, - wantViolate: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - decision, err := engine.Evaluate(context.Background(), tt.input) - if err != nil { - t.Fatalf("Evaluate() error = %v", err) - } - - if decision.Result != tt.wantResult { - t.Errorf("Result = %v, want %v", decision.Result, tt.wantResult) - } - - hasViolations := len(decision.Violations) > 0 - if hasViolations != tt.wantViolate { - t.Errorf("Has violations = %v, want %v. Violations: %v", hasViolations, tt.wantViolate, decision.Violations) - } - }) - } -} - -func TestEngine_Evaluate_MultiplePolicies(t *testing.T) { - protectedZones := &PolicyFile{ - Name: "protected_zones", - Path: "protected_zones.rego", - Content: `package taskwing.policy - -import rego.v1 - -deny contains msg if { - some file in input.task.files_modified - startswith(file, "vendor/") - msg := "Cannot modify vendor directory" -} -`, - } - - secrets := &PolicyFile{ - Name: "secrets", - Path: "secrets.rego", - Content: `package taskwing.policy - -import rego.v1 - -deny contains msg if { - some file in input.task.files_modified - endswith(file, ".env") - msg := "Cannot modify .env files" -} -`, - } - - engine := NewEngineWithPolicies("/project", []*PolicyFile{protectedZones, secrets}) - - // Should deny both violations - input := map[string]any{ - "task": map[string]any{ - "files_modified": []string{"vendor/lib.go", "config/.env"}, - }, - } - - decision, err := engine.Evaluate(context.Background(), input) - if err != nil { - t.Fatalf("Evaluate() error = %v", err) - } - - if decision.Result != PolicyResultDeny { - t.Errorf("Result = %v, want %v", decision.Result, PolicyResultDeny) - } - - // Should have 2 violations - if len(decision.Violations) != 2 { - t.Errorf("Violations count = %d, want 2. Got: %v", len(decision.Violations), decision.Violations) - } -} - -func TestEngine_EvaluateTask(t *testing.T) { - policy := &PolicyFile{ - Name: "task_policy", - Path: "task_policy.rego", - Content: `package taskwing.policy - -import rego.v1 - -deny contains msg if { - input.task.id == "" - msg := "Task ID is required" -} -`, - } - - engine := NewEngineWithPolicies("/project", []*PolicyFile{policy}) - - // Task without ID should be denied - decision, err := engine.EvaluateTask(context.Background(), &TaskInput{ - ID: "", - Title: "Test task", - }, nil, nil) - if err != nil { - t.Fatalf("EvaluateTask() error = %v", err) - } - - if decision.Result != PolicyResultDeny { - t.Errorf("Result = %v, want deny", decision.Result) - } - - // Task with ID should be allowed - decision, err = engine.EvaluateTask(context.Background(), &TaskInput{ - ID: "task-123", - Title: "Test task", - }, nil, nil) - if err != nil { - t.Fatalf("EvaluateTask() error = %v", err) - } - - if decision.Result != PolicyResultAllow { - t.Errorf("Result = %v, want allow", decision.Result) - } -} - -func TestEngine_EvaluateFiles(t *testing.T) { - policy := &PolicyFile{ - Name: "file_policy", - Path: "file_policy.rego", - Content: `package taskwing.policy - -import rego.v1 - -deny contains msg if { - some file in input.task.files_created - endswith(file, "_test.go") - msg := "Cannot create test files in this context" -} -`, - } - - engine := NewEngineWithPolicies("/project", []*PolicyFile{policy}) - - // Creating a non-test file should be allowed - decision, err := engine.EvaluateFiles(context.Background(), nil, []string{"main.go"}) - if err != nil { - t.Fatalf("EvaluateFiles() error = %v", err) - } - if decision.Result != PolicyResultAllow { - t.Errorf("Result = %v, want allow", decision.Result) - } - - // Creating a test file should be denied - decision, err = engine.EvaluateFiles(context.Background(), nil, []string{"main_test.go"}) - if err != nil { - t.Fatalf("EvaluateFiles() error = %v", err) - } - if decision.Result != PolicyResultDeny { - t.Errorf("Result = %v, want deny", decision.Result) - } -} - -func TestEngine_PolicyManagement(t *testing.T) { - engine := &Engine{ - policies: nil, - policyPackage: DefaultPolicyPackage, - compiled: true, - } - - // Initially no policies - if engine.PolicyCount() != 0 { - t.Errorf("PolicyCount() = %d, want 0", engine.PolicyCount()) - } - - // Add a policy - engine.AddPolicy("test", `package taskwing.policy`) - if engine.PolicyCount() != 1 { - t.Errorf("PolicyCount() after add = %d, want 1", engine.PolicyCount()) - } - - names := engine.PolicyNames() - if len(names) != 1 || names[0] != "test" { - t.Errorf("PolicyNames() = %v, want [test]", names) - } - - // Clear policies - engine.ClearPolicies() - if engine.PolicyCount() != 0 { - t.Errorf("PolicyCount() after clear = %d, want 0", engine.PolicyCount()) - } -} - -func TestNewEngine(t *testing.T) { - fs := afero.NewMemMapFs() - - // Create policies directory with a policy file - _ = fs.MkdirAll("/project/.taskwing/policies", 0755) - policyContent := `package taskwing.policy - -import rego.v1 - -deny contains msg if { - input.blocked == true - msg := "Input is blocked" -} -` - _ = afero.WriteFile(fs, "/project/.taskwing/policies/test.rego", []byte(policyContent), 0644) - - engine, err := NewEngine(EngineConfig{ - WorkDir: "/project", - Fs: fs, - }) - if err != nil { - t.Fatalf("NewEngine() error = %v", err) - } - - if engine.PolicyCount() != 1 { - t.Errorf("PolicyCount() = %d, want 1", engine.PolicyCount()) - } - - // Test evaluation - decision, err := engine.Evaluate(context.Background(), map[string]any{"blocked": true}) - if err != nil { - t.Fatalf("Evaluate() error = %v", err) - } - if decision.Result != PolicyResultDeny { - t.Errorf("Result = %v, want deny", decision.Result) - } - - decision, err = engine.Evaluate(context.Background(), map[string]any{"blocked": false}) - if err != nil { - t.Fatalf("Evaluate() error = %v", err) - } - if decision.Result != PolicyResultAllow { - t.Errorf("Result = %v, want allow", decision.Result) - } -} - -func TestNewEngine_NoPoliciesDir(t *testing.T) { - fs := afero.NewMemMapFs() - - // Don't create policies directory - engine, err := NewEngine(EngineConfig{ - WorkDir: "/project", - Fs: fs, - }) - if err != nil { - t.Fatalf("NewEngine() error = %v (should succeed with no policies)", err) - } - - if engine.PolicyCount() != 0 { - t.Errorf("PolicyCount() = %d, want 0", engine.PolicyCount()) - } -} - -func TestValidatePolicy(t *testing.T) { - tests := []struct { - name string - content string - wantErr bool - }{ - { - name: "valid policy", - content: `package test - -import rego.v1 - -deny contains "blocked" if { - input.x == 1 -} -`, - wantErr: false, - }, - { - name: "invalid syntax", - content: `package test { invalid syntax here`, - wantErr: true, - }, - { - name: "empty content", - content: "", - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := ValidatePolicy(tt.content) - if (err != nil) != tt.wantErr { - t.Errorf("ValidatePolicy() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - -func TestEngine_Evaluate_DecisionFields(t *testing.T) { - policy := &PolicyFile{ - Name: "test", - Path: "test.rego", - Content: `package taskwing.policy`, - } - - engine := NewEngineWithPolicies("/project", []*PolicyFile{policy}) - - input := map[string]any{"test": "data"} - decision, err := engine.Evaluate(context.Background(), input) - if err != nil { - t.Fatalf("Evaluate() error = %v", err) - } - - // Check all decision fields are populated - if decision.DecisionID == "" { - t.Error("DecisionID is empty") - } - - if decision.PolicyPath != DefaultPolicyPackage { - t.Errorf("PolicyPath = %v, want %v", decision.PolicyPath, DefaultPolicyPackage) - } - - if decision.EvaluatedAt.IsZero() { - t.Error("EvaluatedAt is zero") - } - - if decision.Input == nil { - t.Error("Input is nil") - } -} - -func TestEngine_MustEvaluate(t *testing.T) { - engine := &Engine{ - policies: nil, - policyPackage: DefaultPolicyPackage, - compiled: true, - } - - // Should not panic with valid input - decision := engine.MustEvaluate(context.Background(), map[string]any{}) - if decision == nil { - t.Error("MustEvaluate() returned nil") - } -} - -func TestEngine_ReloadPolicies(t *testing.T) { - fs := afero.NewMemMapFs() - - // Start with one policy - _ = fs.MkdirAll("/project/.taskwing/policies", 0755) - _ = afero.WriteFile(fs, "/project/.taskwing/policies/a.rego", []byte(`package taskwing.policy`), 0644) - - engine, _ := NewEngine(EngineConfig{ - WorkDir: "/project", - Fs: fs, - }) - - if engine.PolicyCount() != 1 { - t.Errorf("Initial PolicyCount() = %d, want 1", engine.PolicyCount()) - } - - // Add another policy file - _ = afero.WriteFile(fs, "/project/.taskwing/policies/b.rego", []byte(`package taskwing.policy`), 0644) - - // Reload - err := engine.ReloadPolicies(fs, "/project/.taskwing/policies") - if err != nil { - t.Fatalf("ReloadPolicies() error = %v", err) - } - - if engine.PolicyCount() != 2 { - t.Errorf("PolicyCount() after reload = %d, want 2", engine.PolicyCount()) - } -} - -func TestEngine_Evaluate_NoNetworkCalls(t *testing.T) { - // This test verifies that evaluation happens entirely in-process. - // We use a policy that only uses local operations. - policy := &PolicyFile{ - Name: "local_only", - Path: "local_only.rego", - Content: `package taskwing.policy - -import rego.v1 - -# This policy only uses local string operations -deny contains msg if { - some file in input.task.files_modified - contains(file, "secret") - msg := "Cannot modify files containing 'secret' in path" -} -`, - } - - engine := NewEngineWithPolicies("/project", []*PolicyFile{policy}) - - // Run multiple evaluations - all should be fast and local - for i := 0; i < 100; i++ { - input := map[string]any{ - "task": map[string]any{ - "files_modified": []string{"config.go"}, - }, - } - _, err := engine.Evaluate(context.Background(), input) - if err != nil { - t.Fatalf("Evaluate() iteration %d error = %v", i, err) - } - } - // If we got here without timeout, evaluation is local -} diff --git a/internal/policy/loader_test.go b/internal/policy/loader_test.go deleted file mode 100644 index 98f188d..0000000 --- a/internal/policy/loader_test.go +++ /dev/null @@ -1,265 +0,0 @@ -package policy - -import ( - "testing" - - "github.com/spf13/afero" -) - -func TestLoader_LoadAll(t *testing.T) { - fs := afero.NewMemMapFs() - - // Create policies directory structure - _ = fs.MkdirAll("/project/.taskwing/policies", 0755) - - // Create some .rego files - protectedZonesRego := `package taskwing.policy - -import rego.v1 - -deny contains msg if { - some file in input.task.files_modified - startswith(file, "core/") - msg := "Cannot modify core files" -} -` - secretsRego := `package taskwing.policy.secrets - -import rego.v1 - -deny contains msg if { - some file in input.task.files_modified - endswith(file, ".env") - msg := "Cannot modify .env files" -} -` - - _ = afero.WriteFile(fs, "/project/.taskwing/policies/protected_zones.rego", []byte(protectedZonesRego), 0644) - _ = afero.WriteFile(fs, "/project/.taskwing/policies/secrets.rego", []byte(secretsRego), 0644) - // Add a non-rego file that should be ignored - _ = afero.WriteFile(fs, "/project/.taskwing/policies/README.md", []byte("# Policies"), 0644) - - loader := NewLoader(fs, "/project/.taskwing/policies") - - policies, err := loader.LoadAll() - if err != nil { - t.Fatalf("LoadAll() error = %v", err) - } - - if len(policies) != 2 { - t.Errorf("LoadAll() returned %d policies, want 2", len(policies)) - } - - // Verify policy names - names := make(map[string]bool) - for _, p := range policies { - names[p.Name] = true - if p.Content == "" { - t.Errorf("Policy %s has empty content", p.Name) - } - } - - if !names["protected_zones"] { - t.Error("Expected protected_zones policy to be loaded") - } - if !names["secrets"] { - t.Error("Expected secrets policy to be loaded") - } -} - -func TestLoader_LoadAll_Subdirectories(t *testing.T) { - fs := afero.NewMemMapFs() - - // Create nested directory structure - _ = fs.MkdirAll("/project/.taskwing/policies/security", 0755) - _ = fs.MkdirAll("/project/.taskwing/policies/architecture", 0755) - - _ = afero.WriteFile(fs, "/project/.taskwing/policies/defaults.rego", []byte("package defaults"), 0644) - _ = afero.WriteFile(fs, "/project/.taskwing/policies/security/hardcoded_secrets.rego", []byte("package security"), 0644) - _ = afero.WriteFile(fs, "/project/.taskwing/policies/architecture/layers.rego", []byte("package architecture"), 0644) - - loader := NewLoader(fs, "/project/.taskwing/policies") - - policies, err := loader.LoadAll() - if err != nil { - t.Fatalf("LoadAll() error = %v", err) - } - - if len(policies) != 3 { - t.Errorf("LoadAll() returned %d policies, want 3", len(policies)) - } -} - -func TestLoader_LoadAll_EmptyDirectory(t *testing.T) { - fs := afero.NewMemMapFs() - - // Create empty policies directory - _ = fs.MkdirAll("/project/.taskwing/policies", 0755) - - loader := NewLoader(fs, "/project/.taskwing/policies") - - policies, err := loader.LoadAll() - if err != nil { - t.Fatalf("LoadAll() error = %v", err) - } - - if len(policies) != 0 { - t.Errorf("LoadAll() returned %d policies, want 0", len(policies)) - } -} - -func TestLoader_LoadAll_NonExistentDirectory(t *testing.T) { - fs := afero.NewMemMapFs() - - // Don't create the directory - it doesn't exist - loader := NewLoader(fs, "/project/.taskwing/policies") - - policies, err := loader.LoadAll() - if err != nil { - t.Fatalf("LoadAll() error = %v (should return empty slice)", err) - } - - if len(policies) != 0 { - t.Errorf("LoadAll() returned %d policies for non-existent dir, want 0", len(policies)) - } -} - -func TestLoader_LoadFile(t *testing.T) { - fs := afero.NewMemMapFs() - - content := `package test.policy - -deny contains msg if { - msg := "test violation" -} -` - _ = fs.MkdirAll("/project/.taskwing/policies", 0755) - _ = afero.WriteFile(fs, "/project/.taskwing/policies/test.rego", []byte(content), 0644) - - loader := NewLoader(fs, "/project/.taskwing/policies") - - policy, err := loader.LoadFile("/project/.taskwing/policies/test.rego") - if err != nil { - t.Fatalf("LoadFile() error = %v", err) - } - - if policy.Name != "test" { - t.Errorf("Name = %v, want test", policy.Name) - } - - if policy.Content != content { - t.Errorf("Content mismatch") - } - - if policy.Path != "/project/.taskwing/policies/test.rego" { - t.Errorf("Path = %v, want /project/.taskwing/policies/test.rego", policy.Path) - } -} - -func TestLoader_LoadFile_NotFound(t *testing.T) { - fs := afero.NewMemMapFs() - - loader := NewLoader(fs, "/project/.taskwing/policies") - - _, err := loader.LoadFile("/project/.taskwing/policies/nonexistent.rego") - if err == nil { - t.Error("LoadFile() expected error for non-existent file") - } -} - -func TestLoader_Exists(t *testing.T) { - fs := afero.NewMemMapFs() - - loader := NewLoader(fs, "/project/.taskwing/policies") - - // Directory doesn't exist - exists, err := loader.Exists() - if err != nil { - t.Fatalf("Exists() error = %v", err) - } - if exists { - t.Error("Exists() = true for non-existent directory") - } - - // Create directory - _ = fs.MkdirAll("/project/.taskwing/policies", 0755) - - exists, err = loader.Exists() - if err != nil { - t.Fatalf("Exists() error = %v", err) - } - if !exists { - t.Error("Exists() = false for existing directory") - } -} - -func TestLoader_ListFiles(t *testing.T) { - fs := afero.NewMemMapFs() - - _ = fs.MkdirAll("/project/.taskwing/policies", 0755) - _ = afero.WriteFile(fs, "/project/.taskwing/policies/a.rego", []byte("package a"), 0644) - _ = afero.WriteFile(fs, "/project/.taskwing/policies/b.rego", []byte("package b"), 0644) - _ = afero.WriteFile(fs, "/project/.taskwing/policies/readme.md", []byte("# README"), 0644) - - loader := NewLoader(fs, "/project/.taskwing/policies") - - paths, err := loader.ListFiles() - if err != nil { - t.Fatalf("ListFiles() error = %v", err) - } - - if len(paths) != 2 { - t.Errorf("ListFiles() returned %d paths, want 2", len(paths)) - } -} - -func TestLoader_ListFiles_NonExistentDirectory(t *testing.T) { - fs := afero.NewMemMapFs() - - loader := NewLoader(fs, "/project/.taskwing/policies") - - paths, err := loader.ListFiles() - if err != nil { - t.Fatalf("ListFiles() error = %v (should return empty slice)", err) - } - - if len(paths) != 0 { - t.Errorf("ListFiles() returned %d paths for non-existent dir, want 0", len(paths)) - } -} - -func TestGetPoliciesPath(t *testing.T) { - tests := []struct { - projectRoot string - want string - }{ - { - projectRoot: "/home/user/project", - want: "/home/user/project/.taskwing/policies", - }, - { - projectRoot: "/project", - want: "/project/.taskwing/policies", - }, - } - - for _, tt := range tests { - t.Run(tt.projectRoot, func(t *testing.T) { - got := GetPoliciesPath(tt.projectRoot) - if got != tt.want { - t.Errorf("GetPoliciesPath() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestNewOsLoader(t *testing.T) { - // Just verify it can be created without panicking - loader := NewOsLoader("/tmp/test-policies") - if loader == nil { - t.Fatal("NewOsLoader() returned nil") - } - if loader.baseDir != "/tmp/test-policies" { - t.Errorf("baseDir = %v, want /tmp/test-policies", loader.baseDir) - } -} diff --git a/internal/policy/test_runner_test.go b/internal/policy/test_runner_test.go deleted file mode 100644 index a9ac964..0000000 --- a/internal/policy/test_runner_test.go +++ /dev/null @@ -1,202 +0,0 @@ -package policy - -import ( - "context" - "testing" - - "github.com/spf13/afero" -) - -func TestTestRunner_Run(t *testing.T) { - // Create test filesystem - fs := afero.NewMemMapFs() - - // Create a simple policy - policy := `package taskwing.policy - -import rego.v1 - -is_env_file(file) if startswith(file, ".env") - -deny contains msg if { - some file in input.task.files_modified - is_env_file(file) - msg := sprintf("BLOCKED: Environment file '%s' is protected", [file]) -} -` - - // Create a test file - testFile := `package taskwing.policy - -import rego.v1 - -test_deny_env_file if { - result := deny with input as {"task": {"files_modified": [".env"]}} - count(result) > 0 -} - -test_allow_regular_file if { - result := deny with input as {"task": {"files_modified": ["main.go"]}} - count(result) == 0 -} -` - - // Write files to test filesystem - policiesDir := "/test/policies" - _ = afero.WriteFile(fs, policiesDir+"/default.rego", []byte(policy), 0644) - _ = afero.WriteFile(fs, policiesDir+"/default_test.rego", []byte(testFile), 0644) - - // Create runner and run tests - runner := NewTestRunner(fs, policiesDir, "/test") - ctx := context.Background() - - summary, err := runner.Run(ctx) - if err != nil { - t.Fatalf("Run() failed: %v", err) - } - - // Verify results - if summary.Total != 2 { - t.Errorf("expected 2 tests, got %d", summary.Total) - } - if summary.Passed != 2 { - t.Errorf("expected 2 passed, got %d passed (failed: %d, errored: %d)", - summary.Passed, summary.Failed, summary.Errored) - } - if !summary.AllPassed() { - t.Error("expected AllPassed() to return true") - } -} - -func TestTestRunner_Run_NoTests(t *testing.T) { - // Create test filesystem with no test files - fs := afero.NewMemMapFs() - policiesDir := "/test/policies" - - // Only create a policy file (no test file) - policy := `package taskwing.policy - -import rego.v1 - -deny contains msg if { - false - msg := "never" -} -` - _ = afero.WriteFile(fs, policiesDir+"/default.rego", []byte(policy), 0644) - - runner := NewTestRunner(fs, policiesDir, "/test") - ctx := context.Background() - - summary, err := runner.Run(ctx) - if err != nil { - t.Fatalf("Run() failed: %v", err) - } - - // Should have no tests - if summary.Total != 0 { - t.Errorf("expected 0 tests, got %d", summary.Total) - } -} - -func TestTestRunner_HasTests(t *testing.T) { - fs := afero.NewMemMapFs() - policiesDir := "/test/policies" - - // Initially no test files - _ = afero.WriteFile(fs, policiesDir+"/default.rego", []byte("package test"), 0644) - - runner := NewTestRunner(fs, policiesDir, "/test") - - hasTests, err := runner.HasTests() - if err != nil { - t.Fatalf("HasTests() failed: %v", err) - } - if hasTests { - t.Error("expected HasTests() to return false when no test files") - } - - // Add a test file - _ = afero.WriteFile(fs, policiesDir+"/default_test.rego", []byte("package test"), 0644) - - hasTests, err = runner.HasTests() - if err != nil { - t.Fatalf("HasTests() failed: %v", err) - } - if !hasTests { - t.Error("expected HasTests() to return true after adding test file") - } -} - -func TestTestRunner_Run_FailingTest(t *testing.T) { - fs := afero.NewMemMapFs() - policiesDir := "/test/policies" - - // Create a policy - policy := `package taskwing.policy -import rego.v1 - -deny := false -` - - // Create a failing test - testFile := `package taskwing.policy -import rego.v1 - -test_should_fail if { - 1 == 2 # This will always fail -} -` - - _ = afero.WriteFile(fs, policiesDir+"/default.rego", []byte(policy), 0644) - _ = afero.WriteFile(fs, policiesDir+"/default_test.rego", []byte(testFile), 0644) - - runner := NewTestRunner(fs, policiesDir, "/test") - ctx := context.Background() - - summary, err := runner.Run(ctx) - if err != nil { - t.Fatalf("Run() failed: %v", err) - } - - if summary.Failed != 1 { - t.Errorf("expected 1 failed test, got %d", summary.Failed) - } - if summary.AllPassed() { - t.Error("expected AllPassed() to return false for failing test") - } -} - -func TestTestSummary_FormatSummary(t *testing.T) { - summary := &TestSummary{ - Total: 5, - Passed: 3, - Failed: 1, - Errored: 1, - } - - output := summary.FormatSummary() - if output == "" { - t.Error("FormatSummary() returned empty string") - } - // Should contain the counts - if !contains(output, "5 tests") { - t.Errorf("expected output to contain '5 tests', got: %s", output) - } - if !contains(output, "3 passed") { - t.Errorf("expected output to contain '3 passed', got: %s", output) - } -} - -func contains(s, substr string) bool { - return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr)) -} - -func containsHelper(s, substr string) bool { - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return true - } - } - return false -} diff --git a/internal/project/project_test.go b/internal/project/project_test.go deleted file mode 100644 index 2df83d8..0000000 --- a/internal/project/project_test.go +++ /dev/null @@ -1,305 +0,0 @@ -package project - -import ( - "testing" - - "github.com/spf13/afero" -) - -func TestMarkerTypeString(t *testing.T) { - tests := []struct { - marker MarkerType - expected string - }{ - {MarkerNone, "none"}, - {MarkerTaskWing, ".taskwing"}, - {MarkerGoMod, "go.mod"}, - {MarkerPackageJSON, "package.json"}, - {MarkerCargoToml, "Cargo.toml"}, - {MarkerPomXML, "pom.xml"}, - {MarkerPyProjectToml, "pyproject.toml"}, - {MarkerGit, ".git"}, - } - - for _, tt := range tests { - t.Run(tt.expected, func(t *testing.T) { - if got := tt.marker.String(); got != tt.expected { - t.Errorf("MarkerType.String() = %v, want %v", got, tt.expected) - } - }) - } -} - -func TestMarkerTypePriority(t *testing.T) { - // TaskWing should have highest priority - if MarkerTaskWing.Priority() <= MarkerGoMod.Priority() { - t.Error("MarkerTaskWing should have higher priority than MarkerGoMod") - } - - // Language manifests should have higher priority than Git - if MarkerGoMod.Priority() <= MarkerGit.Priority() { - t.Error("MarkerGoMod should have higher priority than MarkerGit") - } - - // Git should have higher priority than None - if MarkerGit.Priority() <= MarkerNone.Priority() { - t.Error("MarkerGit should have higher priority than MarkerNone") - } -} - -func TestMarkerTypeIsLanguageManifest(t *testing.T) { - languageManifests := []MarkerType{ - MarkerGoMod, - MarkerPackageJSON, - MarkerCargoToml, - MarkerPomXML, - MarkerPyProjectToml, - } - - for _, m := range languageManifests { - if !m.IsLanguageManifest() { - t.Errorf("%s should be a language manifest", m.String()) - } - } - - nonManifests := []MarkerType{ - MarkerNone, - MarkerTaskWing, - MarkerGit, - } - - for _, m := range nonManifests { - if m.IsLanguageManifest() { - t.Errorf("%s should not be a language manifest", m.String()) - } - } -} - -func TestDetectWithGoMod(t *testing.T) { - // Create in-memory filesystem - fs := afero.NewMemMapFs() - - // Create a directory structure with go.mod - _ = fs.MkdirAll("/project/subdir", 0755) - _ = afero.WriteFile(fs, "/project/go.mod", []byte("module test"), 0644) - - detector := NewDetector(fs) - - // Detect from subdir should find go.mod in parent - ctx, err := detector.Detect("/project/subdir") - if err != nil { - t.Fatalf("Detect() error = %v", err) - } - - if ctx.RootPath != "/project" { - t.Errorf("RootPath = %v, want /project", ctx.RootPath) - } - - if ctx.MarkerType != MarkerGoMod { - t.Errorf("MarkerType = %v, want MarkerGoMod", ctx.MarkerType) - } -} - -func TestDetectWithTaskWing(t *testing.T) { - // Create in-memory filesystem - fs := afero.NewMemMapFs() - - // Create a directory structure with both .taskwing and go.mod - // .taskwing should take precedence - _ = fs.MkdirAll("/project/.taskwing", 0755) - _ = afero.WriteFile(fs, "/project/go.mod", []byte("module test"), 0644) - - detector := NewDetector(fs) - - ctx, err := detector.Detect("/project") - if err != nil { - t.Fatalf("Detect() error = %v", err) - } - - if ctx.MarkerType != MarkerTaskWing { - t.Errorf("MarkerType = %v, want MarkerTaskWing (should have highest priority)", ctx.MarkerType) - } -} - -func TestDetectWithGit(t *testing.T) { - // Create in-memory filesystem - fs := afero.NewMemMapFs() - - // Create a directory structure with only .git - _ = fs.MkdirAll("/project/.git", 0755) - - detector := NewDetector(fs) - - ctx, err := detector.Detect("/project") - if err != nil { - t.Fatalf("Detect() error = %v", err) - } - - if ctx.MarkerType != MarkerGit { - t.Errorf("MarkerType = %v, want MarkerGit", ctx.MarkerType) - } - - if ctx.GitRoot != "/project" { - t.Errorf("GitRoot = %v, want /project", ctx.GitRoot) - } -} - -func TestContextRelativeGitPath(t *testing.T) { - tests := []struct { - name string - ctx Context - expected string - }{ - { - name: "same path", - ctx: Context{RootPath: "/project", GitRoot: "/project"}, - expected: ".", - }, - { - name: "subdir of git root", - ctx: Context{RootPath: "/project/packages/api", GitRoot: "/project"}, - expected: "packages/api", - }, - { - name: "empty git root", - ctx: Context{RootPath: "/project", GitRoot: ""}, - expected: ".", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := tt.ctx.RelativeGitPath(); got != tt.expected { - t.Errorf("RelativeGitPath() = %v, want %v", got, tt.expected) - } - }) - } -} - -func TestWorkspaceTypeString(t *testing.T) { - tests := []struct { - wsType WorkspaceType - expected string - }{ - {WorkspaceTypeSingle, "single"}, - {WorkspaceTypeMonorepo, "monorepo"}, - {WorkspaceTypeMultiRepo, "multi-repo"}, - } - - for _, tt := range tests { - t.Run(tt.expected, func(t *testing.T) { - if got := tt.wsType.String(); got != tt.expected { - t.Errorf("WorkspaceType.String() = %v, want %v", got, tt.expected) - } - }) - } -} - -func TestWorkspaceInfoMethods(t *testing.T) { - info := &WorkspaceInfo{ - Type: WorkspaceTypeMonorepo, - RootPath: "/project", - Services: []string{"api", "web", "common"}, - Name: "myproject", - } - - if info.IsMultiRepo() { - t.Error("Monorepo should not be reported as multi-repo") - } - - if info.ServiceCount() != 3 { - t.Errorf("ServiceCount() = %d, want 3", info.ServiceCount()) - } - - expectedPath := "/project/api" - if got := info.GetServicePath("api"); got != expectedPath { - t.Errorf("GetServicePath() = %v, want %v", got, expectedPath) - } -} - -// === Workspace Auto-Detection Tests === - -func TestExtractWorkspaceName(t *testing.T) { - tests := []struct { - name string - relPath string - expected string - }{ - {"simple name", "osprey", "osprey"}, - {"nested path", "services/osprey", "osprey"}, - {"deeply nested", "apps/frontend/web", "web"}, - {"empty path", "", ""}, - {"dot path", ".", ""}, - {"root slash", "/", ""}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := extractWorkspaceName(tt.relPath) - if result != tt.expected { - t.Errorf("extractWorkspaceName(%q) = %q, want %q", tt.relPath, result, tt.expected) - } - }) - } -} - -func TestDetectWorkspaceFromPath_SingleRepo(t *testing.T) { - // For a single repo (non-monorepo), should always return "root" - // Use a path that we know won't be detected as a monorepo - workspace, err := DetectWorkspaceFromPath("/nonexistent/path") - if err != nil { - t.Errorf("unexpected error: %v", err) - } - if workspace != "root" { - t.Errorf("DetectWorkspaceFromPath for nonexistent path = %q, want 'root'", workspace) - } -} - -func TestAutoDetectWorkspace_FallbackToRoot(t *testing.T) { - // AutoDetectWorkspace should never return empty string - workspace, _ := DetectWorkspaceFromCwd() - // We can't predict the cwd, but it should either return "root" or a valid workspace name - if workspace == "" { - t.Error("DetectWorkspaceFromCwd returned empty string, expected 'root' or workspace name") - } -} - -func TestDetectWorkspaceFromPath_WithMonorepoContext(t *testing.T) { - // This test documents the expected behavior for monorepo detection - // The actual detection depends on the file system structure - // We test the edge cases here - - tests := []struct { - name string - relPath string - isMonorepo bool - expectedDefault string - }{ - {"at root", ".", false, "root"}, - {"empty", "", false, "root"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // For non-monorepo or at root, should return "root" - ctx := &Context{ - RootPath: "/project", - GitRoot: "/project", - IsMonorepo: tt.isMonorepo, - } - - relPath := ctx.RelativeGitPath() - workspace := extractWorkspaceName(relPath) - - // At root, RelativeGitPath returns ".", extractWorkspaceName returns "" - // which means we should fallback to "root" - if relPath == "." && workspace == "" { - workspace = "root" - } - - if workspace != tt.expectedDefault { - t.Errorf("workspace = %q, want %q", workspace, tt.expectedDefault) - } - }) - } -} diff --git a/internal/server/handlers.go b/internal/server/handlers.go index dd12ee9..e6b4c46 100644 --- a/internal/server/handlers.go +++ b/internal/server/handlers.go @@ -381,7 +381,7 @@ func (s *Server) handlePromoteToTask(w http.ResponseWriter, r *http.Request) { Status: task.StatusPending, Priority: 50, } - // Populate AI integration fields (scope, keywords, suggested_recall_queries) + // Populate AI integration fields (scope, keywords, suggested_ask_queries) newTask.EnrichAIFields() if err := s.repo.CreateTask(newTask); err != nil { diff --git a/internal/task/dag_test.go b/internal/task/dag_test.go deleted file mode 100644 index 95f7259..0000000 --- a/internal/task/dag_test.go +++ /dev/null @@ -1,126 +0,0 @@ -package task - -import ( - "testing" -) - -func TestVerifyDAG_NoCycle(t *testing.T) { - // A -> B -> C (linear, no cycle) - tasks := []Task{ - {ID: "task-A", Title: "Task A", Dependencies: nil}, - {ID: "task-B", Title: "Task B", Dependencies: []string{"task-A"}}, - {ID: "task-C", Title: "Task C", Dependencies: []string{"task-B"}}, - } - - if err := VerifyDAG(tasks); err != nil { - t.Errorf("VerifyDAG() returned error for valid DAG: %v", err) - } -} - -func TestVerifyDAG_WithCycle(t *testing.T) { - // A -> B -> C -> A (cycle) - tasks := []Task{ - {ID: "task-A", Title: "Task A", Dependencies: []string{"task-C"}}, - {ID: "task-B", Title: "Task B", Dependencies: []string{"task-A"}}, - {ID: "task-C", Title: "Task C", Dependencies: []string{"task-B"}}, - } - - err := VerifyDAG(tasks) - if err == nil { - t.Error("VerifyDAG() should return error for cycle, got nil") - } -} - -func TestVerifyDAG_EmptyID(t *testing.T) { - tasks := []Task{ - {ID: "", Title: "Task with no ID"}, - } - - err := VerifyDAG(tasks) - if err == nil { - t.Error("VerifyDAG() should return error for empty ID, got nil") - } -} - -func TestTopologicalSort_LinearDependencies(t *testing.T) { - // C depends on B, B depends on A - // Expected order: A, B, C - tasks := []Task{ - {ID: "task-C", Title: "Task C", Dependencies: []string{"task-B"}}, - {ID: "task-A", Title: "Task A", Dependencies: nil}, - {ID: "task-B", Title: "Task B", Dependencies: []string{"task-A"}}, - } - - sorted, err := TopologicalSort(tasks) - if err != nil { - t.Fatalf("TopologicalSort() error: %v", err) - } - - if len(sorted) != 3 { - t.Fatalf("Expected 3 tasks, got %d", len(sorted)) - } - - // A must come before B, B must come before C - posA, posB, posC := -1, -1, -1 - for i, task := range sorted { - switch task.ID { - case "task-A": - posA = i - case "task-B": - posB = i - case "task-C": - posC = i - } - } - - if posA >= posB { - t.Errorf("Task A (pos %d) should come before Task B (pos %d)", posA, posB) - } - if posB >= posC { - t.Errorf("Task B (pos %d) should come before Task C (pos %d)", posB, posC) - } -} - -func TestTopologicalSort_DiamondDependencies(t *testing.T) { - // Diamond: D depends on B and C, B and C both depend on A - // A - // / \ - // B C - // \ / - // D - tasks := []Task{ - {ID: "task-D", Title: "Task D", Dependencies: []string{"task-B", "task-C"}}, - {ID: "task-B", Title: "Task B", Dependencies: []string{"task-A"}}, - {ID: "task-C", Title: "Task C", Dependencies: []string{"task-A"}}, - {ID: "task-A", Title: "Task A", Dependencies: nil}, - } - - sorted, err := TopologicalSort(tasks) - if err != nil { - t.Fatalf("TopologicalSort() error: %v", err) - } - - if len(sorted) != 4 { - t.Fatalf("Expected 4 tasks, got %d", len(sorted)) - } - - // A must come first, D must come last - if sorted[0].ID != "task-A" { - t.Errorf("Expected first task to be A, got %s", sorted[0].ID) - } - if sorted[3].ID != "task-D" { - t.Errorf("Expected last task to be D, got %s", sorted[3].ID) - } -} - -func TestTopologicalSort_WithCycle(t *testing.T) { - tasks := []Task{ - {ID: "task-A", Title: "Task A", Dependencies: []string{"task-B"}}, - {ID: "task-B", Title: "Task B", Dependencies: []string{"task-A"}}, - } - - _, err := TopologicalSort(tasks) - if err == nil { - t.Error("TopologicalSort() should return error for cycle, got nil") - } -} diff --git a/internal/task/git_verifier_test.go b/internal/task/git_verifier_test.go deleted file mode 100644 index 41f0472..0000000 --- a/internal/task/git_verifier_test.go +++ /dev/null @@ -1,304 +0,0 @@ -package task - -import ( - "context" - "os" - "os/exec" - "path/filepath" - "testing" -) - -// TestGitVerifier_Verify_NoDiscrepancy tests when reported matches actual. -func TestGitVerifier_Verify_NoDiscrepancy(t *testing.T) { - // Create a temp git repo - dir := setupTestGitRepo(t) - - // Create and modify a file - testFile := filepath.Join(dir, "test.go") - if err := os.WriteFile(testFile, []byte("package main"), 0644); err != nil { - t.Fatalf("failed to create test file: %v", err) - } - - verifier := NewGitVerifier(dir) - result := verifier.Verify(context.Background(), []string{"test.go"}) - - if !result.IsVerified { - t.Errorf("expected verification to succeed, got error: %s", result.VerifyError) - } - - if result.HasDiscrepancy() { - t.Errorf("expected no discrepancy, got unreported=%v, over_reported=%v", - result.UnreportedFiles, result.OverReported) - } -} - -// TestGitVerifier_Verify_UnreportedFiles tests detection of files agent didn't report. -func TestGitVerifier_Verify_UnreportedFiles(t *testing.T) { - dir := setupTestGitRepo(t) - - // Create two files - for _, name := range []string{"reported.go", "unreported.go"} { - if err := os.WriteFile(filepath.Join(dir, name), []byte("package main"), 0644); err != nil { - t.Fatalf("failed to create %s: %v", name, err) - } - } - - verifier := NewGitVerifier(dir) - // Agent only reports one file - result := verifier.Verify(context.Background(), []string{"reported.go"}) - - if !result.IsVerified { - t.Fatalf("expected verification to succeed, got error: %s", result.VerifyError) - } - - if len(result.UnreportedFiles) != 1 || result.UnreportedFiles[0] != "unreported.go" { - t.Errorf("expected unreported.go in unreported files, got %v", result.UnreportedFiles) - } - - if !result.HasDiscrepancy() { - t.Error("expected HasDiscrepancy() to return true") - } -} - -// TestGitVerifier_Verify_OverReported tests detection of files agent claimed but didn't change. -func TestGitVerifier_Verify_OverReported(t *testing.T) { - dir := setupTestGitRepo(t) - - // Create only one file - if err := os.WriteFile(filepath.Join(dir, "actual.go"), []byte("package main"), 0644); err != nil { - t.Fatalf("failed to create actual.go: %v", err) - } - - verifier := NewGitVerifier(dir) - // Agent claims to have modified more files than actually changed - result := verifier.Verify(context.Background(), []string{"actual.go", "hallucinated.go"}) - - if !result.IsVerified { - t.Fatalf("expected verification to succeed, got error: %s", result.VerifyError) - } - - if len(result.OverReported) != 1 || result.OverReported[0] != "hallucinated.go" { - t.Errorf("expected hallucinated.go in over-reported files, got %v", result.OverReported) - } - - if !result.HasDiscrepancy() { - t.Error("expected HasDiscrepancy() to return true") - } -} - -// TestGitVerifier_Verify_NotGitRepo tests graceful handling of non-git directories. -func TestGitVerifier_Verify_NotGitRepo(t *testing.T) { - // Create temp dir without git init - dir, err := os.MkdirTemp("", "test-no-git-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(dir) }() - - verifier := NewGitVerifier(dir) - result := verifier.Verify(context.Background(), []string{"test.go"}) - - if result.IsVerified { - t.Error("expected verification to fail for non-git directory") - } - - if result.VerifyError == "" { - t.Error("expected error message for non-git directory") - } -} - -// TestGitVerifier_HasUnreportedHighRisk tests high-risk file detection. -func TestGitVerifier_HasUnreportedHighRisk(t *testing.T) { - tests := []struct { - name string - unreportedFiles []string - expectHighRisk bool - }{ - { - name: "no unreported files", - unreportedFiles: nil, - expectHighRisk: false, - }, - { - name: "normal file", - unreportedFiles: []string{"handler.go"}, - expectHighRisk: false, - }, - { - name: "config file", - unreportedFiles: []string{"config/database.yaml"}, - expectHighRisk: true, - }, - { - name: "secrets file", - unreportedFiles: []string{"internal/auth/secrets.go"}, - expectHighRisk: true, - }, - { - name: "env file", - unreportedFiles: []string{".env.production"}, - expectHighRisk: true, - }, - { - name: "migration file", - unreportedFiles: []string{"db/migrations/001_create_users.sql"}, - expectHighRisk: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := &VerificationResult{ - UnreportedFiles: tt.unreportedFiles, - IsVerified: true, - } - - got := result.HasUnreportedHighRisk() - if got != tt.expectHighRisk { - t.Errorf("HasUnreportedHighRisk() = %v, want %v", got, tt.expectHighRisk) - } - }) - } -} - -// TestGitVerifier_VerifyWithBaseline_ExcludesBaselineFiles tests that pre-existing -// modified files (baseline) are not flagged as unreported. -func TestGitVerifier_VerifyWithBaseline_ExcludesBaselineFiles(t *testing.T) { - dir := setupTestGitRepo(t) - - // Create two files: one in baseline, one new - for _, name := range []string{"baseline.go", "new.go"} { - if err := os.WriteFile(filepath.Join(dir, name), []byte("package main"), 0644); err != nil { - t.Fatalf("failed to create %s: %v", name, err) - } - } - - verifier := NewGitVerifier(dir) - - // Simulate: baseline.go was modified before task started - // Agent only reports new.go (correctly) - baseline := []string{"baseline.go"} - reported := []string{"new.go"} - - result := verifier.VerifyWithBaseline(context.Background(), reported, baseline) - - if !result.IsVerified { - t.Fatalf("expected verification to succeed, got error: %s", result.VerifyError) - } - - // baseline.go should NOT be in unreported (it was in baseline) - for _, f := range result.UnreportedFiles { - if f == "baseline.go" { - t.Error("baseline.go should not be flagged as unreported - it was in baseline") - } - } - - // Result should show no discrepancy (agent correctly reported new.go, baseline.go excluded) - if result.HasDiscrepancy() { - t.Errorf("expected no discrepancy when baseline is properly excluded, got unreported=%v, over_reported=%v", - result.UnreportedFiles, result.OverReported) - } -} - -// TestGitVerifier_VerifyWithBaseline_StillCatchesUnreported tests that files not in -// baseline AND not reported are still flagged. -func TestGitVerifier_VerifyWithBaseline_StillCatchesUnreported(t *testing.T) { - dir := setupTestGitRepo(t) - - // Create three files - for _, name := range []string{"baseline.go", "reported.go", "sneaky.go"} { - if err := os.WriteFile(filepath.Join(dir, name), []byte("package main"), 0644); err != nil { - t.Fatalf("failed to create %s: %v", name, err) - } - } - - verifier := NewGitVerifier(dir) - - // baseline.go was pre-existing, agent reports reported.go, but sneaky.go is hidden - baseline := []string{"baseline.go"} - reported := []string{"reported.go"} - - result := verifier.VerifyWithBaseline(context.Background(), reported, baseline) - - if !result.IsVerified { - t.Fatalf("expected verification to succeed, got error: %s", result.VerifyError) - } - - // sneaky.go should be flagged as unreported - foundSneaky := false - for _, f := range result.UnreportedFiles { - if f == "sneaky.go" { - foundSneaky = true - } - } - - if !foundSneaky { - t.Errorf("sneaky.go should be flagged as unreported, got unreported=%v", result.UnreportedFiles) - } -} - -// TestIsGitRepo tests git repository detection. -func TestIsGitRepo(t *testing.T) { - // Test with a real git repo - gitDir := setupTestGitRepo(t) - if !IsGitRepo(gitDir) { - t.Error("expected IsGitRepo to return true for git repository") - } - - // Test with a non-git directory - nonGitDir, err := os.MkdirTemp("", "test-no-git-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(nonGitDir) }() - - if IsGitRepo(nonGitDir) { - t.Error("expected IsGitRepo to return false for non-git directory") - } -} - -// setupTestGitRepo creates a temporary git repository for testing. -func setupTestGitRepo(t *testing.T) string { - t.Helper() - - dir, err := os.MkdirTemp("", "test-git-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - - t.Cleanup(func() { - _ = os.RemoveAll(dir) - }) - - // Initialize git repo - cmd := exec.Command("git", "init") - cmd.Dir = dir - if err := cmd.Run(); err != nil { - t.Fatalf("failed to init git repo: %v", err) - } - - // Configure git user for commits - cmd = exec.Command("git", "config", "user.email", "test@test.com") - cmd.Dir = dir - _ = cmd.Run() - - cmd = exec.Command("git", "config", "user.name", "Test") - cmd.Dir = dir - _ = cmd.Run() - - // Create initial commit so HEAD exists - readmeFile := filepath.Join(dir, "README.md") - if err := os.WriteFile(readmeFile, []byte("# Test"), 0644); err != nil { - t.Fatalf("failed to create README: %v", err) - } - - cmd = exec.Command("git", "add", ".") - cmd.Dir = dir - _ = cmd.Run() - - cmd = exec.Command("git", "commit", "-m", "initial") - cmd.Dir = dir - _ = cmd.Run() - - return dir -} diff --git a/internal/task/models.go b/internal/task/models.go index 64d04fe..3e53cc5 100644 --- a/internal/task/models.go +++ b/internal/task/models.go @@ -159,7 +159,7 @@ type Task struct { // AI integration fields - for MCP tool context fetching Scope string `json:"scope,omitempty"` // e.g., "auth", "api", "vectorsearch" Keywords []string `json:"keywords,omitempty"` // Extracted from title/description - SuggestedRecallQueries []string `json:"suggestedRecallQueries,omitempty"` // Pre-computed queries for recall tool + SuggestedAskQueries []string `json:"suggestedAskQueries,omitempty"` // Pre-computed queries for ask tool // Session tracking - for AI tool state management ClaimedBy string `json:"claimedBy,omitempty"` // Session ID that claimed this task @@ -292,7 +292,7 @@ var stopWords = map[string]bool{ "so": true, "than": true, "too": true, "very": true, "just": true, "also": true, } -// EnrichAIFields populates Scope, Keywords, and SuggestedRecallQueries from title/description. +// EnrichAIFields populates Scope, Keywords, and SuggestedAskQueries from title/description. // Call this before CreateTask to ensure AI integration fields are set. // // This is part of the early binding context strategy - see docs/architecture/ADR_CONTEXT_BINDING.md @@ -311,7 +311,7 @@ var stopWords = map[string]bool{ // - Highest-scoring scope wins; defaults to "general" if no matches // - Scopes are configurable via task.scopes in .taskwing.yaml // -// 3. RECALL QUERY GENERATION: +// 3. ASK QUERY GENERATION: // - Query 1: " patterns constraints decisions" - domain-specific architecture // - Query 2: Top 5 keywords joined - content-specific search // - Query 3: Simplified title words - intent-focused search @@ -382,7 +382,7 @@ func (t *Task) EnrichAIFields() { t.Scope = effectiveScope } - // Generate suggested recall queries + // Generate suggested ask queries var queries []string // Query 1: Scope-based patterns and constraints @@ -413,5 +413,5 @@ func (t *Task) EnrichAIFields() { queries = append(queries, strings.Join(titleKw, " ")) } - t.SuggestedRecallQueries = queries + t.SuggestedAskQueries = queries } diff --git a/internal/task/models_test.go b/internal/task/models_test.go deleted file mode 100644 index 55dd7a0..0000000 --- a/internal/task/models_test.go +++ /dev/null @@ -1,223 +0,0 @@ -package task - -import ( - "encoding/json" - "strings" - "testing" -) - -func TestTask_Validate(t *testing.T) { - tests := []struct { - name string - task Task - wantErr bool - }{ - { - name: "valid task", - task: Task{ - Title: "Valid Task", - Description: "Valid Description", - Priority: 50, - }, - wantErr: false, - }, - { - name: "empty title", - task: Task{ - Title: "", - Description: "Valid Description", - Priority: 50, - }, - wantErr: true, // title required - }, - { - name: "long title", - task: Task{ - Title: strings.Repeat("a", 201), - Description: "Valid Description", - Priority: 50, - }, - wantErr: true, // max 200 - }, - { - name: "empty description", - task: Task{ - Title: "Valid Task", - Description: "", - Priority: 50, - }, - wantErr: true, // description required - }, - { - name: "priority too low", - task: Task{ - Title: "Valid Task", - Description: "Valid Description", - Priority: -1, - }, - wantErr: true, // 0-100 - }, - { - name: "priority too high", - task: Task{ - Title: "Valid Task", - Description: "Valid Description", - Priority: 101, - }, - wantErr: true, // 0-100 - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := tt.task.Validate(); (err != nil) != tt.wantErr { - t.Errorf("Task.Validate() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - -// TestPlanIDJSONSchema_SnakeCase tests that plan_id is used in JSON output. -func TestPlanIDJSONSchema_SnakeCase(t *testing.T) { - task := Task{ - ID: "task-123", - PlanID: "plan-456", - Title: "Test Task", - Description: "Test Description", - } - - data, err := json.Marshal(task) - if err != nil { - t.Fatalf("Failed to marshal task: %v", err) - } - - jsonStr := string(data) - - // Should contain plan_id (snake_case) - if !strings.Contains(jsonStr, `"plan_id"`) { - t.Errorf("JSON output should use 'plan_id', got: %s", jsonStr) - } - - // Should NOT contain planId (camelCase) in output - if strings.Contains(jsonStr, `"planId"`) { - t.Errorf("JSON output should NOT use 'planId', got: %s", jsonStr) - } -} - -// TestPlanIDJSONSchema_AcceptSnakeCase tests that plan_id is correctly unmarshaled. -func TestPlanIDJSONSchema_AcceptSnakeCase(t *testing.T) { - jsonData := `{"id":"task-123","plan_id":"plan-456","title":"Test","description":"Test"}` - - var task Task - if err := json.Unmarshal([]byte(jsonData), &task); err != nil { - t.Fatalf("Failed to unmarshal task: %v", err) - } - - if task.PlanID != "plan-456" { - t.Errorf("PlanID = %q, want %q", task.PlanID, "plan-456") - } -} - -// TestPlanIDJSONSchema_RejectCamelCaseAlias tests that planId is rejected. -func TestPlanIDJSONSchema_RejectCamelCaseAlias(t *testing.T) { - jsonData := `{"id":"task-123","planId":"plan-789","title":"Test","description":"Test"}` - - var task Task - err := json.Unmarshal([]byte(jsonData), &task) - if err == nil { - t.Fatal("expected unmarshal error for legacy planId") - } - if !strings.Contains(err.Error(), "planId") { - t.Fatalf("unexpected error: %v", err) - } -} - -// TestPlanIDJSONSchema_RejectWhenBothKeysPresent tests strict rejection when legacy key exists. -func TestPlanIDJSONSchema_RejectWhenBothKeysPresent(t *testing.T) { - jsonData := `{"id":"task-123","plan_id":"plan-primary","planId":"plan-alias","title":"Test","description":"Test"}` - - var task Task - err := json.Unmarshal([]byte(jsonData), &task) - if err == nil { - t.Fatal("expected unmarshal error for legacy planId") - } - if !strings.Contains(err.Error(), "planId") { - t.Fatalf("unexpected error: %v", err) - } -} - -// TestPlanIDJSONSchema_RoundTrip tests that marshal -> unmarshal preserves PlanID. -func TestPlanIDJSONSchema_RoundTrip(t *testing.T) { - original := Task{ - ID: "task-123", - PlanID: "plan-456", - Title: "Test Task", - Description: "Test Description", - Priority: 50, - } - - data, err := json.Marshal(original) - if err != nil { - t.Fatalf("Failed to marshal task: %v", err) - } - - var decoded Task - if err := json.Unmarshal(data, &decoded); err != nil { - t.Fatalf("Failed to unmarshal task: %v", err) - } - - if decoded.PlanID != original.PlanID { - t.Errorf("PlanID after round-trip = %q, want %q", decoded.PlanID, original.PlanID) - } -} - -// TestPlanIDJSONSchema_EmptyValues tests edge cases with empty values. -func TestPlanIDJSONSchema_EmptyValues(t *testing.T) { - tests := []struct { - name string - jsonData string - wantPlanID string - }{ - { - name: "empty plan_id", - jsonData: `{"id":"task-123","plan_id":"","title":"Test","description":"Test"}`, - wantPlanID: "", - }, - { - name: "null plan_id", - jsonData: `{"id":"task-123","plan_id":null,"title":"Test","description":"Test"}`, - wantPlanID: "", - }, - { - name: "both missing", - jsonData: `{"id":"task-123","title":"Test","description":"Test"}`, - wantPlanID: "", - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - var task Task - if err := json.Unmarshal([]byte(tc.jsonData), &task); err != nil { - t.Fatalf("Failed to unmarshal: %v", err) - } - - if task.PlanID != tc.wantPlanID { - t.Errorf("PlanID = %q, want %q", task.PlanID, tc.wantPlanID) - } - }) - } -} - -func TestPlanIDJSONSchema_LegacyAliasRejected(t *testing.T) { - jsonData := `{"id":"task-123","planId":"plan-fallback","title":"Test","description":"Test"}` - - var task Task - err := json.Unmarshal([]byte(jsonData), &task) - if err == nil { - t.Fatal("expected unmarshal error for legacy planId") - } - if !strings.Contains(err.Error(), "planId") { - t.Fatalf("unexpected error: %v", err) - } -} diff --git a/internal/task/presentation.go b/internal/task/presentation.go index 87b61d0..17c876f 100644 --- a/internal/task/presentation.go +++ b/internal/task/presentation.go @@ -6,12 +6,12 @@ import ( "strings" ) -// RecallSearchFunc is the signature for a context/recall search function. +// AskSearchFunc is the signature for a context/ask search function. // This breaks the import cycle by avoiding direct dependency on knowledge.Service. -type RecallSearchFunc func(ctx context.Context, query string, limit int) ([]RecallResult, error) +type AskSearchFunc func(ctx context.Context, query string, limit int) ([]AskResult, error) -// RecallResult is a minimal struct for context search results. -type RecallResult struct { +// AskResult is a minimal struct for context search results. +type AskResult struct { Summary string Type string Content string @@ -23,16 +23,16 @@ type RecallResult struct { // Context Binding Strategy (see docs/architecture/ADR_CONTEXT_BINDING.md): // - Early binding: Uses Task.ContextSummary if available (populated at creation) // - Late binding: Falls back to searchFn if ContextSummary is empty (backward compatibility) -func FormatRichContext(ctx context.Context, t *Task, p *Plan, searchFn RecallSearchFunc) string { - var recallContext string +func FormatRichContext(ctx context.Context, t *Task, p *Plan, searchFn AskSearchFunc) string { + var askContext string // Early binding: Use pre-computed ContextSummary if available if t.ContextSummary != "" { - recallContext = "\n" + t.ContextSummary - } else if len(t.SuggestedRecallQueries) > 0 && searchFn != nil { + askContext = "\n" + t.ContextSummary + } else if len(t.SuggestedAskQueries) > 0 && searchFn != nil { // Late binding fallback: Fetch context dynamically using ALL queries - var allResults []RecallResult - for _, query := range t.SuggestedRecallQueries { + var allResults []AskResult + for _, query := range t.SuggestedAskQueries { results, err := searchFn(ctx, query, 3) if err == nil { allResults = append(allResults, results...) @@ -56,7 +56,7 @@ func FormatRichContext(ctx context.Context, t *Task, p *Plan, searchFn RecallSea } sb.WriteString(fmt.Sprintf("- **%s** (%s): %s\n", r.Summary, r.Type, content)) } - recallContext = sb.String() + askContext = sb.String() } } @@ -102,8 +102,8 @@ Plan Progress: %d%% (%d/%d tasks completed) } } - if recallContext != "" { - contextStr += recallContext + if askContext != "" { + contextStr += askContext } contextStr += ` diff --git a/internal/task/scope_config_test.go b/internal/task/scope_config_test.go deleted file mode 100644 index 35350e2..0000000 --- a/internal/task/scope_config_test.go +++ /dev/null @@ -1,199 +0,0 @@ -package task - -import ( - "testing" - - "github.com/spf13/viper" -) - -func TestGetScopeConfig_Defaults(t *testing.T) { - // Reset for clean test state - ResetScopeConfig() - viper.Reset() - t.Cleanup(func() { - ResetScopeConfig() - viper.Reset() - }) - - cfg := GetScopeConfig() - - // Verify defaults are loaded - scopes := cfg.GetScopes() - if len(scopes) == 0 { - t.Error("Expected default scopes to be loaded") - } - - // Check a known default scope - authKeywords, ok := scopes["auth"] - if !ok { - t.Error("Expected 'auth' scope in defaults") - } - if len(authKeywords) == 0 { - t.Error("Expected auth scope to have keywords") - } - - // Verify default limits - if cfg.MaxKeywords() != defaultMaxKeywords { - t.Errorf("Expected maxKeywords=%d, got %d", defaultMaxKeywords, cfg.MaxKeywords()) - } - if cfg.MinWordLength() != defaultMinWordLen { - t.Errorf("Expected minWordLength=%d, got %d", defaultMinWordLen, cfg.MinWordLength()) - } -} - -func TestGetScopeConfig_CustomScopes(t *testing.T) { - // Reset for clean test state - ResetScopeConfig() - viper.Reset() - t.Cleanup(func() { - ResetScopeConfig() - viper.Reset() - }) - - // Configure custom scopes - viper.Set("task.scopes", map[string][]string{ - "custom_domain": {"keyword1", "keyword2", "keyword3"}, - "auth": {"custom_auth_keyword"}, // Override default - }) - - cfg := GetScopeConfig() - scopes := cfg.GetScopes() - - // Custom scope should exist - customKw, ok := scopes["custom_domain"] - if !ok { - t.Error("Expected 'custom_domain' scope to be loaded from config") - } - if len(customKw) != 3 { - t.Errorf("Expected 3 keywords in custom_domain, got %d", len(customKw)) - } - - // Auth scope should be overridden - authKw := scopes["auth"] - if len(authKw) != 1 || authKw[0] != "custom_auth_keyword" { - t.Errorf("Expected auth scope to be overridden, got %v", authKw) - } - - // Other default scopes should still exist (merged) - if _, ok := scopes["api"]; !ok { - t.Error("Expected default 'api' scope to still exist after merge") - } -} - -func TestGetScopeConfig_CustomLimits(t *testing.T) { - // Reset for clean test state - ResetScopeConfig() - viper.Reset() - t.Cleanup(func() { - ResetScopeConfig() - viper.Reset() - }) - - // Configure custom limits - viper.Set("task.maxKeywords", 20) - viper.Set("task.minWordLength", 4) - - cfg := GetScopeConfig() - - if cfg.MaxKeywords() != 20 { - t.Errorf("Expected maxKeywords=20, got %d", cfg.MaxKeywords()) - } - if cfg.MinWordLength() != 4 { - t.Errorf("Expected minWordLength=4, got %d", cfg.MinWordLength()) - } -} - -func TestScopeConfig_InferScope(t *testing.T) { - // Reset for clean test state - ResetScopeConfig() - viper.Reset() - t.Cleanup(func() { - ResetScopeConfig() - viper.Reset() - }) - - cfg := GetScopeConfig() - - tests := []struct { - name string - words map[string]bool - expected string - }{ - { - name: "auth scope detection", - words: map[string]bool{"login": true, "password": true, "jwt": true}, - expected: "auth", - }, - { - name: "database scope detection", - words: map[string]bool{"db": true, "sql": true, "migration": true}, - expected: "database", - }, - { - name: "api scope detection", - words: map[string]bool{"endpoint": true, "handler": true, "rest": true}, - expected: "api", - }, - { - name: "general fallback", - words: map[string]bool{"random": true, "words": true, "here": true}, - expected: "general", - }, - { - name: "short abbreviation matching", - words: map[string]bool{"ui": true, "tui": true}, - expected: "ui", - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result := cfg.InferScope(tc.words) - if result != tc.expected { - t.Errorf("Expected scope %q, got %q", tc.expected, result) - } - }) - } -} - -func TestEnrichAIFields_UsesConfigurableScopes(t *testing.T) { - // Reset for clean test state - ResetScopeConfig() - viper.Reset() - t.Cleanup(func() { - ResetScopeConfig() - viper.Reset() - }) - - // Add custom scope - viper.Set("task.scopes", map[string][]string{ - "payments": {"payment", "stripe", "checkout", "invoice"}, - }) - - // Force reload - ResetScopeConfig() - - task := &Task{ - Title: "Implement Stripe payment integration", - Description: "Add checkout flow with invoice generation using the Stripe API", - } - task.EnrichAIFields() - - if task.Scope != "payments" { - t.Errorf("Expected scope 'payments' for payment-related task, got %q", task.Scope) - } - - // Verify keywords were extracted - if len(task.Keywords) == 0 { - t.Error("Expected keywords to be extracted") - } - - // Verify recall queries were generated - if len(task.SuggestedRecallQueries) == 0 { - t.Error("Expected recall queries to be generated") - } - // First query should include the inferred scope - if task.SuggestedRecallQueries[0] != "payments patterns constraints decisions" { - t.Errorf("Expected first query to include 'payments' scope, got %q", task.SuggestedRecallQueries[0]) - } -} diff --git a/internal/task/sentinel_policy_test.go b/internal/task/sentinel_policy_test.go deleted file mode 100644 index 9220ea1..0000000 --- a/internal/task/sentinel_policy_test.go +++ /dev/null @@ -1,285 +0,0 @@ -package task - -import ( - "context" - "testing" -) - -// mockPolicyEvaluator implements PolicyEvaluator for testing. -type mockPolicyEvaluator struct { - allowAll bool - violations []string - decisionID string - err error - policyCount int -} - -func (m *mockPolicyEvaluator) EvaluateTaskPolicy(ctx context.Context, taskID, taskTitle, taskDescription string, filesModified, filesCreated []string, planID, planGoal string) (bool, []string, string, error) { - if m.err != nil { - return false, nil, "", m.err - } - if m.allowAll { - return true, nil, m.decisionID, nil - } - // Check for protected files - for _, f := range filesModified { - if isProtectedFile(f) { - return false, m.violations, m.decisionID, nil - } - } - return true, nil, m.decisionID, nil -} - -func (m *mockPolicyEvaluator) EvaluateFilesPolicy(ctx context.Context, filesModified, filesCreated []string) (bool, []string, string, error) { - if m.err != nil { - return false, nil, "", m.err - } - if m.allowAll { - return true, nil, m.decisionID, nil - } - for _, f := range filesModified { - if isProtectedFile(f) { - return false, m.violations, m.decisionID, nil - } - } - return true, nil, m.decisionID, nil -} - -func (m *mockPolicyEvaluator) PolicyCount() int { - return m.policyCount -} - -// isProtectedFile checks if a file matches protected patterns. -func isProtectedFile(file string) bool { - protectedPatterns := []string{".env", ".env.local", ".env.production", "secrets/"} - for _, pattern := range protectedPatterns { - if len(file) >= len(pattern) { - // Simple contains check for testing - for i := 0; i <= len(file)-len(pattern); i++ { - if file[i:i+len(pattern)] == pattern { - return true - } - } - } - } - return false -} - -func TestPolicyEnforcer_NoPolicies(t *testing.T) { - enforcer := NewPolicyEnforcer(nil, "session-123") - - task := &Task{ - ID: "task-123", - Title: "Test task", - FilesModified: []string{".env"}, - } - - result := enforcer.Enforce(context.Background(), task, "Test plan goal") - - if !result.Allowed { - t.Error("Expected task to be allowed when no policies are configured") - } - if enforcer.HasPolicies() { - t.Error("HasPolicies should return false when evaluator is nil") - } -} - -func TestPolicyEnforcer_AllowNonProtectedFile(t *testing.T) { - evaluator := &mockPolicyEvaluator{ - allowAll: false, - violations: []string{"Cannot modify .env files"}, - decisionID: "decision-456", - policyCount: 1, - } - - enforcer := NewPolicyEnforcer(evaluator, "session-123") - - task := &Task{ - ID: "task-123", - Title: "Add feature", - FilesModified: []string{"internal/app/main.go", "internal/app/handler.go"}, - } - - result := enforcer.Enforce(context.Background(), task, "Implement feature") - - if !result.Allowed { - t.Error("Expected task to be allowed for non-protected files") - } -} - -func TestPolicyEnforcer_DenyEnvFile(t *testing.T) { - evaluator := &mockPolicyEvaluator{ - allowAll: false, - violations: []string{"Cannot modify .env files"}, - decisionID: "decision-789", - policyCount: 1, - } - - enforcer := NewPolicyEnforcer(evaluator, "session-123") - - task := &Task{ - ID: "task-123", - Title: "Update config", - FilesModified: []string{".env"}, - } - - result := enforcer.Enforce(context.Background(), task, "Update configuration") - - if result.Allowed { - t.Error("Expected task to be denied for .env file modification") - } - if len(result.Violations) == 0 { - t.Error("Expected violations to be set") - } - if result.DecisionID != "decision-789" { - t.Errorf("Expected decision ID 'decision-789', got '%s'", result.DecisionID) - } -} - -func TestPolicyEnforcer_DenyEnvLocalFile(t *testing.T) { - evaluator := &mockPolicyEvaluator{ - allowAll: false, - violations: []string{"Cannot modify environment files"}, - decisionID: "decision-001", - policyCount: 1, - } - - enforcer := NewPolicyEnforcer(evaluator, "session-123") - - task := &Task{ - ID: "task-123", - Title: "Update local config", - FilesModified: []string{".env.local"}, - } - - result := enforcer.Enforce(context.Background(), task, "Update local configuration") - - if result.Allowed { - t.Error("Expected task to be denied for .env.local file modification") - } -} - -func TestPolicyEnforcer_DenySecretsDirectory(t *testing.T) { - evaluator := &mockPolicyEvaluator{ - allowAll: false, - violations: []string{"Cannot modify secrets directory"}, - decisionID: "decision-002", - policyCount: 1, - } - - enforcer := NewPolicyEnforcer(evaluator, "session-123") - - task := &Task{ - ID: "task-123", - Title: "Update secrets", - FilesModified: []string{"secrets/api_key.json"}, - } - - result := enforcer.Enforce(context.Background(), task, "Update secrets") - - if result.Allowed { - t.Error("Expected task to be denied for secrets/ directory modification") - } -} - -func TestPolicyEnforcer_EnforceFiles(t *testing.T) { - evaluator := &mockPolicyEvaluator{ - allowAll: false, - violations: []string{"Cannot modify .env files"}, - decisionID: "decision-003", - policyCount: 1, - } - - enforcer := NewPolicyEnforcer(evaluator, "session-123") - - // Test allowed files - result := enforcer.EnforceFiles(context.Background(), []string{"main.go"}, nil) - if !result.Allowed { - t.Error("Expected main.go to be allowed") - } - - // Test denied files - result = enforcer.EnforceFiles(context.Background(), []string{".env.production"}, nil) - if result.Allowed { - t.Error("Expected .env.production to be denied") - } -} - -func TestPolicyEnforcer_HasPolicies(t *testing.T) { - // No evaluator - enforcer := NewPolicyEnforcer(nil, "session-123") - if enforcer.HasPolicies() { - t.Error("HasPolicies should return false with nil evaluator") - } - - // Evaluator with no policies - evaluator := &mockPolicyEvaluator{policyCount: 0} - enforcer = NewPolicyEnforcer(evaluator, "session-123") - if enforcer.HasPolicies() { - t.Error("HasPolicies should return false with 0 policies") - } - - // Evaluator with policies - evaluator = &mockPolicyEvaluator{policyCount: 2} - enforcer = NewPolicyEnforcer(evaluator, "session-123") - if !enforcer.HasPolicies() { - t.Error("HasPolicies should return true with policies loaded") - } -} - -func TestPolicyEnforcer_PolicyCount(t *testing.T) { - enforcer := NewPolicyEnforcer(nil, "session-123") - if enforcer.PolicyCount() != 0 { - t.Errorf("PolicyCount should return 0 with nil evaluator, got %d", enforcer.PolicyCount()) - } - - evaluator := &mockPolicyEvaluator{policyCount: 5} - enforcer = NewPolicyEnforcer(evaluator, "session-123") - if enforcer.PolicyCount() != 5 { - t.Errorf("PolicyCount should return 5, got %d", enforcer.PolicyCount()) - } -} - -func TestPolicyEnforcer_EvaluationError(t *testing.T) { - evaluator := &mockPolicyEvaluator{ - err: context.DeadlineExceeded, - policyCount: 1, - } - - enforcer := NewPolicyEnforcer(evaluator, "session-123") - - task := &Task{ - ID: "task-123", - Title: "Test task", - FilesModified: []string{"main.go"}, - } - - result := enforcer.Enforce(context.Background(), task, "Test goal") - - if result.Allowed { - t.Error("Expected task to be denied on evaluation error") - } - if result.Error == nil { - t.Error("Expected error to be set on evaluation failure") - } -} - -func TestPolicyEnforcementResult_AllowedByDefault(t *testing.T) { - // When no evaluator is configured, tasks should be allowed by default - enforcer := NewPolicyEnforcer(nil, "test-session") - - task := &Task{ - ID: "task-001", - Title: "Dangerous task", - FilesModified: []string{".env", "secrets/key.json", "config/credentials.yaml"}, - } - - result := enforcer.Enforce(context.Background(), task, "Some goal") - - if !result.Allowed { - t.Error("With no policy evaluator, all tasks should be allowed") - } - if result.Error != nil { - t.Errorf("Expected no error, got: %v", result.Error) - } -} diff --git a/internal/task/sentinel_test.go b/internal/task/sentinel_test.go deleted file mode 100644 index 7906592..0000000 --- a/internal/task/sentinel_test.go +++ /dev/null @@ -1,147 +0,0 @@ -package task - -import ( - "testing" -) - -func TestSentinelAnalyze_PerfectMatch(t *testing.T) { - s := NewSentinel() - task := &Task{ - ID: "task-123", - Title: "Add authentication", - ExpectedFiles: []string{"internal/auth/handler.go", "internal/auth/middleware.go"}, - FilesModified: []string{"internal/auth/handler.go", "internal/auth/middleware.go"}, - } - - report := s.Analyze(task) - - if len(report.Deviations) != 0 { - t.Errorf("Expected 0 deviations, got %d", len(report.Deviations)) - } - if report.DeviationRate != 0.0 { - t.Errorf("Expected deviation rate 0.0, got %f", report.DeviationRate) - } - if report.HasDeviations() { - t.Error("Expected HasDeviations() to return false") - } -} - -func TestSentinelAnalyze_DriftDetection(t *testing.T) { - s := NewSentinel() - task := &Task{ - ID: "task-123", - Title: "Add authentication", - ExpectedFiles: []string{"internal/auth/handler.go"}, - FilesModified: []string{"internal/auth/handler.go", "internal/auth/extra.go", "internal/db/schema.go"}, - } - - report := s.Analyze(task) - - driftDeviations := report.GetDeviationsByType(DeviationDrift) - if len(driftDeviations) != 2 { - t.Errorf("Expected 2 drift deviations, got %d", len(driftDeviations)) - } - - if !report.HasDeviations() { - t.Error("Expected HasDeviations() to return true") - } -} - -func TestSentinelAnalyze_MissingDetection(t *testing.T) { - s := NewSentinel() - task := &Task{ - ID: "task-123", - Title: "Add authentication", - ExpectedFiles: []string{"internal/auth/handler.go", "internal/auth/middleware.go"}, - FilesModified: []string{"internal/auth/handler.go"}, - } - - report := s.Analyze(task) - - missingDeviations := report.GetDeviationsByType(DeviationMissing) - if len(missingDeviations) != 1 { - t.Errorf("Expected 1 missing deviation, got %d", len(missingDeviations)) - } - if missingDeviations[0].File != "internal/auth/middleware.go" { - t.Errorf("Expected missing file to be middleware.go, got %s", missingDeviations[0].File) - } -} - -func TestSentinelAnalyze_HighRiskFile(t *testing.T) { - s := NewSentinel() - task := &Task{ - ID: "task-123", - Title: "Update config", - ExpectedFiles: []string{}, - FilesModified: []string{"config/secrets.yaml"}, - } - - report := s.Analyze(task) - - if !report.HasCriticalDeviations() { - t.Error("Expected HasCriticalDeviations() to return true for secrets file") - } - - driftDeviations := report.GetDeviationsByType(DeviationDrift) - if len(driftDeviations) != 1 { - t.Fatalf("Expected 1 drift deviation, got %d", len(driftDeviations)) - } - if driftDeviations[0].Severity != SeverityError { - t.Errorf("Expected severity Error for secrets file, got %s", driftDeviations[0].Severity) - } -} - -func TestSentinelAnalyze_EmptyExpected(t *testing.T) { - s := NewSentinel() - task := &Task{ - ID: "task-123", - Title: "Research task", - ExpectedFiles: []string{}, - FilesModified: []string{"notes.md"}, - } - - report := s.Analyze(task) - - // Files modified with no expected = deviation rate of 1.0 - if report.DeviationRate != 1.0 { - t.Errorf("Expected deviation rate 1.0, got %f", report.DeviationRate) - } -} - -func TestSentinelAnalyze_NoFiles(t *testing.T) { - s := NewSentinel() - task := &Task{ - ID: "task-123", - Title: "Review task", - ExpectedFiles: []string{}, - FilesModified: []string{}, - } - - report := s.Analyze(task) - - if report.DeviationRate != 0.0 { - t.Errorf("Expected deviation rate 0.0, got %f", report.DeviationRate) - } - if report.HasDeviations() { - t.Error("Expected no deviations for empty task") - } -} - -func TestNormalizePath(t *testing.T) { - tests := []struct { - input string - expected string - }{ - {"./internal/auth/handler.go", "internal/auth/handler.go"}, - {"internal/auth/handler.go", "internal/auth/handler.go"}, - {"internal//auth//handler.go", "internal/auth/handler.go"}, - {"./internal/../internal/auth/handler.go", "internal/auth/handler.go"}, - } - - for _, tt := range tests { - result := normalizePath(tt.input) - if result != tt.expected { - t.Errorf("normalizePath(%q) = %q, expected %q", tt.input, result, tt.expected) - } - } -} diff --git a/internal/telemetry/client.go b/internal/telemetry/client.go index d1e76b9..64cff45 100644 --- a/internal/telemetry/client.go +++ b/internal/telemetry/client.go @@ -94,16 +94,6 @@ func NewPostHogClient(cfg ClientConfig) (*PostHogClient, error) { }, nil } -// newPostHogClientWithEnqueuer creates a client with a custom enqueuer (for testing). -func newPostHogClientWithEnqueuer(enq enqueuer, cfg *Config, version string) *PostHogClient { - return &PostHogClient{ - client: enq, - config: cfg, - version: version, - initialized: true, - } -} - // Track sends an event asynchronously. // Returns immediately without blocking the CLI. // No-op if telemetry is disabled or client is not initialized. diff --git a/internal/telemetry/client_test.go b/internal/telemetry/client_test.go deleted file mode 100644 index e1e9cc8..0000000 --- a/internal/telemetry/client_test.go +++ /dev/null @@ -1,312 +0,0 @@ -package telemetry - -import ( - "runtime" - "sync" - "testing" - "time" - - "github.com/posthog/posthog-go" -) - -// mockEnqueuer captures events for testing. -type mockEnqueuer struct { - mu sync.Mutex - events []posthog.Capture - closed bool -} - -func (m *mockEnqueuer) Enqueue(msg posthog.Message) error { - m.mu.Lock() - defer m.mu.Unlock() - - if capture, ok := msg.(posthog.Capture); ok { - m.events = append(m.events, capture) - } - return nil -} - -func (m *mockEnqueuer) Close() error { - m.mu.Lock() - defer m.mu.Unlock() - m.closed = true - return nil -} - -func (m *mockEnqueuer) getEvents() []posthog.Capture { - m.mu.Lock() - defer m.mu.Unlock() - result := make([]posthog.Capture, len(m.events)) - copy(result, m.events) - return result -} - -func (m *mockEnqueuer) isClosed() bool { - m.mu.Lock() - defer m.mu.Unlock() - return m.closed -} - -// newTestClient creates a PostHogClient with a mock enqueuer for testing. -func newTestClient(cfg *Config, version string) (*PostHogClient, *mockEnqueuer) { - mock := &mockEnqueuer{} - client := newPostHogClientWithEnqueuer(mock, cfg, version) - return client, mock -} - -func TestPostHogClient_Track_WhenEnabled(t *testing.T) { - cfg := &Config{ - Enabled: true, - ConsentAsked: true, - AnonymousID: "test-anon-id-123", - } - - client, mock := newTestClient(cfg, "1.2.3") - - // Track an event - client.Track("command_executed", Properties{ - "command": "bootstrap", - "success": true, - "duration": 1500, - }) - - // Verify event was captured - events := mock.getEvents() - if len(events) != 1 { - t.Fatalf("expected 1 event, got %d", len(events)) - } - - event := events[0] - - // Check event name - if event.Event != "command_executed" { - t.Errorf("event name = %q, want %q", event.Event, "command_executed") - } - - // Check distinct ID is anonymous ID - if event.DistinctId != "test-anon-id-123" { - t.Errorf("distinct_id = %q, want %q", event.DistinctId, "test-anon-id-123") - } - - // Check custom properties - if event.Properties["command"] != "bootstrap" { - t.Errorf("command = %v, want %q", event.Properties["command"], "bootstrap") - } - if event.Properties["success"] != true { - t.Errorf("success = %v, want true", event.Properties["success"]) - } - if event.Properties["duration"] != 1500 { - t.Errorf("duration = %v, want 1500", event.Properties["duration"]) - } - - // Check standard properties are added - if event.Properties["os"] != runtime.GOOS { - t.Errorf("os = %v, want %q", event.Properties["os"], runtime.GOOS) - } - if event.Properties["arch"] != runtime.GOARCH { - t.Errorf("arch = %v, want %q", event.Properties["arch"], runtime.GOARCH) - } - if event.Properties["cli_version"] != "1.2.3" { - t.Errorf("cli_version = %v, want %q", event.Properties["cli_version"], "1.2.3") - } -} - -func TestPostHogClient_Track_WhenDisabled(t *testing.T) { - cfg := &Config{ - Enabled: false, // Disabled - ConsentAsked: true, - AnonymousID: "test-anon-id-123", - } - - client, mock := newTestClient(cfg, "1.2.3") - - // Track an event - client.Track("command_executed", Properties{ - "command": "bootstrap", - }) - - // Verify no events were captured - events := mock.getEvents() - if len(events) != 0 { - t.Errorf("expected 0 events when disabled, got %d", len(events)) - } -} - -func TestPostHogClient_Track_NotInitialized(t *testing.T) { - client := &PostHogClient{ - config: &Config{Enabled: true}, - initialized: false, // Not initialized - } - - // This should not panic - client.Track("test_event", nil) -} - -func TestPostHogClient_Track_NilConfig(t *testing.T) { - mock := &mockEnqueuer{} - client := &PostHogClient{ - client: mock, - config: nil, // Nil config - initialized: true, - } - - // This should not panic and should be a no-op - client.Track("test_event", nil) - - events := mock.getEvents() - if len(events) != 0 { - t.Errorf("expected 0 events with nil config, got %d", len(events)) - } -} - -func TestPostHogClient_Track_NilProperties(t *testing.T) { - cfg := &Config{ - Enabled: true, - ConsentAsked: true, - AnonymousID: "test-anon-id", - } - - client, mock := newTestClient(cfg, "1.0.0") - - // Track with nil properties - client.Track("simple_event", nil) - - events := mock.getEvents() - if len(events) != 1 { - t.Fatalf("expected 1 event, got %d", len(events)) - } - - // Standard properties should still be added - event := events[0] - if event.Properties["os"] != runtime.GOOS { - t.Errorf("os should be set even with nil properties") - } -} - -func TestPostHogClient_Close(t *testing.T) { - cfg := &Config{ - Enabled: true, - ConsentAsked: true, - AnonymousID: "test-anon-id", - } - - client, mock := newTestClient(cfg, "1.0.0") - - if err := client.Close(); err != nil { - t.Errorf("Close() error = %v", err) - } - - if !mock.isClosed() { - t.Error("underlying client should be closed") - } -} - -func TestPostHogClient_Close_NotInitialized(t *testing.T) { - client := &PostHogClient{ - initialized: false, - } - - // Should not error - if err := client.Close(); err != nil { - t.Errorf("Close() error = %v", err) - } -} - -func TestNoopClient(t *testing.T) { - client := NewNoopClient() - - // Track should not panic - client.Track("event", Properties{"key": "value"}) - - // Close should not error - if err := client.Close(); err != nil { - t.Errorf("NoopClient.Close() error = %v", err) - } -} - -func TestNewPostHogClient_EmptyAPIKey(t *testing.T) { - client, err := NewPostHogClient(ClientConfig{ - APIKey: "", // Empty - Version: "1.0.0", - Config: &Config{Enabled: true}, - }) - - if err != nil { - t.Errorf("should not error with empty API key, got %v", err) - } - - if client.initialized { - t.Error("should not be initialized with empty API key") - } - - // Track should be a no-op, not panic - client.Track("event", nil) -} - -func TestNewPostHogClient_NilConfig(t *testing.T) { - client, err := NewPostHogClient(ClientConfig{ - APIKey: "test-key", - Version: "1.0.0", - Config: nil, // Nil config - }) - - if err != nil { - t.Errorf("should not error with nil config, got %v", err) - } - - if client.initialized { - t.Error("should not be initialized with nil config") - } -} - -func TestPostHogClient_Track_Concurrent(t *testing.T) { - cfg := &Config{ - Enabled: true, - ConsentAsked: true, - AnonymousID: "test-anon-id", - } - - client, mock := newTestClient(cfg, "1.0.0") - - // Track concurrently - var wg sync.WaitGroup - for i := 0; i < 100; i++ { - wg.Add(1) - go func(n int) { - defer wg.Done() - client.Track("concurrent_event", Properties{"iteration": n}) - }(i) - } - wg.Wait() - - events := mock.getEvents() - if len(events) != 100 { - t.Errorf("expected 100 events, got %d", len(events)) - } -} - -func TestPostHogClient_Track_ReturnsImmediately(t *testing.T) { - cfg := &Config{ - Enabled: true, - ConsentAsked: true, - AnonymousID: "test-anon-id", - } - - client, _ := newTestClient(cfg, "1.0.0") - - // Track should return immediately (non-blocking) - // This is a basic smoke test - the actual async behavior is handled by PostHog SDK - done := make(chan bool, 1) - go func() { - client.Track("test_event", nil) - done <- true - }() - - // Give goroutine time to complete - Track should be nearly instant - select { - case <-done: - // Success - returned quickly - case <-time.After(100 * time.Millisecond): - t.Error("Track() should return immediately (within 100ms)") - } -} diff --git a/internal/telemetry/config_test.go b/internal/telemetry/config_test.go deleted file mode 100644 index 39dcac7..0000000 --- a/internal/telemetry/config_test.go +++ /dev/null @@ -1,300 +0,0 @@ -package telemetry - -import ( - "encoding/json" - "os" - "path/filepath" - "testing" -) - -func TestLoad_NewConfig(t *testing.T) { - // Use temp directory for test isolation - tmpDir := t.TempDir() - SetConfigDir(tmpDir) - defer SetConfigDir("") - - cfg, err := Load() - if err != nil { - t.Fatalf("Load() error = %v", err) - } - - // Should return defaults - if cfg.Enabled { - t.Error("new config should have Enabled = false") - } - if cfg.ConsentAsked { - t.Error("new config should have ConsentAsked = false") - } - if cfg.AnonymousID == "" { - t.Error("new config should have generated AnonymousID") - } - - // UUID should be valid format (36 chars with hyphens) - if len(cfg.AnonymousID) != 36 { - t.Errorf("AnonymousID should be UUID format, got length %d", len(cfg.AnonymousID)) - } -} - -func TestSave_CreatesFile(t *testing.T) { - tmpDir := t.TempDir() - SetConfigDir(tmpDir) - defer SetConfigDir("") - - cfg := &Config{ - Enabled: true, - ConsentAsked: true, - AnonymousID: "test-uuid-1234", - } - - if err := cfg.Save(); err != nil { - t.Fatalf("Save() error = %v", err) - } - - // Verify file exists - configPath := filepath.Join(tmpDir, ConfigFileName) - if _, err := os.Stat(configPath); os.IsNotExist(err) { - t.Error("config file was not created") - } - - // Verify file permissions - info, err := os.Stat(configPath) - if err != nil { - t.Fatalf("Stat() error = %v", err) - } - // On Unix, check that permissions are 0600 - // Note: Windows doesn't support Unix permissions the same way - if info.Mode().Perm() != 0600 { - t.Errorf("file permissions = %o, want 0600", info.Mode().Perm()) - } - - // Verify content - data, err := os.ReadFile(configPath) - if err != nil { - t.Fatalf("ReadFile() error = %v", err) - } - - var loaded Config - if err := json.Unmarshal(data, &loaded); err != nil { - t.Fatalf("Unmarshal() error = %v", err) - } - - if loaded.Enabled != cfg.Enabled { - t.Errorf("Enabled = %v, want %v", loaded.Enabled, cfg.Enabled) - } - if loaded.ConsentAsked != cfg.ConsentAsked { - t.Errorf("ConsentAsked = %v, want %v", loaded.ConsentAsked, cfg.ConsentAsked) - } - if loaded.AnonymousID != cfg.AnonymousID { - t.Errorf("AnonymousID = %v, want %v", loaded.AnonymousID, cfg.AnonymousID) - } -} - -func TestLoad_ExistingConfig(t *testing.T) { - tmpDir := t.TempDir() - SetConfigDir(tmpDir) - defer SetConfigDir("") - - // Create existing config - existing := Config{ - Enabled: true, - ConsentAsked: true, - AnonymousID: "existing-uuid-5678", - } - data, _ := json.Marshal(existing) - configPath := filepath.Join(tmpDir, ConfigFileName) - if err := os.WriteFile(configPath, data, 0600); err != nil { - t.Fatalf("WriteFile() error = %v", err) - } - - // Load should read existing - cfg, err := Load() - if err != nil { - t.Fatalf("Load() error = %v", err) - } - - if cfg.Enabled != existing.Enabled { - t.Errorf("Enabled = %v, want %v", cfg.Enabled, existing.Enabled) - } - if cfg.ConsentAsked != existing.ConsentAsked { - t.Errorf("ConsentAsked = %v, want %v", cfg.ConsentAsked, existing.ConsentAsked) - } - if cfg.AnonymousID != existing.AnonymousID { - t.Errorf("AnonymousID = %v, want %v", cfg.AnonymousID, existing.AnonymousID) - } -} - -func TestLoad_GeneratesUUID_WhenMissing(t *testing.T) { - tmpDir := t.TempDir() - SetConfigDir(tmpDir) - defer SetConfigDir("") - - // Create config without anonymous ID - existing := Config{ - Enabled: true, - ConsentAsked: true, - AnonymousID: "", // Missing - } - data, _ := json.Marshal(existing) - configPath := filepath.Join(tmpDir, ConfigFileName) - if err := os.WriteFile(configPath, data, 0600); err != nil { - t.Fatalf("WriteFile() error = %v", err) - } - - cfg, err := Load() - if err != nil { - t.Fatalf("Load() error = %v", err) - } - - // Should have generated a UUID - if cfg.AnonymousID == "" { - t.Error("should have generated AnonymousID when missing") - } - if len(cfg.AnonymousID) != 36 { - t.Errorf("AnonymousID should be UUID format, got length %d", len(cfg.AnonymousID)) - } -} - -func TestConfig_Enable(t *testing.T) { - cfg := &Config{ - Enabled: false, - ConsentAsked: false, - } - - cfg.Enable() - - if !cfg.Enabled { - t.Error("Enable() should set Enabled = true") - } - if !cfg.ConsentAsked { - t.Error("Enable() should set ConsentAsked = true") - } -} - -func TestConfig_Disable(t *testing.T) { - cfg := &Config{ - Enabled: true, - ConsentAsked: false, - } - - cfg.Disable() - - if cfg.Enabled { - t.Error("Disable() should set Enabled = false") - } - if !cfg.ConsentAsked { - t.Error("Disable() should set ConsentAsked = true") - } -} - -func TestConfig_NeedsConsent(t *testing.T) { - tests := []struct { - name string - consentAsked bool - want bool - }{ - {"needs consent when not asked", false, true}, - {"no consent needed when already asked", true, false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - cfg := &Config{ConsentAsked: tt.consentAsked} - if got := cfg.NeedsConsent(); got != tt.want { - t.Errorf("NeedsConsent() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestConfig_IsEnabled(t *testing.T) { - tests := []struct { - name string - enabled bool - want bool - }{ - {"returns true when enabled", true, true}, - {"returns false when disabled", false, false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - cfg := &Config{Enabled: tt.enabled} - if got := cfg.IsEnabled(); got != tt.want { - t.Errorf("IsEnabled() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestSave_CreatesDirectory(t *testing.T) { - tmpDir := t.TempDir() - // Use a subdirectory that doesn't exist yet - nestedDir := filepath.Join(tmpDir, "nested", "config") - SetConfigDir(nestedDir) - defer SetConfigDir("") - - cfg := &Config{ - Enabled: true, - ConsentAsked: true, - AnonymousID: "test-uuid", - } - - if err := cfg.Save(); err != nil { - t.Fatalf("Save() error = %v", err) - } - - // Directory should have been created - if _, err := os.Stat(nestedDir); os.IsNotExist(err) { - t.Error("Save() should create nested directories") - } -} - -func TestGetConfigPath(t *testing.T) { - tmpDir := t.TempDir() - SetConfigDir(tmpDir) - defer SetConfigDir("") - - path, err := GetConfigPath() - if err != nil { - t.Fatalf("GetConfigPath() error = %v", err) - } - - expected := filepath.Join(tmpDir, ConfigFileName) - if path != expected { - t.Errorf("GetConfigPath() = %v, want %v", path, expected) - } -} - -func TestRoundTrip(t *testing.T) { - tmpDir := t.TempDir() - SetConfigDir(tmpDir) - defer SetConfigDir("") - - // Create and save - original := &Config{ - Enabled: true, - ConsentAsked: true, - AnonymousID: "roundtrip-uuid-9999", - } - - if err := original.Save(); err != nil { - t.Fatalf("Save() error = %v", err) - } - - // Load back - loaded, err := Load() - if err != nil { - t.Fatalf("Load() error = %v", err) - } - - // Verify all fields match - if loaded.Enabled != original.Enabled { - t.Errorf("Enabled = %v, want %v", loaded.Enabled, original.Enabled) - } - if loaded.ConsentAsked != original.ConsentAsked { - t.Errorf("ConsentAsked = %v, want %v", loaded.ConsentAsked, original.ConsentAsked) - } - if loaded.AnonymousID != original.AnonymousID { - t.Errorf("AnonymousID = %v, want %v", loaded.AnonymousID, original.AnonymousID) - } -} diff --git a/internal/ui/context_view.go b/internal/ui/context_view.go index b5e4aed..e45eacf 100644 --- a/internal/ui/context_view.go +++ b/internal/ui/context_view.go @@ -7,6 +7,7 @@ import ( "github.com/charmbracelet/lipgloss" "github.com/josephgoksu/TaskWing/internal/app" "github.com/josephgoksu/TaskWing/internal/knowledge" + "github.com/josephgoksu/TaskWing/internal/memory" ) const ( @@ -304,6 +305,101 @@ func getContentWithoutSummary(content, summary string) string { return content } +// RenderAskResult displays a complete AskResult from the ask pipeline. +// This is the primary rendering function for the `taskwing ask` command. +func RenderAskResult(result *app.AskResult, verbose bool) { + titleStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("205")).Bold(true) + sectionStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("141")).Bold(true) + metaStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("241")) + + // Title + if result.Answer != "" { + fmt.Println() + fmt.Println(titleStyle.Render(fmt.Sprintf("📖 %s", result.Query))) + } else { + fmt.Println(titleStyle.Render(fmt.Sprintf("🔍 Results for: \"%s\"", result.Query))) + } + + // Pipeline info + fmt.Println(metaStyle.Render(fmt.Sprintf(" Pipeline: %s", result.Pipeline))) + if result.RewrittenQuery != "" { + fmt.Println(metaStyle.Render(fmt.Sprintf(" Rewritten: %s", result.RewrittenQuery))) + } + + // Warning + if result.Warning != "" { + fmt.Println() + fmt.Println(RenderWarningPanel("Warning", result.Warning)) + } + + // Answer (only render if not already streamed — streaming writes directly to stdout) + if result.Answer != "" { + fmt.Println() + fmt.Println(RenderInfoPanel("Answer", result.Answer)) + } + + // Knowledge results + if len(result.Results) > 0 { + fmt.Println() + fmt.Println(sectionStyle.Render("📚 Knowledge")) + + // Convert NodeResponse to ScoredNode for the existing panel renderer + scored := nodeResponsesToScoredNodes(result.Results) + + var maxScore float32 = 0.01 + for _, s := range scored { + if s.Score > maxScore { + maxScore = s.Score + } + } + + for i, s := range scored { + renderScoredNodePanel(i+1, s, maxScore, verbose) + } + } + + // Code symbols + if len(result.Symbols) > 0 { + fmt.Println() + fmt.Println(sectionStyle.Render("💻 Code Symbols")) + + for i, sym := range result.Symbols { + renderSymbolPanel(i+1, sym, verbose) + } + } + + // No results + if len(result.Results) == 0 && len(result.Symbols) == 0 && result.Answer == "" { + fmt.Println() + fmt.Println(metaStyle.Render(" No results found. Try a different query or run 'taskwing bootstrap' to populate memory.")) + } + + // Summary line + if result.Total > 0 || result.TotalSymbols > 0 { + fmt.Println() + fmt.Println(metaStyle.Render(fmt.Sprintf(" %d knowledge result(s), %d symbol(s)", result.Total, result.TotalSymbols))) + } +} + +// nodeResponsesToScoredNodes converts NodeResponse slice to ScoredNode slice +// for reuse with the existing renderScoredNodePanel renderer. +func nodeResponsesToScoredNodes(responses []knowledge.NodeResponse) []knowledge.ScoredNode { + scored := make([]knowledge.ScoredNode, len(responses)) + for i, r := range responses { + scored[i] = knowledge.ScoredNode{ + Node: &memory.Node{ + ID: r.ID, + Type: r.Type, + Summary: r.Summary, + Content: r.Content, + SourceAgent: "", // Not available in NodeResponse + }, + Score: r.MatchScore, + } + } + return scored +} + // symbolKindIcon returns an icon for a symbol kind. func symbolKindIcon(kind string) string { switch kind { diff --git a/internal/ui/utils_test.go b/internal/ui/utils_test.go deleted file mode 100644 index 728f174..0000000 --- a/internal/ui/utils_test.go +++ /dev/null @@ -1,153 +0,0 @@ -package ui - -import ( - "strings" - "testing" -) - -func TestTruncate(t *testing.T) { - tests := []struct { - name string - input string - maxLen int - expected string - }{ - {"empty", "", 10, ""}, - {"short string", "hello", 10, "hello"}, - {"exact length", "hello", 5, "hello"}, - {"needs truncation", "hello world", 8, "hello..."}, - {"very short max", "hello", 3, "hel"}, - {"zero max", "hello", 0, "hello"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := Truncate(tt.input, tt.maxLen) - if result != tt.expected { - t.Errorf("Truncate(%q, %d) = %q, want %q", tt.input, tt.maxLen, result, tt.expected) - } - }) - } -} - -func TestWrapText(t *testing.T) { - tests := []struct { - name string - input string - width int - contains []string - }{ - {"short text", "hello world", 20, []string{"hello world"}}, - {"needs wrap", "hello world foo bar", 10, []string{"hello", "world", "foo", "bar"}}, - {"zero width", "hello", 0, []string{"hello"}}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := WrapText(tt.input, tt.width) - for _, substr := range tt.contains { - if !strings.Contains(result, substr) { - t.Errorf("WrapText(%q, %d) = %q, expected to contain %q", tt.input, tt.width, result, substr) - } - } - }) - } -} - -func TestPanel(t *testing.T) { - t.Run("basic panel", func(t *testing.T) { - panel := NewPanel("Title", "Content") - result := panel.Render() - - if !strings.Contains(result, "Title") { - t.Error("Panel should contain title") - } - if !strings.Contains(result, "Content") { - t.Error("Panel should contain content") - } - }) - - t.Run("panel without title", func(t *testing.T) { - panel := NewPanel("", "Content only") - result := panel.Render() - - if !strings.Contains(result, "Content only") { - t.Error("Panel should contain content") - } - }) - - t.Run("panel with custom color", func(t *testing.T) { - panel := NewPanel("Info", "Details").WithBorderColor(ColorCyan) - result := panel.Render() - - if !strings.Contains(result, "Info") { - t.Error("Panel should contain title") - } - }) - - t.Run("convenience functions", func(t *testing.T) { - info := RenderInfoPanel("Info", "content") - success := RenderSuccessPanel("Success", "content") - errPanel := RenderErrorPanel("Error", "content") - warning := RenderWarningPanel("Warning", "content") - - if !strings.Contains(info, "Info") { - t.Error("Info panel should contain title") - } - if !strings.Contains(success, "Success") { - t.Error("Success panel should contain title") - } - if !strings.Contains(errPanel, "Error") { - t.Error("Error panel should contain title") - } - if !strings.Contains(warning, "Warning") { - t.Error("Warning panel should contain title") - } - }) -} - -func TestTable(t *testing.T) { - t.Run("basic table", func(t *testing.T) { - table := Table{ - Headers: []string{"ID", "Name", "Status"}, - Rows: [][]string{ - {"1", "Task A", "pending"}, - {"2", "Task B", "done"}, - }, - } - - result := table.Render() - - if !strings.Contains(result, "ID") { - t.Error("Table should contain ID header") - } - if !strings.Contains(result, "Task A") { - t.Error("Table should contain Task A") - } - if !strings.Contains(result, "done") { - t.Error("Table should contain done status") - } - }) - - t.Run("empty table", func(t *testing.T) { - table := Table{} - result := table.Render() - - if result != "" { - t.Error("Empty table should render empty string") - } - }) - - t.Run("with max width", func(t *testing.T) { - table := Table{ - Headers: []string{"Description"}, - Rows: [][]string{{"This is a very long description that should be truncated"}}, - MaxWidth: 20, - } - - widths := table.ColumnWidths() - if widths[0] > 20 { - t.Errorf("Column width should be <= 20, got %d", widths[0]) - } - }) -} diff --git a/internal/util/id_test.go b/internal/util/id_test.go deleted file mode 100644 index 6316abb..0000000 --- a/internal/util/id_test.go +++ /dev/null @@ -1,336 +0,0 @@ -package util - -import ( - "context" - "errors" - "testing" -) - -func TestShortID(t *testing.T) { - tests := []struct { - name string - id string - n int - want string - }{ - { - name: "default length truncates", - id: "task-abcdef12", - n: 0, - want: "task-abc", - }, - { - name: "negative uses default", - id: "task-abcdef12", - n: -1, - want: "task-abc", - }, - { - name: "explicit length 10", - id: "task-abcdef12", - n: 10, - want: "task-abcde", - }, - { - name: "length equals ID", - id: "task-abc", - n: 8, - want: "task-abc", - }, - { - name: "length longer than ID", - id: "task-abc", - n: 20, - want: "task-abc", - }, - { - name: "plan ID", - id: "plan-xyz12345", - n: 8, - want: "plan-xyz", - }, - { - name: "empty ID", - id: "", - n: 8, - want: "", - }, - { - name: "very short", - id: "ab", - n: 8, - want: "ab", - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - got := ShortID(tc.id, tc.n) - if got != tc.want { - t.Errorf("ShortID(%q, %d) = %q, want %q", tc.id, tc.n, got, tc.want) - } - }) - } -} - -// mockResolver implements IDPrefixResolver for testing. -type mockResolver struct { - taskIDs []string - planIDs []string - err error -} - -func (m *mockResolver) FindTaskIDsByPrefix(_ context.Context, prefix string) ([]string, error) { - if m.err != nil { - return nil, m.err - } - var matches []string - for _, id := range m.taskIDs { - if len(id) >= len(prefix) && id[:len(prefix)] == prefix { - matches = append(matches, id) - } - } - return matches, nil -} - -func (m *mockResolver) FindPlanIDsByPrefix(_ context.Context, prefix string) ([]string, error) { - if m.err != nil { - return nil, m.err - } - var matches []string - for _, id := range m.planIDs { - if len(id) >= len(prefix) && id[:len(prefix)] == prefix { - matches = append(matches, id) - } - } - return matches, nil -} - -func TestResolveTaskID(t *testing.T) { - ctx := context.Background() - - tests := []struct { - name string - resolver *mockResolver - idOrPrefix string - want string - wantErr error - }{ - { - name: "full ID exact match", - resolver: &mockResolver{ - taskIDs: []string{"task-abcdef12", "task-xyz12345"}, - }, - idOrPrefix: "task-abcdef12", - want: "task-abcdef12", - }, - { - name: "prefix matches one", - resolver: &mockResolver{ - taskIDs: []string{"task-abcdef12", "task-xyz12345"}, - }, - idOrPrefix: "task-abc", - want: "task-abcdef12", - }, - { - name: "prefix without task- prepended", - resolver: &mockResolver{ - taskIDs: []string{"task-abcdef12", "task-xyz12345"}, - }, - idOrPrefix: "abc", - want: "task-abcdef12", - }, - { - name: "prefix matches multiple - ambiguous", - resolver: &mockResolver{ - taskIDs: []string{"task-abc11111", "task-abc22222", "task-abc33333"}, - }, - idOrPrefix: "task-abc", - wantErr: ErrAmbiguousID, - }, - { - name: "prefix matches none - not found", - resolver: &mockResolver{ - taskIDs: []string{"task-abcdef12"}, - }, - idOrPrefix: "task-xyz", - wantErr: ErrNotFound, - }, - { - name: "empty ID", - resolver: &mockResolver{}, - idOrPrefix: "", - wantErr: ErrNotFound, - }, - { - name: "resolver error", - resolver: &mockResolver{ - err: errors.New("database error"), - }, - idOrPrefix: "task-abc", - wantErr: errors.New("database error"), - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - got, err := ResolveTaskID(ctx, tc.resolver, tc.idOrPrefix) - - if tc.wantErr != nil { - if err == nil { - t.Fatalf("expected error containing %v, got nil", tc.wantErr) - } - if !errors.Is(err, tc.wantErr) && !containsError(err, tc.wantErr) { - t.Errorf("error = %v, want %v", err, tc.wantErr) - } - return - } - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if got != tc.want { - t.Errorf("ResolveTaskID() = %q, want %q", got, tc.want) - } - }) - } -} - -func TestResolvePlanID(t *testing.T) { - ctx := context.Background() - - tests := []struct { - name string - resolver *mockResolver - idOrPrefix string - want string - wantErr error - }{ - { - name: "full ID exact match", - resolver: &mockResolver{ - planIDs: []string{"plan-abcdef12", "plan-xyz12345"}, - }, - idOrPrefix: "plan-abcdef12", - want: "plan-abcdef12", - }, - { - name: "prefix matches one", - resolver: &mockResolver{ - planIDs: []string{"plan-abcdef12", "plan-xyz12345"}, - }, - idOrPrefix: "plan-abc", - want: "plan-abcdef12", - }, - { - name: "prefix without plan- prepended", - resolver: &mockResolver{ - planIDs: []string{"plan-abcdef12", "plan-xyz12345"}, - }, - idOrPrefix: "abc", - want: "plan-abcdef12", - }, - { - name: "prefix matches multiple - ambiguous", - resolver: &mockResolver{ - planIDs: []string{"plan-abc11111", "plan-abc22222"}, - }, - idOrPrefix: "plan-abc", - wantErr: ErrAmbiguousID, - }, - { - name: "prefix matches none - not found", - resolver: &mockResolver{ - planIDs: []string{"plan-abcdef12"}, - }, - idOrPrefix: "plan-xyz", - wantErr: ErrNotFound, - }, - { - name: "empty ID", - resolver: &mockResolver{}, - idOrPrefix: "", - wantErr: ErrNotFound, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - got, err := ResolvePlanID(ctx, tc.resolver, tc.idOrPrefix) - - if tc.wantErr != nil { - if err == nil { - t.Fatalf("expected error containing %v, got nil", tc.wantErr) - } - if !errors.Is(err, tc.wantErr) && !containsError(err, tc.wantErr) { - t.Errorf("error = %v, want %v", err, tc.wantErr) - } - return - } - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if got != tc.want { - t.Errorf("ResolvePlanID() = %q, want %q", got, tc.want) - } - }) - } -} - -// containsError checks if err contains the target error message. -func containsError(err, target error) bool { - if err == nil || target == nil { - return false - } - return err.Error() == target.Error() || - len(err.Error()) > len(target.Error()) && - err.Error()[len(err.Error())-len(target.Error()):] == target.Error() -} - -func TestAmbiguousErrorMessage(t *testing.T) { - ctx := context.Background() - resolver := &mockResolver{ - taskIDs: []string{ - "task-aaa11111", - "task-aaa22222", - "task-aaa33333", - "task-aaa44444", - "task-aaa55555", - "task-aaa66666", // 6th one, should be truncated - }, - } - - _, err := ResolveTaskID(ctx, resolver, "task-aaa") - if err == nil { - t.Fatal("expected error") - } - - if !errors.Is(err, ErrAmbiguousID) { - t.Errorf("expected ErrAmbiguousID, got: %v", err) - } - - // Should mention 6 matches - errStr := err.Error() - if !contains(errStr, "6 tasks") { - t.Errorf("error should mention 6 matches: %s", errStr) - } - - // Should only show first 5 candidates (MaxAmbiguousCandidates) - if contains(errStr, "task-aaa66666") { - t.Errorf("error should not show 6th candidate: %s", errStr) - } -} - -func contains(s, substr string) bool { - return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && findSubstr(s, substr)) -} - -func findSubstr(s, substr string) bool { - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return true - } - } - return false -} diff --git a/internal/utils/json_test.go b/internal/utils/json_test.go deleted file mode 100644 index 96493b0..0000000 --- a/internal/utils/json_test.go +++ /dev/null @@ -1,469 +0,0 @@ -package utils - -import ( - "strings" - "testing" -) - -// TestExtractAndParseJSON_InvalidEscapeSequences tests JSON parsing with invalid -// escape sequences that LLMs commonly produce (e.g., \c, \s, \d from regex patterns). -// This is a regression test for the "invalid character 'c' in string escape code" error. -func TestExtractAndParseJSON_InvalidEscapeSequences(t *testing.T) { - type TestResult struct { - Name string `json:"name"` - Pattern string `json:"pattern"` - Description string `json:"description"` - } - - tests := []struct { - name string - input string - wantErr bool - errContains string - }{ - { - name: "valid JSON", - input: `{"name": "test", "pattern": "foo", "description": "bar"}`, - wantErr: false, - }, - { - name: "regex pattern with backslash-s", - input: `{"name": "regex", "pattern": "^\s+match\s*$", "description": "whitespace"}`, - wantErr: false, - }, - { - name: "regex pattern with backslash-d", - input: `{"name": "digits", "pattern": "\d+", "description": "numbers"}`, - wantErr: false, - }, - { - name: "regex pattern with backslash-w", - input: `{"name": "word", "pattern": "\w+", "description": "word chars"}`, - wantErr: false, - }, - { - name: "regex pattern with backslash-c (the specific failing case)", - input: `{"name": "ctrl", "pattern": "\c", "description": "control char"}`, - wantErr: false, - }, - { - name: "multiple invalid escapes", - input: `{"name": "complex", "pattern": "\s\d\w\c\x", "description": "mixed"}`, - wantErr: false, - }, - { - name: "Windows path with backslash-C (common in file paths)", - input: `{"name": "path", "pattern": "C:\code\project", "description": "Windows path"}`, - wantErr: false, - }, - { - name: "Windows path with lowercase", - input: `{"name": "path", "pattern": "c:\code\project", "description": "lowercase drive"}`, - wantErr: false, - }, - { - name: "JSON embedded in markdown code block", - input: "```json\n{\"name\": \"test\", \"pattern\": \"\\s+\", \"description\": \"wrapped\"}\n```", - wantErr: false, - }, - { - name: "LLM response with explanation before JSON", - input: "Here's the analysis:\n\n{\"name\": \"test\", \"pattern\": \"\\d+\", \"description\": \"with prefix\"}", - wantErr: false, - }, - { - name: "nested invalid escapes in code snippet", - input: `{"name": "code", "pattern": "func match(s string) bool {\n\treturn regexp.MustCompile(` + "`" + `\s+` + "`" + `).MatchString(s)\n}", "description": "code with regex"}`, - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := ExtractAndParseJSON[TestResult](tt.input) - - if tt.wantErr { - if err == nil { - t.Errorf("ExtractAndParseJSON() expected error containing %q, got nil", tt.errContains) - return - } - if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) { - t.Errorf("ExtractAndParseJSON() error = %v, want error containing %q", err, tt.errContains) - } - return - } - - if err != nil { - t.Errorf("ExtractAndParseJSON() unexpected error: %v", err) - return - } - - // Basic validation that parsing worked - if result.Name == "" { - t.Error("ExtractAndParseJSON() result.Name is empty, expected non-empty") - } - }) - } -} - -// TestSanitizeControlChars_InvalidEscapes specifically tests the sanitization -// of invalid JSON escape sequences. -func TestSanitizeControlChars_InvalidEscapes(t *testing.T) { - tests := []struct { - name string - input string - want string - }{ - { - name: "backslash-c inside string", - input: `{"key": "value\c"}`, - want: `{"key": "value\\c"}`, - }, - { - name: "backslash-s inside string", - input: `{"key": "\s+"}`, - want: `{"key": "\\s+"}`, - }, - { - name: "backslash-d inside string", - input: `{"key": "\d{3}"}`, - want: `{"key": "\\d{3}"}`, - }, - { - name: "backslash-w inside string", - input: `{"key": "\w*"}`, - want: `{"key": "\\w*"}`, - }, - { - name: "valid escapes preserved", - input: `{"key": "line1\nline2\ttab"}`, - want: `{"key": "line1\nline2\ttab"}`, - }, - { - name: "mixed valid and invalid", - input: `{"key": "\n\s\t\d"}`, - want: `{"key": "\n\\s\t\\d"}`, - }, - { - name: "backslash outside string unchanged", - input: `{"key": "value"}\extra`, - want: `{"key": "value"}\extra`, - }, - { - // Note: \t is a valid JSON escape for tab, so it's preserved. - // The actual fix for Windows paths happens in the full repair pipeline. - name: "Windows path - partial (t is valid escape)", - input: `{"path": "C:\code\test"}`, - want: `{"path": "C:\\code\test"}`, - }, - { - name: "escaped backslash preserved", - input: `{"key": "path\\to\\file"}`, - want: `{"key": "path\\to\\file"}`, - }, - { - name: "escaped quote preserved", - input: `{"key": "say \"hello\""}`, - want: `{"key": "say \"hello\""}`, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := sanitizeControlChars(tt.input) - if got != tt.want { - t.Errorf("sanitizeControlChars() = %q, want %q", got, tt.want) - } - }) - } -} - -// TestRepairJSON_InvalidEscapes tests the full repair pipeline including -// control character sanitization. -func TestRepairJSON_InvalidEscapes(t *testing.T) { - tests := []struct { - name string - input string - wantValid bool // should the repaired JSON be valid? - }{ - { - name: "regex pattern with backslash-s", - input: `{"pattern": "\s+"}`, - wantValid: true, - }, - { - name: "Windows path", - input: `{"path": "C:\code\project\file.go"}`, - wantValid: true, - }, - { - name: "multiple regex escapes", - input: `{"regex": "^\s*\d+\w+\c$"}`, - wantValid: true, - }, - { - name: "complex code snippet with escapes", - input: `{"code": "if match, _ := regexp.MatchString(` + "`" + `\s+` + "`" + `, s); match {\n\tfmt.Println(\"found\")\n}"}`, - wantValid: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - repaired := repairJSON(tt.input) - - // Try to parse the repaired JSON - var result map[string]any - _, err := ExtractAndParseJSON[map[string]any](repaired) - - if tt.wantValid && err != nil { - t.Errorf("repairJSON() produced invalid JSON: %v\nInput: %s\nRepaired: %s", err, tt.input, repaired) - } - if !tt.wantValid && err == nil { - t.Errorf("repairJSON() unexpectedly produced valid JSON: %v", result) - } - }) - } -} - -// TestExtractAndParseJSON_WindowsPaths tests handling of Windows-style paths -// which commonly cause "invalid character 'c'" errors due to \c sequences. -func TestExtractAndParseJSON_WindowsPaths(t *testing.T) { - type PathResult struct { - FilePath string `json:"file_path"` - Content string `json:"content"` - } - - tests := []struct { - name string - input string - wantErr bool - }{ - { - name: "C drive path with Users", - input: `{"file_path": "C:\Users\dev\project\main.go", "content": "package main"}`, - wantErr: false, - }, - { - name: "path with backslash-t (taskwing) - valid escape in wrong context", - input: `{"file_path": "c:\code\taskwing\file.go", "content": "package utils"}`, - wantErr: false, - }, - { - name: "path with backslash-n (new folder) - valid escape in wrong context", - input: `{"file_path": "D:\projects\new\app\main.go", "content": "package main"}`, - wantErr: false, - }, - { - name: "path with backslash-u (utils) - unicode prefix in wrong context", - input: `{"file_path": "C:\code\utils\helper.go", "content": "package utils"}`, - wantErr: false, - }, - { - name: "path with backslash-r (release) - valid escape in wrong context", - input: `{"file_path": "C:\build\release\app.exe", "content": "binary"}`, - wantErr: false, - }, - { - name: "simple invalid escape", - input: `{"file_path": "C:\code\file.go", "content": "test"}`, - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := ExtractAndParseJSON[PathResult](tt.input) - - if tt.wantErr { - if err == nil { - t.Error("ExtractAndParseJSON() expected error, got nil") - } - return - } - - if err != nil { - t.Errorf("ExtractAndParseJSON() error = %v", err) - return - } - - if result.FilePath == "" { - t.Error("ExtractAndParseJSON() result.FilePath is empty") - } - }) - } -} - -// TestRepairJSON_MalformedNumericLiterals tests the fix for malformed numeric -// literals where LLMs emit numbers like "0. 9" with a space after the decimal point. -// This is a regression test for the "invalid character ' ' after decimal point" error. -func TestRepairJSON_MalformedNumericLiterals(t *testing.T) { - tests := []struct { - name string - input string - wantRepair string // expected repaired string - wantValid bool // should parse as valid JSON after repair - checkValues map[string]float64 - }{ - { - name: "single digit after decimal: 0. 9 -> 0.9", - input: `{"confidence": 0. 9}`, - wantRepair: `{"confidence": 0.9}`, - wantValid: true, - checkValues: map[string]float64{ - "confidence": 0.9, - }, - }, - { - name: "multiple digits after decimal: 0. 85 -> 0.85", - input: `{"score": 0. 85}`, - wantRepair: `{"score": 0.85}`, - wantValid: true, - checkValues: map[string]float64{ - "score": 0.85, - }, - }, - { - name: "integer part: 1. 5 -> 1.5", - input: `{"value": 1. 5}`, - wantRepair: `{"value": 1.5}`, - wantValid: true, - checkValues: map[string]float64{ - "value": 1.5, - }, - }, - { - name: "multiple spaces: 0. 9 -> 0.9", - input: `{"n": 0. 9}`, - wantRepair: `{"n": 0.9}`, - wantValid: true, - checkValues: map[string]float64{ - "n": 0.9, - }, - }, - { - name: "tab after decimal: 0.\t9 -> 0.9", - input: `{"n": 0.` + "\t" + `9}`, - wantRepair: `{"n": 0.9}`, - wantValid: true, - checkValues: map[string]float64{ - "n": 0.9, - }, - }, - { - name: "multiple malformed numbers in object", - input: `{"a": 0. 5, "b": 1. 23, "c": 99. 9}`, - wantRepair: `{"a": 0.5, "b": 1.23, "c": 99.9}`, - wantValid: true, - checkValues: map[string]float64{ - "a": 0.5, - "b": 1.23, - "c": 99.9, - }, - }, - { - name: "normal number unchanged", - input: `{"n": 0.9}`, - wantRepair: `{"n": 0.9}`, - wantValid: true, - checkValues: map[string]float64{ - "n": 0.9, - }, - }, - { - // NOTE: The regex also affects content inside strings. This is acceptable - // because: 1) it's rare for strings to contain "digit. digit" patterns, - // 2) the semantic meaning is preserved, and 3) the primary use case is - // fixing malformed numeric JSON values from LLM output. - name: "string with decimal point and space (gets modified)", - input: `{"s": "1. 2 is text"}`, - wantRepair: `{"s": "1.2 is text"}`, // NOTE: strings ARE modified (acceptable trade-off) - wantValid: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - repaired := repairJSON(tt.input) - - if repaired != tt.wantRepair { - t.Errorf("repairJSON() = %q, want %q", repaired, tt.wantRepair) - } - - // Try to parse the repaired JSON - result, err := ExtractAndParseJSON[map[string]any](repaired) - - if tt.wantValid && err != nil { - t.Errorf("repairJSON() produced invalid JSON: %v\nInput: %s\nRepaired: %s", err, tt.input, repaired) - return - } - - // Check specific values if provided - for key, want := range tt.checkValues { - if got, ok := result[key].(float64); ok { - if got != want { - t.Errorf("result[%q] = %v, want %v", key, got, want) - } - } else { - t.Errorf("result[%q] is not float64: %T", key, result[key]) - } - } - }) - } -} - -// TestExtractAndParseJSON_LLMCodeAnalysis simulates real LLM output that caused -// the "invalid character 'c'" error during bootstrap code analysis. -func TestExtractAndParseJSON_LLMCodeAnalysis(t *testing.T) { - type Evidence struct { - FilePath string `json:"file_path"` - StartLine int `json:"start_line"` - Snippet string `json:"snippet"` - } - - type Finding struct { - Title string `json:"title"` - Description string `json:"description"` - Evidence []Evidence `json:"evidence"` - } - - type AnalysisResult struct { - Decisions []Finding `json:"decisions"` - Patterns []Finding `json:"patterns"` - } - - // This simulates the kind of JSON that might contain file paths or code snippets - // with problematic escape sequences - input := `{ - "decisions": [{ - "title": "Use structured logging", - "description": "The codebase uses structured logging with fields", - "evidence": [{ - "file_path": "internal/bootstrap/scanner.go", - "start_line": 42, - "snippet": "log.WithFields(log.Fields{\"path\": path}).Info(\"scanning\")" - }] - }], - "patterns": [{ - "title": "Regex-based parsing", - "description": "Uses regex patterns like \s+ and \d+ for parsing", - "evidence": [{ - "file_path": "internal/utils/parser.go", - "start_line": 15, - "snippet": "regexp.MustCompile(` + "`" + `^\s*(\w+)\s*=\s*(.*)$` + "`" + `)" - }] - }] - }` - - result, err := ExtractAndParseJSON[AnalysisResult](input) - if err != nil { - t.Fatalf("ExtractAndParseJSON() failed on LLM-like output: %v", err) - } - - if len(result.Decisions) != 1 { - t.Errorf("Expected 1 decision, got %d", len(result.Decisions)) - } - if len(result.Patterns) != 1 { - t.Errorf("Expected 1 pattern, got %d", len(result.Patterns)) - } -} diff --git a/scripts/test_mcp_workspace.sh b/scripts/test_mcp_workspace.sh index ba0fea3..c49e205 100755 --- a/scripts/test_mcp_workspace.sh +++ b/scripts/test_mcp_workspace.sh @@ -1,7 +1,7 @@ #!/bin/bash # Test script for MCP workspace filtering functionality. # -# This script tests that the MCP recall tool correctly handles workspace filtering +# This script tests that the MCP ask tool correctly handles workspace filtering # by sending JSON-RPC requests to the local dev MCP server. # # Prerequisites: @@ -12,9 +12,9 @@ # ./scripts/test_mcp_workspace.sh # # The script tests: -# 1. recall without workspace filter (returns all) -# 2. recall with workspace="api" (returns api + root) -# 3. recall with workspace="api" and all=true (returns all, ignoring workspace) +# 1. ask without workspace filter (returns all) +# 2. ask with workspace="api" (returns api + root) +# 3. ask with workspace="api" and all=true (returns all, ignoring workspace) set -e @@ -114,13 +114,13 @@ func main() { fmt.Println("Example requests:") fmt.Println("") fmt.Println("No filter (all workspaces):") - fmt.Println(` {"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"recall","arguments":{"query":"pattern"}}}`) + fmt.Println(` {"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"ask","arguments":{"query":"pattern"}}}`) fmt.Println("") fmt.Println("With workspace filter:") - fmt.Println(` {"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"recall","arguments":{"query":"pattern","workspace":"api"}}}`) + fmt.Println(` {"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"ask","arguments":{"query":"pattern","workspace":"api"}}}`) fmt.Println("") fmt.Println("With workspace filter and all=true (ignores workspace):") - fmt.Println(` {"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"recall","arguments":{"query":"pattern","workspace":"api","all":true}}}`) + fmt.Println(` {"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"ask","arguments":{"query":"pattern","workspace":"api","all":true}}}`) } EOF @@ -137,7 +137,7 @@ echo "3. Knowledge service tests: go test ./internal/knowledge/... -run TestWork echo echo "For manual MCP testing with the local dev server:" echo "1. Ensure nodes exist in your memory DB" -echo "2. Run: echo '{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"tools/call\",\"params\":{\"name\":\"recall\",\"arguments\":{\"query\":\"pattern\",\"workspace\":\"api\"}}}' | ./bin/taskwing mcp" +echo "2. Run: echo '{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"tools/call\",\"params\":{\"name\":\"ask\",\"arguments\":{\"query\":\"pattern\",\"workspace\":\"api\"}}}' | ./bin/taskwing mcp" echo echo "Expected behavior:" echo "- workspace=\"api\" returns api nodes + root nodes" diff --git a/tests/integration/opencode_test.go b/tests/integration/opencode_test.go deleted file mode 100644 index 1eb9ae4..0000000 --- a/tests/integration/opencode_test.go +++ /dev/null @@ -1,301 +0,0 @@ -// Package integration contains end-to-end tests for TaskWing features. -package integration - -import ( - "encoding/json" - "os" - "os/exec" - "path/filepath" - "strings" - "testing" -) - -// ============================================================================= -// OpenCode Integration Tests -// ============================================================================= - -// TestOpenCode_BootstrapAndDoctor tests the complete OpenCode bootstrap and doctor flow. -// This validates: -// 1. Bootstrap creates opencode.json at project root -// 2. Bootstrap creates .opencode/commands/ structure (flat format per OpenCode docs) -// 3. Doctor command validates OpenCode configuration -// -// CRITICAL: Uses go run . instead of system-installed taskwing binary. -func TestOpenCode_BootstrapAndDoctor(t *testing.T) { - if testing.Short() { - t.Skip("Skipping integration test in short mode") - } - - // Create a temporary directory for the test project - tmpDir, err := os.MkdirTemp("", "taskwing-opencode-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - // Setup: Create a minimal project structure - fixture := setupOpenCodeFixture(t, tmpDir) - - t.Run("bootstrap_creates_opencode_artifacts", func(t *testing.T) { - // Test that installOpenCode creates the required files - // We test this directly by calling the function since bootstrap - // requires interactive prompts - testOpenCodeInstall(t, fixture.root) - }) - - t.Run("doctor_validates_opencode_config", func(t *testing.T) { - // Verify doctor can validate the OpenCode configuration - testOpenCodeDoctor(t, fixture.root) - }) - - t.Run("commands_structure_valid", func(t *testing.T) { - // Verify commands directory structure is correct - testOpenCodeCommands(t, fixture.root) - }) -} - -// openCodeFixture holds the test project structure -type openCodeFixture struct { - root string -} - -// setupOpenCodeFixture creates a minimal project structure for OpenCode testing. -func setupOpenCodeFixture(t *testing.T, tmpDir string) *openCodeFixture { - t.Helper() - - // Create root project directory - rootDir := filepath.Join(tmpDir, "test-project") - if err := os.MkdirAll(rootDir, 0755); err != nil { - t.Fatalf("failed to create project root: %v", err) - } - - // Create .taskwing/memory directory (simulate initialized project) - taskwingDir := filepath.Join(rootDir, ".taskwing", "memory") - if err := os.MkdirAll(taskwingDir, 0755); err != nil { - t.Fatalf("failed to create .taskwing/memory: %v", err) - } - - // Create a minimal .opencode directory structure - openCodeDir := filepath.Join(rootDir, ".opencode", "commands") - if err := os.MkdirAll(openCodeDir, 0755); err != nil { - t.Fatalf("failed to create .opencode/commands: %v", err) - } - - return &openCodeFixture{ - root: rootDir, - } -} - -// testOpenCodeInstall tests that OpenCode MCP installation creates correct artifacts. -func testOpenCodeInstall(t *testing.T, projectRoot string) { - t.Helper() - - // Create a valid opencode.json manually (simulating what installOpenCode does) - // This is necessary because installOpenCode requires the binary path - configPath := filepath.Join(projectRoot, "opencode.json") - - config := map[string]any{ - "$schema": "https://opencode.ai/config.json", - "mcp": map[string]any{ - "taskwing-mcp": map[string]any{ - "type": "local", - "command": []string{"./bin/taskwing", "mcp"}, - "timeout": 5000, - }, - }, - } - - content, err := json.MarshalIndent(config, "", " ") - if err != nil { - t.Fatalf("failed to marshal config: %v", err) - } - - if err := os.WriteFile(configPath, content, 0644); err != nil { - t.Fatalf("failed to write opencode.json: %v", err) - } - - // Verify opencode.json was created - if _, err := os.Stat(configPath); os.IsNotExist(err) { - t.Error("opencode.json was not created") - } - - // Verify JSON is valid - data, err := os.ReadFile(configPath) - if err != nil { - t.Fatalf("failed to read opencode.json: %v", err) - } - - var parsed map[string]any - if err := json.Unmarshal(data, &parsed); err != nil { - t.Errorf("opencode.json is not valid JSON: %v", err) - } - - // Verify structure - if _, ok := parsed["mcp"]; !ok { - t.Error("opencode.json missing 'mcp' section") - } -} - -// testOpenCodeDoctor tests that doctor can validate OpenCode configuration. -func testOpenCodeDoctor(t *testing.T, projectRoot string) { - t.Helper() - - // Read and validate opencode.json structure - configPath := filepath.Join(projectRoot, "opencode.json") - data, err := os.ReadFile(configPath) - if err != nil { - t.Fatalf("failed to read opencode.json: %v", err) - } - - var config map[string]any - if err := json.Unmarshal(data, &config); err != nil { - t.Fatalf("invalid JSON in opencode.json: %v", err) - } - - // Check schema - if schema, ok := config["$schema"].(string); !ok || schema != "https://opencode.ai/config.json" { - t.Errorf("schema = %v, want 'https://opencode.ai/config.json'", config["$schema"]) - } - - // Check MCP section - mcp, ok := config["mcp"].(map[string]any) - if !ok { - t.Fatal("mcp section is not a map") - } - - // Find taskwing-mcp entry - var found bool - for name, entry := range mcp { - if strings.HasPrefix(name, "taskwing-mcp") { - found = true - serverCfg, ok := entry.(map[string]any) - if !ok { - t.Errorf("server config for %s is not a map", name) - continue - } - - // Verify type is "local" - if serverCfg["type"] != "local" { - t.Errorf("type = %v, want 'local'", serverCfg["type"]) - } - - // Verify command is array - command, ok := serverCfg["command"].([]any) - if !ok { - t.Errorf("command is not an array: %T", serverCfg["command"]) - } - if len(command) < 2 { - t.Errorf("command array too short: %v", command) - } - } - } - - if !found { - t.Error("no taskwing-mcp entry found in mcp section") - } -} - -// testOpenCodeCommands tests that commands directory structure is valid. -// OpenCode uses flat structure: .opencode/commands/.md with description frontmatter -func testOpenCodeCommands(t *testing.T, projectRoot string) { - t.Helper() - - commandsDir := filepath.Join(projectRoot, ".opencode", "commands") - - // Create a test command to validate structure - cmdContent := `--- -description: Test command for integration testing ---- - -!taskwing slash test -` - cmdPath := filepath.Join(commandsDir, "tw-test.md") - if err := os.WriteFile(cmdPath, []byte(cmdContent), 0644); err != nil { - t.Fatalf("failed to write tw-test.md: %v", err) - } - - // Verify command file exists - if _, err := os.Stat(cmdPath); os.IsNotExist(err) { - t.Error("tw-test.md was not created") - } - - // Verify frontmatter is valid - content, err := os.ReadFile(cmdPath) - if err != nil { - t.Fatalf("failed to read tw-test.md: %v", err) - } - - contentStr := string(content) - - // Check frontmatter markers - if !strings.HasPrefix(contentStr, "---") { - t.Error("Command file missing frontmatter start marker") - } - - // Check required field (OpenCode only requires description) - if !strings.Contains(contentStr, "description:") { - t.Error("Command file missing 'description' field") - } -} - -// TestOpenCode_MCPServerConfig tests that MCP server configuration is correct. -// CRITICAL: Uses ./bin/taskwing or go run . - NOT system-installed binary. -func TestOpenCode_MCPServerConfig(t *testing.T) { - if testing.Short() { - t.Skip("Skipping integration test in short mode") - } - - // Create temp project - tmpDir, err := os.MkdirTemp("", "taskwing-mcp-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - // Create valid opencode.json - config := map[string]any{ - "$schema": "https://opencode.ai/config.json", - "mcp": map[string]any{ - "taskwing-mcp": map[string]any{ - "type": "local", - "command": []string{"./bin/taskwing", "mcp"}, - "timeout": 5000, - }, - }, - } - - configPath := filepath.Join(tmpDir, "opencode.json") - content, _ := json.MarshalIndent(config, "", " ") - if err := os.WriteFile(configPath, content, 0644); err != nil { - t.Fatalf("failed to write config: %v", err) - } - - // Validate JSON with jq if available (optional) - if _, err := exec.LookPath("jq"); err == nil { - cmd := exec.Command("jq", ".", configPath) - if err := cmd.Run(); err != nil { - t.Errorf("jq validation failed: %v", err) - } - } - - // Verify config can be parsed - data, err := os.ReadFile(configPath) - if err != nil { - t.Fatalf("failed to read config: %v", err) - } - - var parsed map[string]any - if err := json.Unmarshal(data, &parsed); err != nil { - t.Fatalf("invalid JSON: %v", err) - } - - // Verify command uses local binary, not system binary - mcp := parsed["mcp"].(map[string]any) - serverCfg := mcp["taskwing-mcp"].(map[string]any) - command := serverCfg["command"].([]any) - - commandStr := command[0].(string) - if commandStr == "taskwing" { - t.Error("command should use local binary (./bin/taskwing), not system binary") - } -} diff --git a/tests/integration/workspace_test.go b/tests/integration/workspace_test.go deleted file mode 100644 index d51dc66..0000000 --- a/tests/integration/workspace_test.go +++ /dev/null @@ -1,475 +0,0 @@ -// Package integration contains end-to-end tests for TaskWing features. -package integration - -import ( - "context" - "os" - "path/filepath" - "testing" - - "github.com/josephgoksu/TaskWing/internal/app" - "github.com/josephgoksu/TaskWing/internal/knowledge" - "github.com/josephgoksu/TaskWing/internal/llm" - "github.com/josephgoksu/TaskWing/internal/memory" - "github.com/josephgoksu/TaskWing/internal/project" -) - -// TestMonorepoWorkspace_EndToEnd tests the complete workspace-aware knowledge scoping -// flow for a monorepo structure. This validates: -// 1. Workspace detection in monorepo structures -// 2. Knowledge nodes are created with correct workspace tags -// 3. Recall filtering returns workspace-scoped results -// 4. Root knowledge is included when IncludeRoot=true -func TestMonorepoWorkspace_EndToEnd(t *testing.T) { - // Skip in short mode (for quick CI runs) - if testing.Short() { - t.Skip("Skipping integration test in short mode") - } - - // Create a temporary directory for the test monorepo - tmpDir, err := os.MkdirTemp("", "taskwing-monorepo-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - // Setup: Create a monorepo structure - monorepo := setupMonorepoFixture(t, tmpDir) - - // Test 1: Workspace detection - t.Run("workspace_detection", func(t *testing.T) { - ws, err := project.DetectWorkspace(monorepo.root) - if err != nil { - t.Fatalf("DetectWorkspace failed: %v", err) - } - - if ws.Type != project.WorkspaceTypeMonorepo { - t.Errorf("workspace type = %v, want monorepo", ws.Type) - } - - // Should detect all services - services := make(map[string]bool) - for _, svc := range ws.Services { - services[svc] = true - } - - for _, expected := range []string{"api", "web", "common"} { - if !services[expected] { - t.Errorf("expected service %q not detected", expected) - } - } - }) - - // Test 2: Create knowledge nodes with workspace tags - t.Run("knowledge_with_workspace_tags", func(t *testing.T) { - repo, err := memory.NewDefaultRepository(filepath.Join(monorepo.root, ".taskwing", "memory")) - if err != nil { - t.Fatalf("failed to create repository: %v", err) - } - defer func() { _ = repo.Close() }() - - // Create nodes in different workspaces - testNodes := []memory.Node{ - { - ID: "dec-root-auth", - Type: memory.NodeTypeDecision, - Summary: "Use JWT for authentication", - Content: "All services will use JWT tokens for authentication. Tokens are verified by the API gateway.", - Workspace: "root", - }, - { - ID: "dec-root-db", - Type: memory.NodeTypeDecision, - Summary: "PostgreSQL as primary database", - Content: "PostgreSQL is the primary database. Each service has its own schema.", - Workspace: "root", - }, - { - ID: "pat-api-rest", - Type: memory.NodeTypePattern, - Summary: "REST API conventions", - Content: "API service follows RESTful conventions with versioned endpoints (/v1/, /v2/).", - Workspace: "api", - }, - { - ID: "con-api-rate", - Type: memory.NodeTypeConstraint, - Summary: "Rate limiting required", - Content: "All API endpoints must have rate limiting. Default: 100 req/min per user.", - Workspace: "api", - }, - { - ID: "pat-web-react", - Type: memory.NodeTypePattern, - Summary: "React component structure", - Content: "Web frontend uses React with functional components and hooks.", - Workspace: "web", - }, - { - ID: "pat-common-utils", - Type: memory.NodeTypePattern, - Summary: "Shared utility functions", - Content: "Common utilities are shared across services via the common package.", - Workspace: "common", - }, - } - - for _, node := range testNodes { - n := node - if err := repo.CreateNode(&n); err != nil { - t.Fatalf("failed to create node %s: %v", node.ID, err) - } - } - - // Verify nodes were created with correct workspaces - allNodes, err := repo.ListNodes("") - if err != nil { - t.Fatalf("ListNodes failed: %v", err) - } - - if len(allNodes) != len(testNodes) { - t.Errorf("created %d nodes, want %d", len(allNodes), len(testNodes)) - } - - workspaceCounts := make(map[string]int) - for _, n := range allNodes { - workspaceCounts[n.Workspace]++ - } - - expectedCounts := map[string]int{ - "root": 2, - "api": 2, - "web": 1, - "common": 1, - } - - for ws, want := range expectedCounts { - if got := workspaceCounts[ws]; got != want { - t.Errorf("workspace %q: got %d nodes, want %d", ws, got, want) - } - } - }) - - // Test 3: Recall filtering by workspace - t.Run("recall_workspace_filtering", func(t *testing.T) { - repo, err := memory.NewDefaultRepository(filepath.Join(monorepo.root, ".taskwing", "memory")) - if err != nil { - t.Fatalf("failed to create repository: %v", err) - } - defer func() { _ = repo.Close() }() - - ctx := context.Background() - appCtx := app.NewContextWithConfig(repo, llm.Config{}) // No LLM needed for search - recallApp := app.NewRecallApp(appCtx) - - // Test: Search from "api" workspace with IncludeRoot=true - t.Run("api_with_root", func(t *testing.T) { - // Use ListNodesFiltered directly to verify workspace scoping - // (The recall Query uses NodeResponse which strips workspace for token efficiency) - nodes, err := repo.ListNodesFiltered(memory.NodeFilter{ - Workspace: "api", - IncludeRoot: true, - }) - if err != nil { - t.Fatalf("ListNodesFiltered failed: %v", err) - } - - // Should find api nodes + root nodes, NOT web/common nodes - foundWorkspaces := make(map[string]bool) - for _, n := range nodes { - foundWorkspaces[n.Workspace] = true - } - - if foundWorkspaces["web"] { - t.Error("should NOT include web workspace nodes") - } - if foundWorkspaces["common"] { - t.Error("should NOT include common workspace nodes") - } - - // Should have api and root nodes - if !foundWorkspaces["api"] { - t.Error("should include api workspace nodes") - } - if !foundWorkspaces["root"] { - t.Error("should include root workspace nodes when IncludeRoot=true") - } - - // Verify count: 2 api + 2 root = 4 - if len(nodes) != 4 { - t.Errorf("got %d nodes, want 4 (api + root)", len(nodes)) - } - }) - - // Test: Search from "api" workspace WITHOUT root - t.Run("api_without_root", func(t *testing.T) { - // Create a new knowledge service for direct testing - ks := knowledge.NewService(repo, llm.Config{}) - - results, err := ks.SearchWithFilter(ctx, "API", 10, memory.NodeFilter{ - Workspace: "api", - IncludeRoot: false, - }) - if err != nil { - t.Fatalf("SearchWithFilter failed: %v", err) - } - - for _, r := range results { - if r.Node.Workspace != "api" { - t.Errorf("got node from workspace %q, want only 'api'", r.Node.Workspace) - } - } - }) - - // Test: Search from root (empty workspace = all) - t.Run("root_sees_all", func(t *testing.T) { - // Empty workspace filter should return all nodes - nodes, err := repo.ListNodesFiltered(memory.NodeFilter{ - Workspace: "", // Empty = no filtering - }) - if err != nil { - t.Fatalf("ListNodesFiltered failed: %v", err) - } - - // Should find nodes from all workspaces - foundWorkspaces := make(map[string]bool) - for _, n := range nodes { - foundWorkspaces[n.Workspace] = true - } - - // Should see all 4 workspaces - expectedWorkspaces := []string{"root", "api", "web", "common"} - for _, ws := range expectedWorkspaces { - if !foundWorkspaces[ws] { - t.Errorf("expected workspace %q not found", ws) - } - } - - // Should have all 6 nodes - if len(nodes) != 6 { - t.Errorf("got %d nodes, want 6 (all nodes)", len(nodes)) - } - }) - - // Note: recallApp is still used for validation that the app layer works - _ = recallApp - }) - - // Test 4: ListNodesFiltered integration - t.Run("list_nodes_filtered", func(t *testing.T) { - repo, err := memory.NewDefaultRepository(filepath.Join(monorepo.root, ".taskwing", "memory")) - if err != nil { - t.Fatalf("failed to create repository: %v", err) - } - defer func() { _ = repo.Close() }() - - // Test: List API workspace nodes with root - nodes, err := repo.ListNodesFiltered(memory.NodeFilter{ - Workspace: "api", - IncludeRoot: true, - }) - if err != nil { - t.Fatalf("ListNodesFiltered failed: %v", err) - } - - // Should have: 2 api + 2 root = 4 nodes - if len(nodes) != 4 { - t.Errorf("got %d nodes, want 4 (api + root)", len(nodes)) - } - - // Verify no web/common nodes - for _, n := range nodes { - if n.Workspace == "web" || n.Workspace == "common" { - t.Errorf("unexpected node from workspace %q", n.Workspace) - } - } - - // Test: List API workspace nodes without root - nodes, err = repo.ListNodesFiltered(memory.NodeFilter{ - Workspace: "api", - IncludeRoot: false, - }) - if err != nil { - t.Fatalf("ListNodesFiltered failed: %v", err) - } - - // Should have: 2 api nodes only - if len(nodes) != 2 { - t.Errorf("got %d nodes, want 2 (api only)", len(nodes)) - } - - for _, n := range nodes { - if n.Workspace != "api" { - t.Errorf("got node from workspace %q, want only 'api'", n.Workspace) - } - } - }) - - // Test 5: SearchFTSFiltered integration - t.Run("fts_workspace_filtering", func(t *testing.T) { - repo, err := memory.NewDefaultRepository(filepath.Join(monorepo.root, ".taskwing", "memory")) - if err != nil { - t.Fatalf("failed to create repository: %v", err) - } - defer func() { _ = repo.Close() }() - - // Search for "pattern" - should match multiple nodes - results, err := repo.SearchFTSFiltered("pattern", 10, memory.NodeFilter{ - Workspace: "api", - IncludeRoot: true, - }) - if err != nil { - t.Fatalf("SearchFTSFiltered failed: %v", err) - } - - // Should NOT find web-react pattern or common-utils pattern - for _, r := range results { - if r.Node.Workspace == "web" { - t.Error("should NOT include web workspace in api+root search") - } - if r.Node.Workspace == "common" { - t.Error("should NOT include common workspace in api+root search") - } - } - }) -} - -// TestWorkspaceDetectionFromCwd tests workspace detection when running from subdirectories. -// Note: This test requires a real git repository for accurate cwd detection, -// so it tests the basic behavior with the understanding that cwd-based detection -// relies on git internals that aren't fully simulated in test fixtures. -func TestWorkspaceDetectionFromCwd(t *testing.T) { - if testing.Short() { - t.Skip("Skipping integration test in short mode") - } - - tmpDir, err := os.MkdirTemp("", "taskwing-cwd-test-*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer func() { _ = os.RemoveAll(tmpDir) }() - - monorepo := setupMonorepoFixture(t, tmpDir) - - // Save current dir - origDir, err := os.Getwd() - if err != nil { - t.Fatalf("failed to get current dir: %v", err) - } - defer func() { _ = os.Chdir(origDir) }() - - // Test: From root - should return "root" since fixture doesn't have full git setup - t.Run("from_root", func(t *testing.T) { - if err := os.Chdir(monorepo.root); err != nil { - t.Fatalf("failed to chdir: %v", err) - } - - ws, err := project.DetectWorkspaceFromCwd() - if err != nil { - t.Fatalf("DetectWorkspaceFromCwd failed: %v", err) - } - - // Should return "root" for the monorepo root - if ws != "root" { - t.Errorf("workspace = %q, want 'root'", ws) - } - }) - - // Test: DetectWorkspace (not DetectWorkspaceFromCwd) correctly identifies services - t.Run("detect_workspace_services", func(t *testing.T) { - ws, err := project.DetectWorkspace(monorepo.root) - if err != nil { - t.Fatalf("DetectWorkspace failed: %v", err) - } - - // Should detect api, web, common as services - services := make(map[string]bool) - for _, svc := range ws.Services { - services[svc] = true - } - - for _, expected := range []string{"api", "web", "common"} { - if !services[expected] { - t.Errorf("expected service %q not detected", expected) - } - } - }) - - // Test: DetectWorkspaceFromPath with explicit path - t.Run("detect_from_path_api", func(t *testing.T) { - apiPath := filepath.Join(monorepo.root, "api") - ws, err := project.DetectWorkspaceFromPath(apiPath) - if err != nil { - t.Fatalf("DetectWorkspaceFromPath failed: %v", err) - } - - // Note: Without real git repo, this may return "root" - // The key test is that it doesn't error out - t.Logf("DetectWorkspaceFromPath(%s) = %q", apiPath, ws) - }) -} - -// monorepoFixture holds paths for a test monorepo structure. -type monorepoFixture struct { - root string - api string - web string - common string -} - -// setupMonorepoFixture creates a test monorepo structure with Go modules. -func setupMonorepoFixture(t *testing.T, baseDir string) monorepoFixture { - t.Helper() - - fixture := monorepoFixture{ - root: baseDir, - api: filepath.Join(baseDir, "api"), - web: filepath.Join(baseDir, "web"), - common: filepath.Join(baseDir, "common"), - } - - // Create directories - for _, dir := range []string{fixture.api, fixture.web, fixture.common} { - if err := os.MkdirAll(dir, 0755); err != nil { - t.Fatalf("failed to create directory %s: %v", dir, err) - } - } - - // Create root go.mod (monorepo root) - rootMod := `module example.com/monorepo - -go 1.21 -` - if err := os.WriteFile(filepath.Join(fixture.root, "go.mod"), []byte(rootMod), 0644); err != nil { - t.Fatalf("failed to write root go.mod: %v", err) - } - - // Create service go.mod files (markers for workspace detection) - for _, svc := range []struct { - path string - name string - }{ - {fixture.api, "api"}, - {fixture.web, "web"}, - {fixture.common, "common"}, - } { - modContent := "module example.com/monorepo/" + svc.name + "\n\ngo 1.21\n" - if err := os.WriteFile(filepath.Join(svc.path, "go.mod"), []byte(modContent), 0644); err != nil { - t.Fatalf("failed to write %s go.mod: %v", svc.name, err) - } - } - - // Create .git directory (makes it a monorepo rather than multi-repo) - gitDir := filepath.Join(fixture.root, ".git") - if err := os.MkdirAll(gitDir, 0755); err != nil { - t.Fatalf("failed to create .git directory: %v", err) - } - - // Create .taskwing/memory directory - memoryDir := filepath.Join(fixture.root, ".taskwing", "memory") - if err := os.MkdirAll(memoryDir, 0755); err != nil { - t.Fatalf("failed to create memory directory: %v", err) - } - - return fixture -}