Skip to content

Commit 6fdacfa

Browse files
authored
Merge pull request #14 from josephgoksu/fix/mcp-validation-bugs-and-code-quality
fix: MCP validation bugs, panic elimination, and code quality improvements
2 parents 82561e2 + 87946fb commit 6fdacfa

15 files changed

Lines changed: 572 additions & 77 deletions

File tree

cmd/config.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,10 @@ func detectProjectRoot() *project.Context {
104104
}
105105

106106
// Store in config package for GetMemoryBasePath and other consumers
107-
config.SetProjectContext(ctx)
107+
if err := config.SetProjectContext(ctx); err != nil {
108+
fmt.Fprintf(os.Stderr, "Warning: failed to set project context: %v\n", err)
109+
return nil
110+
}
108111

109112
// Log in verbose mode
110113
if viper.GetBool("verbose") && ctx.RootPath != cwd {

cmd/hook.go

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -436,22 +436,32 @@ Tasks Completed: %d
436436
`, session.SessionID, int(elapsed.Minutes()), session.TasksCompleted)
437437

438438
// Remove session file
439-
sessionPath := getHookSessionPath()
440-
_ = os.Remove(sessionPath)
439+
sessionPath, err := getHookSessionPath()
440+
if err == nil {
441+
_ = os.Remove(sessionPath)
442+
}
441443

442444
return nil
443445
}
444446

445447
// Session persistence helpers
446448

447-
func getHookSessionPath() string {
449+
func getHookSessionPath() (string, error) {
448450
// Hook commands use GetMemoryBasePathOrGlobal since they may run
449451
// before project context is fully established (e.g., SessionStart)
450-
return filepath.Join(config.GetMemoryBasePathOrGlobal(), "hook_session.json")
452+
memoryPath, err := config.GetMemoryBasePathOrGlobal()
453+
if err != nil {
454+
return "", fmt.Errorf("get memory path: %w", err)
455+
}
456+
return filepath.Join(memoryPath, "hook_session.json"), nil
451457
}
452458

453459
func loadHookSession() (*HookSession, error) {
454-
data, err := os.ReadFile(getHookSessionPath())
460+
sessionPath, err := getHookSessionPath()
461+
if err != nil {
462+
return nil, err
463+
}
464+
data, err := os.ReadFile(sessionPath)
455465
if err != nil {
456466
return nil, err
457467
}
@@ -470,7 +480,10 @@ func saveHookSession(session *HookSession) error {
470480
return err
471481
}
472482

473-
sessionPath := getHookSessionPath()
483+
sessionPath, err := getHookSessionPath()
484+
if err != nil {
485+
return err
486+
}
474487
// Ensure directory exists
475488
if err := os.MkdirAll(filepath.Dir(sessionPath), 0755); err != nil {
476489
return fmt.Errorf("create session directory: %w", err)

cmd/mcp_server.go

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,10 @@ func mcpFormattedErrorResponse(formattedError string) (*mcpsdk.CallToolResultFor
9797
func initMCPRepository() (*memory.Repository, error) {
9898
// MCP server is a special case - it may run in sandboxed environments
9999
// where project context isn't available. Use the fallback-enabled path.
100-
memoryPath := config.GetMemoryBasePathOrGlobal()
100+
memoryPath, err := config.GetMemoryBasePathOrGlobal()
101+
if err != nil {
102+
return nil, fmt.Errorf("determine memory path: %w", err)
103+
}
101104

102105
repo, err := memory.NewDefaultRepository(memoryPath)
103106
if err != nil {
@@ -214,7 +217,13 @@ func runMCPServer(ctx context.Context) error {
214217
- next: Get next pending task from plan (use auto_start=true to claim immediately)
215218
- current: Get current in-progress task for session
216219
- start: Claim a specific task by ID
217-
- complete: Mark task as completed with summary`,
220+
- complete: Mark task as completed with summary
221+
222+
REQUIRED FIELDS BY ACTION:
223+
- next: session_id (required)
224+
- current: session_id (required)
225+
- start: task_id (required), session_id (required)
226+
- complete: task_id (required)`,
218227
}
219228
mcpsdk.AddTool(server, taskTool, func(ctx context.Context, session *mcpsdk.ServerSession, params *mcpsdk.CallToolParamsFor[mcppresenter.TaskToolParams]) (*mcpsdk.CallToolResultFor[any], error) {
220229
result, err := mcppresenter.HandleTaskTool(ctx, repo, params.Arguments)
@@ -233,7 +242,12 @@ func runMCPServer(ctx context.Context) error {
233242
Description: `Unified plan creation tool. Use action parameter to select operation:
234243
- clarify: Refine goal with clarifying questions (loop until is_ready_to_plan=true)
235244
- generate: Create plan with tasks from enriched goal
236-
- audit: Verify completed plan with build/test/semantic checks (auto-fixes failures)`,
245+
- audit: Verify completed plan with build/test/semantic checks (auto-fixes failures)
246+
247+
REQUIRED FIELDS BY ACTION:
248+
- clarify: goal (required)
249+
- generate: goal (required), enriched_goal (required) - call clarify first to get enriched_goal
250+
- audit: none required (defaults to active plan)`,
237251
}
238252
mcpsdk.AddTool(server, planTool, func(ctx context.Context, session *mcpsdk.ServerSession, params *mcpsdk.CallToolParamsFor[mcppresenter.PlanToolParams]) (*mcpsdk.CallToolResultFor[any], error) {
239253
result, err := mcppresenter.HandlePlanTool(ctx, repo, params.Arguments)

cmd/plan.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -357,13 +357,12 @@ func printPlanTable(plans []task.Plan) {
357357
goal = goal[:57] + "..."
358358
}
359359
// Tasks count - service ListPlans probably returns plans without tasks or with?
360-
// task.Repository interface implies ListPlans returns []Plan which contains Tasks?
361-
// SQLite implementation usually does. If not, we might check tasks length.
362-
// Assuming populated for now or length 0.
360+
// ListPlans sets TaskCount but leaves Tasks nil for efficiency.
361+
// Use GetTaskCount() to get the count regardless of how the plan was loaded.
363362
fmt.Printf("%-18s %-12s %-6d %s\n",
364363
idStyle.Render(p.ID),
365364
dateStyle.Render(p.CreatedAt.Format("2006-01-02")),
366-
len(p.Tasks),
365+
p.GetTaskCount(),
367366
goalStyle.Render(goal))
368367
}
369368
fmt.Printf("\n%s\n", lipgloss.NewStyle().Foreground(lipgloss.Color("241")).Render(fmt.Sprintf("Total: %d plan(s)", len(plans))))

internal/config/paths.go

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,15 @@ var GetGlobalConfigDir = func() (string, error) {
3838

3939
// SetProjectContext sets the detected project context for use by GetMemoryBasePath.
4040
// This MUST be called during CLI initialization before any command that needs project context.
41-
func SetProjectContext(ctx *project.Context) {
41+
// Returns error if ctx is nil.
42+
func SetProjectContext(ctx *project.Context) error {
4243
if ctx == nil {
43-
panic("SetProjectContext called with nil context")
44+
return errors.New("SetProjectContext called with nil context")
4445
}
4546
projectContextMu.Lock()
4647
defer projectContextMu.Unlock()
4748
projectContext = ctx
49+
return nil
4850
}
4951

5052
// ClearProjectContext resets the project context. Only use in tests.
@@ -62,14 +64,14 @@ func GetProjectContext() *project.Context {
6264
return projectContext
6365
}
6466

65-
// MustGetProjectContext returns the project context or panics if not set.
66-
// Use this when project context is required and absence is a programming error.
67-
func MustGetProjectContext() *project.Context {
67+
// GetProjectContextOrError returns the project context or an error if not set.
68+
// Use this when project context is required.
69+
func GetProjectContextOrError() (*project.Context, error) {
6870
ctx := GetProjectContext()
6971
if ctx == nil {
70-
panic(ErrProjectContextNotSet)
72+
return nil, ErrProjectContextNotSet
7173
}
72-
return ctx
74+
return ctx, nil
7375
}
7476

7577
// DetectAndSetProjectContext detects the project root and sets it.
@@ -90,7 +92,9 @@ func DetectAndSetProjectContext() (*project.Context, error) {
9092
return nil, fmt.Errorf("%w: %v", ErrDetectionFailed, err)
9193
}
9294

93-
SetProjectContext(ctx)
95+
if err := SetProjectContext(ctx); err != nil {
96+
return nil, fmt.Errorf("set project context: %w", err)
97+
}
9498
return ctx, nil
9599
}
96100

@@ -133,19 +137,18 @@ func GetMemoryBasePath() (string, error) {
133137
//
134138
// ALL OTHER COMMANDS should use GetMemoryBasePath() which enforces fail-fast behavior.
135139
// Using this function inappropriately masks project detection failures.
136-
func GetMemoryBasePathOrGlobal() string {
140+
func GetMemoryBasePathOrGlobal() (string, error) {
137141
path, err := GetMemoryBasePath()
138142
if err == nil {
139-
return path
143+
return path, nil
140144
}
141145

142146
// Only fall back to global for non-project commands
143147
dir, err := GetGlobalConfigDir()
144148
if err != nil {
145-
// This is a critical failure - can't determine any valid path
146-
panic(fmt.Sprintf("cannot determine memory path: %v", err))
149+
return "", fmt.Errorf("cannot determine memory path: %w", err)
147150
}
148-
return filepath.Join(dir, "memory")
151+
return filepath.Join(dir, "memory"), nil
149152
}
150153

151154
// GetProjectRoot returns the detected project root path.
@@ -160,13 +163,3 @@ func GetProjectRoot() (string, error) {
160163
}
161164
return ctx.RootPath, nil
162165
}
163-
164-
// MustGetProjectRoot returns the project root or panics.
165-
// Use when project root is required and absence is a programming error.
166-
func MustGetProjectRoot() string {
167-
root, err := GetProjectRoot()
168-
if err != nil {
169-
panic(err)
170-
}
171-
return root
172-
}

internal/config/paths_test.go

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
package config
2+
3+
import (
4+
"errors"
5+
"testing"
6+
7+
"github.com/josephgoksu/TaskWing/internal/project"
8+
)
9+
10+
func TestSetProjectContext_NilReturnsError(t *testing.T) {
11+
// Clear any existing context
12+
ClearProjectContext()
13+
14+
err := SetProjectContext(nil)
15+
if err == nil {
16+
t.Fatal("expected error for nil context, got nil")
17+
}
18+
19+
// Verify error message is helpful
20+
if err.Error() != "SetProjectContext called with nil context" {
21+
t.Errorf("unexpected error message: %s", err.Error())
22+
}
23+
}
24+
25+
func TestSetProjectContext_ValidContext(t *testing.T) {
26+
// Clear any existing context
27+
ClearProjectContext()
28+
defer ClearProjectContext()
29+
30+
ctx := &project.Context{
31+
RootPath: "/test/path",
32+
MarkerType: project.MarkerGit,
33+
}
34+
35+
err := SetProjectContext(ctx)
36+
if err != nil {
37+
t.Fatalf("unexpected error: %v", err)
38+
}
39+
40+
// Verify context was set
41+
got := GetProjectContext()
42+
if got == nil {
43+
t.Fatal("expected context to be set")
44+
}
45+
if got.RootPath != ctx.RootPath {
46+
t.Errorf("expected RootPath %q, got %q", ctx.RootPath, got.RootPath)
47+
}
48+
}
49+
50+
func TestGetProjectContextOrError_NotSet(t *testing.T) {
51+
ClearProjectContext()
52+
53+
ctx, err := GetProjectContextOrError()
54+
if err == nil {
55+
t.Fatal("expected error when context not set")
56+
}
57+
if !errors.Is(err, ErrProjectContextNotSet) {
58+
t.Errorf("expected ErrProjectContextNotSet, got: %v", err)
59+
}
60+
if ctx != nil {
61+
t.Error("expected nil context")
62+
}
63+
}
64+
65+
func TestGetProjectContextOrError_Set(t *testing.T) {
66+
ClearProjectContext()
67+
defer ClearProjectContext()
68+
69+
expected := &project.Context{RootPath: "/test"}
70+
_ = SetProjectContext(expected)
71+
72+
ctx, err := GetProjectContextOrError()
73+
if err != nil {
74+
t.Fatalf("unexpected error: %v", err)
75+
}
76+
if ctx != expected {
77+
t.Error("context does not match expected")
78+
}
79+
}
80+
81+
func TestGetProjectRoot_NotSet(t *testing.T) {
82+
ClearProjectContext()
83+
84+
root, err := GetProjectRoot()
85+
if err == nil {
86+
t.Fatal("expected error when context not set")
87+
}
88+
if !errors.Is(err, ErrProjectContextNotSet) {
89+
t.Errorf("expected ErrProjectContextNotSet, got: %v", err)
90+
}
91+
if root != "" {
92+
t.Errorf("expected empty root, got: %s", root)
93+
}
94+
}
95+
96+
func TestGetProjectRoot_EmptyRootPath(t *testing.T) {
97+
ClearProjectContext()
98+
defer ClearProjectContext()
99+
100+
ctx := &project.Context{RootPath: ""}
101+
_ = SetProjectContext(ctx)
102+
103+
root, err := GetProjectRoot()
104+
if err == nil {
105+
t.Fatal("expected error for empty RootPath")
106+
}
107+
if root != "" {
108+
t.Errorf("expected empty root, got: %s", root)
109+
}
110+
}
111+
112+
func TestGetProjectRoot_Valid(t *testing.T) {
113+
ClearProjectContext()
114+
defer ClearProjectContext()
115+
116+
expected := "/my/project"
117+
ctx := &project.Context{RootPath: expected}
118+
_ = SetProjectContext(ctx)
119+
120+
root, err := GetProjectRoot()
121+
if err != nil {
122+
t.Fatalf("unexpected error: %v", err)
123+
}
124+
if root != expected {
125+
t.Errorf("expected %q, got %q", expected, root)
126+
}
127+
}
128+
129+
func TestGetMemoryBasePath_NotSet(t *testing.T) {
130+
ClearProjectContext()
131+
132+
path, err := GetMemoryBasePath()
133+
if err == nil {
134+
t.Fatal("expected error when context not set")
135+
}
136+
if !errors.Is(err, ErrProjectContextNotSet) {
137+
t.Errorf("expected ErrProjectContextNotSet, got: %v", err)
138+
}
139+
if path != "" {
140+
t.Errorf("expected empty path, got: %s", path)
141+
}
142+
}
143+
144+
func TestGetMemoryBasePathOrGlobal_FallsBackToGlobal(t *testing.T) {
145+
ClearProjectContext()
146+
147+
// Should fall back to global without error
148+
path, err := GetMemoryBasePathOrGlobal()
149+
if err != nil {
150+
t.Fatalf("unexpected error: %v", err)
151+
}
152+
if path == "" {
153+
t.Error("expected non-empty path")
154+
}
155+
// Should contain "memory" in the path
156+
if len(path) < 6 || path[len(path)-6:] != "memory" {
157+
t.Errorf("expected path to end with 'memory', got: %s", path)
158+
}
159+
}
160+
161+
func TestGetMemoryBasePathOrGlobal_GlobalDirError(t *testing.T) {
162+
ClearProjectContext()
163+
164+
// Save original function
165+
original := GetGlobalConfigDir
166+
defer func() { GetGlobalConfigDir = original }()
167+
168+
// Mock to return error
169+
GetGlobalConfigDir = func() (string, error) {
170+
return "", errors.New("test error: cannot get home dir")
171+
}
172+
173+
path, err := GetMemoryBasePathOrGlobal()
174+
if err == nil {
175+
t.Fatal("expected error when global config dir fails")
176+
}
177+
if path != "" {
178+
t.Errorf("expected empty path on error, got: %s", path)
179+
}
180+
}

0 commit comments

Comments
 (0)