diff --git a/cmd/gosqlx/cmd/lint_test.go b/cmd/gosqlx/cmd/lint_test.go new file mode 100644 index 00000000..09b0e20d --- /dev/null +++ b/cmd/gosqlx/cmd/lint_test.go @@ -0,0 +1,1025 @@ +package cmd + +import ( + "bytes" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/spf13/cobra" +) + +// TestLintCmd_Basic tests basic linting functionality +// Note: lintFailOnWarn must be false to avoid os.Exit(1) on warnings +func TestLintCmd_Basic(t *testing.T) { + tests := []struct { + name string + files map[string]string // filename -> content + expectedOutput []string // strings expected in output + expectedError bool + }{ + { + name: "No violations - clean SQL", + files: map[string]string{ + "test.sql": "SELECT id FROM users", + }, + expectedOutput: []string{"Total files: 1", "Total violations: 0"}, + expectedError: false, + }, + { + name: "Trailing whitespace violation (L001)", + files: map[string]string{ + "test.sql": "SELECT id FROM users ", + }, + expectedOutput: []string{ + "test.sql", + "violation", + "L001", + "Trailing Whitespace", + "line 1", + }, + expectedError: false, + }, + { + name: "Multiple files with violations", + files: map[string]string{ + "query1.sql": "SELECT * FROM users ", + "query2.sql": "SELECT * FROM orders\t", + "query3.sql": "SELECT * FROM products", + }, + expectedOutput: []string{ + "query1.sql", + "query2.sql", + "Total files: 3", + "Total violations: 2", + }, + expectedError: false, + }, + // Skipped: L002 is SeverityError which triggers os.Exit(1) regardless of --fail-on-warn + // This would be tested in integration tests with subprocess + // { + // name: "All three rule violations", + // files: map[string]string{ + // "test.sql": "SELECT id FROM users WHERE name = 'test' AND email = 'test@example.com' AND active = true AND created_at > NOW() \n\t SELECT * FROM orders", + // }, + // expectedOutput: []string{ + // "L001", // Trailing whitespace - SeverityWarning + // "L002", // Mixed indentation - SeverityError (triggers os.Exit) + // "L005", // Long lines - SeverityInfo + // }, + // expectedError: false, + // }, + { + name: "Long line violation (L005)", + files: map[string]string{ + "test.sql": "SELECT column1, column2, column3, column4, column5, column6, column7, column8, column9, column10 FROM users WHERE active = true", + }, + expectedOutput: []string{ + "L005", + "Long Lines", + "exceeds maximum length", + }, + expectedError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + // Setup temp files + tmpDir := t.TempDir() + var args []string + + for filename, content := range tt.files { + path := filepath.Join(tmpDir, filename) + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + args = append(args, path) + } + + // Create command with buffers + var outBuf, errBuf bytes.Buffer + cmd := &cobra.Command{} + cmd.SetOut(&outBuf) + cmd.SetErr(&errBuf) + + // Reset flags + lintRecursive = false + lintPattern = "*.sql" + lintAutoFix = false + lintMaxLength = 100 + lintFailOnWarn = false + + // Run lint command + err := lintRun(cmd, args) + + // Check error + if tt.expectedError { + if err == nil { + t.Errorf("Expected error but got none") + } + } else { + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + } + + // Verify output + output := outBuf.String() + for _, expected := range tt.expectedOutput { + if !strings.Contains(output, expected) { + t.Errorf("Expected output to contain '%s', got:\n%s", expected, output) + } + } + }) + } +} + +// TestLintCmd_NonExistentFile tests error handling for non-existent files +func TestLintCmd_NonExistentFile(t *testing.T) { + var outBuf, errBuf bytes.Buffer + cmd := &cobra.Command{} + cmd.SetOut(&outBuf) + cmd.SetErr(&errBuf) + + // Reset flags + lintRecursive = false + lintAutoFix = false + lintMaxLength = 100 + + args := []string{"/nonexistent/file.sql"} + err := lintRun(cmd, args) + + if err != nil { + t.Errorf("Command should not return error for file read failure: %v", err) + } + + output := outBuf.String() + if !strings.Contains(output, "ERROR") && !strings.Contains(output, "failed to read") { + t.Errorf("Expected error message in output for non-existent file, got: %s", output) + } +} + +// TestLintCmd_EmptyFileList tests error handling for empty file list +func TestLintCmd_EmptyFileList(t *testing.T) { + // Save original stdin + oldStdin := os.Stdin + defer func() { os.Stdin = oldStdin }() + + // Create a fake stdin that returns terminal (not pipe) + r, w, _ := os.Pipe() + os.Stdin = r + w.Close() + + var outBuf, errBuf bytes.Buffer + cmd := &cobra.Command{} + cmd.SetOut(&outBuf) + cmd.SetErr(&errBuf) + + err := lintRun(cmd, []string{}) + + if err == nil { + t.Error("Expected error for empty file list") + } + + // The error message could be either "no input provided" or "failed to read from stdin: stdin is empty" + // depending on whether stdin is detected as pipe or not + if !strings.Contains(err.Error(), "no input provided") && !strings.Contains(err.Error(), "stdin is empty") { + t.Errorf("Expected 'no input provided' or 'stdin is empty' error, got: %v", err) + } +} + +// TestLintCmd_Recursive tests recursive directory linting +func TestLintCmd_Recursive(t *testing.T) { + tests := []struct { + name string + files map[string]string + pattern string + expectedFileCount int + expectedViolations int + expectedInOutput []string + }{ + { + name: "Recursive with default pattern", + files: map[string]string{ + "query1.sql": "SELECT * FROM users ", + "query2.sql": "SELECT * FROM orders", + "subdir/query3.sql": "SELECT * FROM products\t", + }, + pattern: "*.sql", + expectedFileCount: 3, + expectedViolations: 2, + expectedInOutput: []string{"query1.sql", "query3.sql", "Total files: 3"}, + }, + { + name: "Recursive with custom pattern", + files: map[string]string{ + "query.sql": "SELECT * FROM users", + "migration_001.sql": "CREATE TABLE users (id INT)", + "migration_002.sql": "CREATE TABLE orders (id INT) ", + }, + pattern: "migration_*.sql", + expectedFileCount: 2, + expectedViolations: 1, + expectedInOutput: []string{"migration_002.sql", "Total files: 2"}, + }, + { + name: "Nested directories", + files: map[string]string{ + "a/query.sql": "SELECT 1 ", + "a/b/query.sql": "SELECT 2", + "a/b/c/query.sql": "SELECT 3\t", + }, + pattern: "*.sql", + expectedFileCount: 3, + expectedViolations: 2, + expectedInOutput: []string{"Total files: 3", "Total violations: 2"}, + }, + { + name: "No matching files", + files: map[string]string{ + "query.txt": "SELECT * FROM users", + "readme.md": "Documentation", + }, + pattern: "*.sql", + expectedFileCount: 0, + expectedViolations: 0, + expectedInOutput: []string{"Total files: 0", "Total violations: 0"}, + }, + { + name: "Mixed violations in directory", + files: map[string]string{ + "clean.sql": "SELECT * FROM users", + "trailing.sql": "SELECT * FROM orders ", + "long.sql": "SELECT column1, column2, column3, column4, column5, column6, column7, column8, column9, column10 FROM users WHERE active = true", + }, + pattern: "*.sql", + expectedFileCount: 3, + expectedViolations: 2, + expectedInOutput: []string{"Total files: 3", "trailing.sql", "long.sql"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + + // Create test files and directories + for filename, content := range tt.files { + path := filepath.Join(tmpDir, filename) + dir := filepath.Dir(path) + + if err := os.MkdirAll(dir, 0755); err != nil { + t.Fatalf("Failed to create directory: %v", err) + } + + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + } + + // Create command + var outBuf, errBuf bytes.Buffer + cmd := &cobra.Command{} + cmd.SetOut(&outBuf) + cmd.SetErr(&errBuf) + + // Set flags + lintRecursive = true + lintPattern = tt.pattern + lintAutoFix = false + lintMaxLength = 100 + lintFailOnWarn = false + + // Run lint command + err := lintRun(cmd, []string{tmpDir}) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + // Verify output + output := outBuf.String() + for _, expected := range tt.expectedInOutput { + if !strings.Contains(output, expected) { + t.Errorf("Expected output to contain '%s', got:\n%s", expected, output) + } + } + + // Verify file count in output + if !strings.Contains(output, strings.Replace("Total files: X", "X", + strings.TrimSpace(strings.Split(strings.Split(output, "Total files: ")[1], "\n")[0]), 1)) { + t.Logf("File count check - output: %s", output) + } + }) + } +} + +// TestLintCmd_AutoFix tests auto-fix functionality +// Note: L002 (mixed indentation) is SeverityError and triggers os.Exit(1) +// Tests with L002 violations are skipped +func TestLintCmd_AutoFix(t *testing.T) { + tests := []struct { + name string + files map[string]string + expectedFixed map[string]string + expectedInOutput []string + skipTest bool // Skip tests that trigger os.Exit + }{ + { + name: "Auto-fix trailing whitespace", + files: map[string]string{ + "test.sql": "SELECT id FROM users \nSELECT * FROM orders\t", + }, + expectedFixed: map[string]string{ + "test.sql": "SELECT id FROM users\nSELECT * FROM orders", + }, + expectedInOutput: []string{"Auto-fixed", "test.sql"}, + }, + // Skipped: L002 is SeverityError which triggers os.Exit(1) + // { + // name: "Auto-fix mixed indentation", + // files: map[string]string{ + // "test.sql": "\t SELECT * FROM users", + // }, + // expectedFixed: map[string]string{ + // "test.sql": " SELECT * FROM users", + // }, + // expectedInOutput: []string{"Auto-fixed", "test.sql"}, + // skipTest: true, + // }, + // Skipped: L002 is SeverityError which triggers os.Exit(1) + // { + // name: "Auto-fix multiple violations in same file", + // files: map[string]string{ + // "test.sql": "SELECT * FROM users \n\t SELECT * FROM orders ", + // }, + // expectedFixed: map[string]string{ + // "test.sql": "SELECT * FROM users\n SELECT * FROM orders", + // }, + // expectedInOutput: []string{"Auto-fixed", "test.sql"}, + // skipTest: true, + // }, + { + name: "Auto-fix multiple files", + files: map[string]string{ + "query1.sql": "SELECT * FROM users ", + "query2.sql": "SELECT * FROM orders\t", + }, + expectedFixed: map[string]string{ + "query1.sql": "SELECT * FROM users", + "query2.sql": "SELECT * FROM orders", + }, + expectedInOutput: []string{"Auto-fixed 2 file(s)"}, + }, + { + name: "No auto-fix for long lines (L005)", + files: map[string]string{ + "test.sql": "SELECT column1, column2, column3, column4, column5, column6, column7, column8, column9, column10 FROM users WHERE active = true", + }, + expectedFixed: map[string]string{ + "test.sql": "SELECT column1, column2, column3, column4, column5, column6, column7, column8, column9, column10 FROM users WHERE active = true", + }, + expectedInOutput: []string{"Auto-fixed 0 file(s)"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.skipTest { + t.Skip("Skipping test that would trigger os.Exit(1) due to L002 SeverityError") + } + + tmpDir := t.TempDir() + var args []string + + // Create test files + for filename, content := range tt.files { + path := filepath.Join(tmpDir, filename) + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + args = append(args, path) + } + + // Create command + var outBuf, errBuf bytes.Buffer + cmd := &cobra.Command{} + cmd.SetOut(&outBuf) + cmd.SetErr(&errBuf) + + // Set flags + lintRecursive = false + lintAutoFix = true + lintMaxLength = 100 + lintFailOnWarn = false + + // Run lint command + err := lintRun(cmd, args) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + // Verify output + output := outBuf.String() + for _, expected := range tt.expectedInOutput { + if !strings.Contains(output, expected) { + t.Errorf("Expected output to contain '%s', got:\n%s", expected, output) + } + } + + // Verify file contents were fixed + for filename, expectedContent := range tt.expectedFixed { + path := filepath.Join(tmpDir, filename) + actualContent, err := os.ReadFile(path) + if err != nil { + t.Fatalf("Failed to read fixed file: %v", err) + } + + if string(actualContent) != expectedContent { + t.Errorf("File %s not fixed correctly.\nExpected: %q\nGot: %q", + filename, expectedContent, string(actualContent)) + } + } + }) + } +} + +// TestLintCmd_AutoFix_PreservesPermissions tests that auto-fix preserves file permissions +func TestLintCmd_AutoFix_PreservesPermissions(t *testing.T) { + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test.sql") + content := "SELECT * FROM users " + + // Create file with specific permissions + if err := os.WriteFile(filename, []byte(content), 0600); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + + // Get original permissions + originalInfo, err := os.Stat(filename) + if err != nil { + t.Fatalf("Failed to stat file: %v", err) + } + originalPerm := originalInfo.Mode() + + // Create command + var outBuf, errBuf bytes.Buffer + cmd := &cobra.Command{} + cmd.SetOut(&outBuf) + cmd.SetErr(&errBuf) + + // Set flags + lintAutoFix = true + lintMaxLength = 100 + lintFailOnWarn = false + + // Run lint command + err = lintRun(cmd, []string{filename}) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + // Verify permissions preserved + newInfo, err := os.Stat(filename) + if err != nil { + t.Fatalf("Failed to stat fixed file: %v", err) + } + newPerm := newInfo.Mode() + + if originalPerm != newPerm { + t.Errorf("File permissions changed. Original: %o, New: %o", originalPerm, newPerm) + } +} + +// TestLintCmd_Flags tests various flag combinations +func TestLintCmd_Flags(t *testing.T) { + tests := []struct { + name string + files map[string]string + recursive bool + pattern string + maxLength int + failOnWarn bool + expectedOutput []string + }{ + { + name: "Custom max-length flag", + files: map[string]string{ + "test.sql": "SELECT column1, column2, column3, column4 FROM users", + }, + maxLength: 50, + expectedOutput: []string{"L005", "exceeds maximum length"}, + }, + { + name: "Max-length allows longer lines", + files: map[string]string{ + "test.sql": "SELECT column1, column2, column3, column4 FROM users", + }, + maxLength: 200, + expectedOutput: []string{"Total violations: 0"}, + }, + { + name: "Pattern flag with recursive", + files: map[string]string{ + "migration_001.sql": "CREATE TABLE users (id INT) ", + "query.sql": "SELECT * FROM users", + }, + recursive: true, + pattern: "migration_*.sql", + expectedOutput: []string{"migration_001.sql", "Total files: 1"}, + }, + { + name: "Multiple flags combined", + files: map[string]string{ + "subdir/query.sql": "SELECT * FROM users ", + }, + recursive: true, + pattern: "*.sql", + maxLength: 80, + expectedOutput: []string{"query.sql", "L001"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + + // Create test files + for filename, content := range tt.files { + path := filepath.Join(tmpDir, filename) + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0755); err != nil { + t.Fatalf("Failed to create directory: %v", err) + } + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + } + + // Create command + var outBuf, errBuf bytes.Buffer + cmd := &cobra.Command{} + cmd.SetOut(&outBuf) + cmd.SetErr(&errBuf) + + // Set flags + lintRecursive = tt.recursive + lintPattern = tt.pattern + if tt.pattern == "" { + lintPattern = "*.sql" + } + lintAutoFix = false + lintMaxLength = tt.maxLength + if tt.maxLength == 0 { + lintMaxLength = 100 + } + lintFailOnWarn = tt.failOnWarn + + // Run lint command + args := []string{tmpDir} + if !tt.recursive { + args = []string{} + for filename := range tt.files { + args = append(args, filepath.Join(tmpDir, filename)) + } + } + + err := lintRun(cmd, args) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + // Verify output + output := outBuf.String() + for _, expected := range tt.expectedOutput { + if !strings.Contains(output, expected) { + t.Errorf("Expected output to contain '%s', got:\n%s", expected, output) + } + } + }) + } +} + +// TestLintCmd_Stdin tests linting from stdin +// Note: lintFailOnWarn must be false to avoid os.Exit(1) on warnings +func TestLintCmd_Stdin(t *testing.T) { + tests := []struct { + name string + input string + autoFix bool + expectedOutput []string + wantError bool + skipTest bool // Skip tests that would trigger os.Exit + }{ + { + name: "Valid SQL from stdin", + input: "SELECT * FROM users", + expectedOutput: []string{"Linting stdin input", "No violations found"}, + wantError: false, + }, + { + name: "SQL with violations from stdin", + input: "SELECT * FROM users ", + expectedOutput: []string{"Linting stdin input", "Found 1 violation", "L001", "Trailing Whitespace"}, + wantError: false, + skipTest: true, // Violations trigger os.Exit(1) + }, + { + name: "Auto-fix from stdin", + input: "SELECT * FROM users ", + autoFix: true, + expectedOutput: []string{"Auto-fixed output", "SELECT * FROM users"}, + wantError: false, + skipTest: true, // Violations trigger os.Exit(1) + }, + { + name: "Empty stdin", + input: "", + expectedOutput: []string{}, + wantError: true, + }, + { + name: "Large input from stdin", + input: strings.Repeat("SELECT * FROM users;\n", 100), + expectedOutput: []string{"Linting stdin input", "No violations found"}, + wantError: false, + }, + { + name: "Multiple violations from stdin", + input: "SELECT * FROM users \n\t SELECT * FROM orders", + expectedOutput: []string{"Found 2 violation", "L001", "L002"}, + wantError: false, + skipTest: true, // Violations trigger os.Exit(1) + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.skipTest { + t.Skip("Skipping test that would trigger os.Exit(1) - violations trigger exit in lintFromStdin") + } + + // Save original stdin + oldStdin := os.Stdin + defer func() { os.Stdin = oldStdin }() + + // Create pipe for stdin + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("Failed to create pipe: %v", err) + } + os.Stdin = r + + // Write test input to pipe + go func() { + defer w.Close() + w.Write([]byte(tt.input)) + }() + + // Create command + var outBuf, errBuf bytes.Buffer + cmd := &cobra.Command{} + cmd.SetOut(&outBuf) + cmd.SetErr(&errBuf) + + // Set flags + lintAutoFix = tt.autoFix + lintMaxLength = 100 + lintFailOnWarn = false + + // Run lint command with explicit stdin marker + err = lintRun(cmd, []string{"-"}) + + // Check error + if tt.wantError { + if err == nil { + t.Error("Expected error but got none") + } + } else { + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + } + + if !tt.wantError { + // Verify output + output := outBuf.String() + for _, expected := range tt.expectedOutput { + if !strings.Contains(output, expected) { + t.Errorf("Expected output to contain '%s', got:\n%s", expected, output) + } + } + } + }) + } +} + +// TestLintCmd_Stdin_PipeDetection tests automatic stdin pipe detection +func TestLintCmd_Stdin_PipeDetection(t *testing.T) { + // Save original stdin + oldStdin := os.Stdin + defer func() { os.Stdin = oldStdin }() + + // Create pipe for stdin + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("Failed to create pipe: %v", err) + } + os.Stdin = r + + input := "SELECT * FROM users" + go func() { + defer w.Close() + w.Write([]byte(input)) + }() + + // Create command + var outBuf, errBuf bytes.Buffer + cmd := &cobra.Command{} + cmd.SetOut(&outBuf) + cmd.SetErr(&errBuf) + + // Set flags + lintAutoFix = false + lintMaxLength = 100 + lintFailOnWarn = false + + // Run lint command with no args (should auto-detect piped stdin) + err = lintRun(cmd, []string{}) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + output := outBuf.String() + if !strings.Contains(output, "Linting stdin input") { + t.Errorf("Expected stdin input detection, got: %s", output) + } +} + +// TestLintCmd_Output tests output formatting +func TestLintCmd_Output(t *testing.T) { + tests := []struct { + name string + files map[string]string + expectedOutput []string + }{ + { + name: "Output format for violations", + files: map[string]string{ + "test.sql": "SELECT * FROM users ", + }, + expectedOutput: []string{ + "test.sql", + "violation", + "[L001]", + "Trailing Whitespace", + "line 1", + "Severity: warning", + }, + }, + { + name: "Output shows rule IDs", + files: map[string]string{ + "test.sql": "SELECT column1, column2, column3, column4, column5, column6, column7, column8, column9, column10 FROM users WHERE active = true ", + }, + expectedOutput: []string{ + "L001", // Trailing whitespace + "L005", // Long line + // L002 skipped - it's SeverityError and would trigger os.Exit(1) + }, + }, + { + name: "Output shows line numbers and columns", + files: map[string]string{ + "test.sql": "SELECT * FROM users\nSELECT * FROM orders ", + }, + expectedOutput: []string{ + "line 2", + "column", + }, + }, + { + name: "Summary statistics", + files: map[string]string{ + "query1.sql": "SELECT * FROM users ", + "query2.sql": "SELECT * FROM orders\t", + "query3.sql": "SELECT * FROM products", + }, + expectedOutput: []string{ + "Total files: 3", + "Total violations: 2", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + var args []string + + // Create test files + for filename, content := range tt.files { + path := filepath.Join(tmpDir, filename) + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + args = append(args, path) + } + + // Create command + var outBuf, errBuf bytes.Buffer + cmd := &cobra.Command{} + cmd.SetOut(&outBuf) + cmd.SetErr(&errBuf) + + // Set flags + lintRecursive = false + lintAutoFix = false + lintMaxLength = 100 + lintFailOnWarn = false + + // Run lint command + err := lintRun(cmd, args) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + // Verify output + output := outBuf.String() + for _, expected := range tt.expectedOutput { + if !strings.Contains(output, expected) { + t.Errorf("Expected output to contain '%s', got:\n%s", expected, output) + } + } + }) + } +} + +// TestLintCmd_ExitCodes tests exit code behavior (Note: os.Exit is called, so we can't test directly) +func TestLintCmd_ExitCodes(t *testing.T) { + // This test documents the expected exit code behavior + // Actual exit code testing would require subprocess execution + + tests := []struct { + name string + content string + failOnWarn bool + shouldExit bool + description string + }{ + { + name: "No violations - exit 0", + content: "SELECT * FROM users", + failOnWarn: false, + shouldExit: false, + description: "Clean SQL should not trigger exit", + }, + { + name: "Warning without fail-on-warn - exit 0", + content: "SELECT * FROM users ", + failOnWarn: false, + shouldExit: false, + description: "Warnings alone should not trigger exit by default", + }, + { + name: "Warning with fail-on-warn - would exit 1", + content: "SELECT * FROM users ", + failOnWarn: true, + shouldExit: true, + description: "Warnings with --fail-on-warn should trigger exit 1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Logf("Exit code behavior: %s", tt.description) + // Actual exit code testing would require subprocess execution + // This test serves as documentation of expected behavior + }) + } +} + +// TestCreateLinter tests linter creation with rules +func TestCreateLinter(t *testing.T) { + tests := []struct { + name string + maxLength int + expectedRules int + }{ + { + name: "Default max-length", + maxLength: 100, + expectedRules: 3, // L001, L002, L005 + }, + { + name: "Custom max-length", + maxLength: 120, + expectedRules: 3, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set global flag + lintMaxLength = tt.maxLength + + // Create linter + linter := createLinter() + + // Verify rule count + rules := linter.Rules() + if len(rules) != tt.expectedRules { + t.Errorf("Expected %d rules, got %d", tt.expectedRules, len(rules)) + } + + // Verify rule IDs + expectedIDs := []string{"L001", "L002", "L005"} + for i, rule := range rules { + if rule.ID() != expectedIDs[i] { + t.Errorf("Expected rule ID %s, got %s", expectedIDs[i], rule.ID()) + } + } + + // Verify L005 has correct max length + // The actual max-length value is tested in TestCreateLinter_MaxLengthPassedToRule + t.Logf("Created linter with max-length: %d", tt.maxLength) + }) + } +} + +// TestCreateLinter_MaxLengthPassedToRule verifies max-length is correctly configured +func TestCreateLinter_MaxLengthPassedToRule(t *testing.T) { + tests := []struct { + name string + maxLength int + testLine string + wantViolation bool + }{ + { + name: "Line under max-length", + maxLength: 100, + testLine: strings.Repeat("x", 50), + wantViolation: false, + }, + { + name: "Line over max-length", + maxLength: 100, + testLine: strings.Repeat("x", 150), + wantViolation: true, + }, + { + name: "Line exactly at max-length", + maxLength: 100, + testLine: strings.Repeat("x", 100), + wantViolation: false, + }, + { + name: "Line one over max-length", + maxLength: 100, + testLine: strings.Repeat("x", 101), + wantViolation: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test.sql") + + // Create test file + if err := os.WriteFile(filename, []byte(tt.testLine), 0644); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + + // Create command + var outBuf, errBuf bytes.Buffer + cmd := &cobra.Command{} + cmd.SetOut(&outBuf) + cmd.SetErr(&errBuf) + + // Set flags + lintMaxLength = tt.maxLength + lintAutoFix = false + lintFailOnWarn = false + + // Run lint command + err := lintRun(cmd, []string{filename}) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + // Check for L005 violation + output := outBuf.String() + hasL005 := strings.Contains(output, "L005") + + if tt.wantViolation && !hasL005 { + t.Errorf("Expected L005 violation but got none. Output:\n%s", output) + } + if !tt.wantViolation && hasL005 { + t.Errorf("Did not expect L005 violation but got one. Output:\n%s", output) + } + }) + } +} diff --git a/pkg/linter/context_test.go b/pkg/linter/context_test.go new file mode 100644 index 00000000..480ddde5 --- /dev/null +++ b/pkg/linter/context_test.go @@ -0,0 +1,549 @@ +package linter + +import ( + "errors" + "testing" + + "github.com/ajitpratap0/GoSQLX/pkg/models" + "github.com/ajitpratap0/GoSQLX/pkg/sql/ast" +) + +// TestNewContext tests the NewContext constructor with various SQL inputs +func TestNewContext(t *testing.T) { + tests := []struct { + name string + sql string + filename string + expectedSQL string + expectedLines []string + expectedFilename string + }{ + { + name: "simple SQL", + sql: "SELECT * FROM users", + filename: "query.sql", + expectedSQL: "SELECT * FROM users", + expectedLines: []string{"SELECT * FROM users"}, + expectedFilename: "query.sql", + }, + { + name: "empty string", + sql: "", + filename: "empty.sql", + expectedSQL: "", + expectedLines: []string{""}, + expectedFilename: "empty.sql", + }, + { + name: "single line", + sql: "SELECT id, name FROM users WHERE active = true", + filename: "single.sql", + expectedSQL: "SELECT id, name FROM users WHERE active = true", + expectedLines: []string{"SELECT id, name FROM users WHERE active = true"}, + expectedFilename: "single.sql", + }, + { + name: "multiple lines", + sql: `SELECT id, name +FROM users +WHERE active = true`, + filename: "multi.sql", + expectedSQL: `SELECT id, name +FROM users +WHERE active = true`, + expectedLines: []string{"SELECT id, name", "FROM users", "WHERE active = true"}, + expectedFilename: "multi.sql", + }, + { + name: "Windows line endings", + sql: "SELECT *\r\nFROM users\r\nWHERE id = 1", + filename: "windows.sql", + expectedSQL: "SELECT *\r\nFROM users\r\nWHERE id = 1", + expectedLines: []string{"SELECT *\r", "FROM users\r", "WHERE id = 1"}, + expectedFilename: "windows.sql", + }, + { + name: "Unix line endings", + sql: "SELECT *\nFROM users\nWHERE id = 1", + filename: "unix.sql", + expectedSQL: "SELECT *\nFROM users\nWHERE id = 1", + expectedLines: []string{"SELECT *", "FROM users", "WHERE id = 1"}, + expectedFilename: "unix.sql", + }, + { + name: "mixed line endings", + sql: "SELECT *\nFROM users\r\nWHERE id = 1\nORDER BY name", + filename: "mixed.sql", + expectedSQL: "SELECT *\nFROM users\r\nWHERE id = 1\nORDER BY name", + expectedLines: []string{"SELECT *", "FROM users\r", "WHERE id = 1", "ORDER BY name"}, + expectedFilename: "mixed.sql", + }, + { + name: "Unicode content", + sql: `SELECT name, 価格 +FROM 製品 +WHERE カテゴリ = '電子機器'`, + filename: "unicode.sql", + expectedSQL: `SELECT name, 価格 +FROM 製品 +WHERE カテゴリ = '電子機器'`, + expectedLines: []string{"SELECT name, 価格", "FROM 製品", "WHERE カテゴリ = '電子機器'"}, + expectedFilename: "unicode.sql", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := NewContext(tt.sql, tt.filename) + + // Verify SQL is stored correctly + if ctx.SQL != tt.expectedSQL { + t.Errorf("SQL = %q, want %q", ctx.SQL, tt.expectedSQL) + } + + // Verify filename is stored correctly + if ctx.Filename != tt.expectedFilename { + t.Errorf("Filename = %q, want %q", ctx.Filename, tt.expectedFilename) + } + + // Verify lines are split correctly + if len(ctx.Lines) != len(tt.expectedLines) { + t.Fatalf("Lines count = %d, want %d", len(ctx.Lines), len(tt.expectedLines)) + } + + for i, line := range ctx.Lines { + if line != tt.expectedLines[i] { + t.Errorf("Lines[%d] = %q, want %q", i, line, tt.expectedLines[i]) + } + } + + // Verify tokens and AST are initially nil/empty + if ctx.Tokens != nil { + t.Errorf("Tokens should be nil, got %v", ctx.Tokens) + } + if ctx.AST != nil { + t.Errorf("AST should be nil, got %v", ctx.AST) + } + if ctx.ParseErr != nil { + t.Errorf("ParseErr should be nil, got %v", ctx.ParseErr) + } + }) + } +} + +// TestContext_WithTokens tests adding tokens to the context +func TestContext_WithTokens(t *testing.T) { + tests := []struct { + name string + sql string + tokens []models.TokenWithSpan + expectedTokens []models.TokenWithSpan + }{ + { + name: "add tokens to context", + sql: "SELECT * FROM users", + tokens: []models.TokenWithSpan{ + { + Token: models.Token{Type: models.TokenTypeWord, Value: "SELECT"}, + Start: models.Location{Line: 1, Column: 1}, + End: models.Location{Line: 1, Column: 7}, + }, + { + Token: models.Token{Type: models.TokenTypeMul, Value: "*"}, + Start: models.Location{Line: 1, Column: 8}, + End: models.Location{Line: 1, Column: 9}, + }, + }, + expectedTokens: []models.TokenWithSpan{ + { + Token: models.Token{Type: models.TokenTypeWord, Value: "SELECT"}, + Start: models.Location{Line: 1, Column: 1}, + End: models.Location{Line: 1, Column: 7}, + }, + { + Token: models.Token{Type: models.TokenTypeMul, Value: "*"}, + Start: models.Location{Line: 1, Column: 8}, + End: models.Location{Line: 1, Column: 9}, + }, + }, + }, + { + name: "add empty token list", + sql: "SELECT * FROM users", + tokens: []models.TokenWithSpan{}, + expectedTokens: []models.TokenWithSpan{}, + }, + { + name: "add nil token list", + sql: "SELECT * FROM users", + tokens: nil, + expectedTokens: nil, + }, + { + name: "verify tokens are stored correctly", + sql: "INSERT INTO users VALUES (1)", + tokens: []models.TokenWithSpan{ + { + Token: models.Token{Type: models.TokenTypeWord, Value: "INSERT"}, + Start: models.Location{Line: 1, Column: 1}, + End: models.Location{Line: 1, Column: 7}, + }, + }, + expectedTokens: []models.TokenWithSpan{ + { + Token: models.Token{Type: models.TokenTypeWord, Value: "INSERT"}, + Start: models.Location{Line: 1, Column: 1}, + End: models.Location{Line: 1, Column: 7}, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := NewContext(tt.sql, "test.sql") + result := ctx.WithTokens(tt.tokens) + + // Verify method chaining returns the same instance + if result != ctx { + t.Error("WithTokens should return the same context instance for chaining") + } + + // Verify tokens are stored correctly + if len(ctx.Tokens) != len(tt.expectedTokens) { + t.Fatalf("Tokens count = %d, want %d", len(ctx.Tokens), len(tt.expectedTokens)) + } + + for i, token := range ctx.Tokens { + if token.Token.Type != tt.expectedTokens[i].Token.Type { + t.Errorf("Tokens[%d].Type = %v, want %v", i, token.Token.Type, tt.expectedTokens[i].Token.Type) + } + if token.Token.Value != tt.expectedTokens[i].Token.Value { + t.Errorf("Tokens[%d].Value = %q, want %q", i, token.Token.Value, tt.expectedTokens[i].Token.Value) + } + } + }) + } +} + +// TestContext_WithAST tests adding AST and parse errors to the context +func TestContext_WithAST(t *testing.T) { + tests := []struct { + name string + sql string + astObj *ast.AST + parseErr error + expectAST bool + expectError bool + }{ + { + name: "add AST without error", + sql: "SELECT * FROM users", + astObj: &ast.AST{}, + parseErr: nil, + expectAST: true, + expectError: false, + }, + { + name: "add AST with parse error", + sql: "SELECT * FROM", + astObj: nil, + parseErr: errors.New("unexpected end of input"), + expectAST: false, + expectError: true, + }, + { + name: "add nil AST with error", + sql: "INVALID SQL", + astObj: nil, + parseErr: errors.New("syntax error"), + expectAST: false, + expectError: true, + }, + { + name: "add AST and error both present", + sql: "SELECT * FROM users WHERE", + astObj: &ast.AST{}, + parseErr: errors.New("incomplete WHERE clause"), + expectAST: true, + expectError: true, + }, + { + name: "add nil AST without error", + sql: "SELECT * FROM users", + astObj: nil, + parseErr: nil, + expectAST: false, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := NewContext(tt.sql, "test.sql") + result := ctx.WithAST(tt.astObj, tt.parseErr) + + // Verify method chaining returns the same instance + if result != ctx { + t.Error("WithAST should return the same context instance for chaining") + } + + // Verify AST is stored correctly + if tt.expectAST && ctx.AST == nil { + t.Error("Expected AST to be set, but it was nil") + } + if !tt.expectAST && ctx.AST != nil { + t.Errorf("Expected AST to be nil, but got %v", ctx.AST) + } + + // Verify error is stored correctly + if tt.expectError && ctx.ParseErr == nil { + t.Error("Expected ParseErr to be set, but it was nil") + } + if !tt.expectError && ctx.ParseErr != nil { + t.Errorf("Expected ParseErr to be nil, but got %v", ctx.ParseErr) + } + + // Verify error message if present + if tt.expectError && tt.parseErr != nil { + if ctx.ParseErr.Error() != tt.parseErr.Error() { + t.Errorf("ParseErr = %q, want %q", ctx.ParseErr.Error(), tt.parseErr.Error()) + } + } + }) + } +} + +// TestContext_GetLine tests retrieving specific lines from the context +func TestContext_GetLine(t *testing.T) { + tests := []struct { + name string + sql string + lineNum int + expectedLine string + }{ + { + name: "get first line (line 1)", + sql: "SELECT *\nFROM users\nWHERE id = 1", + lineNum: 1, + expectedLine: "SELECT *", + }, + { + name: "get middle line", + sql: "SELECT *\nFROM users\nWHERE id = 1", + lineNum: 2, + expectedLine: "FROM users", + }, + { + name: "get last line", + sql: "SELECT *\nFROM users\nWHERE id = 1", + lineNum: 3, + expectedLine: "WHERE id = 1", + }, + { + name: "get line 0 (out of bounds)", + sql: "SELECT *\nFROM users", + lineNum: 0, + expectedLine: "", + }, + { + name: "get negative line number", + sql: "SELECT *\nFROM users", + lineNum: -1, + expectedLine: "", + }, + { + name: "get line beyond last line", + sql: "SELECT *\nFROM users", + lineNum: 10, + expectedLine: "", + }, + { + name: "get line from single-line SQL", + sql: "SELECT * FROM users WHERE active = true", + lineNum: 1, + expectedLine: "SELECT * FROM users WHERE active = true", + }, + { + name: "get line from empty SQL", + sql: "", + lineNum: 1, + expectedLine: "", + }, + { + name: "get line with Unicode content", + sql: `SELECT name, 価格 +FROM 製品 +WHERE カテゴリ = '電子機器'`, + lineNum: 2, + expectedLine: "FROM 製品", + }, + { + name: "get line with trailing whitespace preserved", + sql: "SELECT * \nFROM users \nWHERE id = 1", + lineNum: 1, + expectedLine: "SELECT * ", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := NewContext(tt.sql, "test.sql") + line := ctx.GetLine(tt.lineNum) + + if line != tt.expectedLine { + t.Errorf("GetLine(%d) = %q, want %q", tt.lineNum, line, tt.expectedLine) + } + }) + } +} + +// TestContext_GetLineCount tests counting lines in the context +func TestContext_GetLineCount(t *testing.T) { + tests := []struct { + name string + sql string + expectedCount int + }{ + { + name: "count lines in multi-line SQL", + sql: `SELECT id, name +FROM users +WHERE active = true +ORDER BY name`, + expectedCount: 4, + }, + { + name: "count lines in single-line SQL", + sql: "SELECT * FROM users WHERE active = true", + expectedCount: 1, + }, + { + name: "count lines in empty string", + sql: "", + expectedCount: 1, // Empty string splits to [""] + }, + { + name: "count lines with only newlines", + sql: "\n\n\n", + expectedCount: 4, // Four empty strings + }, + { + name: "count lines with Windows line endings", + sql: "SELECT *\r\nFROM users\r\nWHERE id = 1", + expectedCount: 3, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := NewContext(tt.sql, "test.sql") + count := ctx.GetLineCount() + + if count != tt.expectedCount { + t.Errorf("GetLineCount() = %d, want %d", count, tt.expectedCount) + } + }) + } +} + +// TestContext_Integration tests the full workflow of building a context +func TestContext_Integration(t *testing.T) { + tests := []struct { + name string + sql string + filename string + }{ + { + name: "full workflow with simple SQL", + sql: "SELECT * FROM users", + filename: "query.sql", + }, + { + name: "full workflow with multi-line SQL", + sql: `SELECT id, name, email +FROM users +WHERE active = true +ORDER BY name`, + filename: "complex_query.sql", + }, + { + name: "full workflow with Unicode SQL", + sql: `SELECT 名前, 価格 +FROM 製品 +WHERE カテゴリ = '電子機器'`, + filename: "unicode_query.sql", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Step 1: Create context + ctx := NewContext(tt.sql, tt.filename) + if ctx.SQL != tt.sql { + t.Errorf("SQL = %q, want %q", ctx.SQL, tt.sql) + } + if ctx.Filename != tt.filename { + t.Errorf("Filename = %q, want %q", ctx.Filename, tt.filename) + } + + // Step 2: Add tokens + tokens := []models.TokenWithSpan{ + { + Token: models.Token{Type: models.TokenTypeWord, Value: "SELECT"}, + Start: models.Location{Line: 1, Column: 1}, + End: models.Location{Line: 1, Column: 7}, + }, + } + result := ctx.WithTokens(tokens) + if result != ctx { + t.Error("WithTokens should return the same context instance") + } + if len(ctx.Tokens) != 1 { + t.Errorf("Tokens count = %d, want 1", len(ctx.Tokens)) + } + + // Step 3: Add AST + astObj := &ast.AST{} + result = ctx.WithAST(astObj, nil) + if result != ctx { + t.Error("WithAST should return the same context instance") + } + if ctx.AST == nil { + t.Error("AST should be set") + } + if ctx.ParseErr != nil { + t.Errorf("ParseErr should be nil, got %v", ctx.ParseErr) + } + + // Step 4: Verify all fields are populated correctly + if ctx.SQL != tt.sql { + t.Errorf("Final SQL = %q, want %q", ctx.SQL, tt.sql) + } + if ctx.Filename != tt.filename { + t.Errorf("Final Filename = %q, want %q", ctx.Filename, tt.filename) + } + if len(ctx.Lines) == 0 { + t.Error("Lines should not be empty") + } + if len(ctx.Tokens) == 0 { + t.Error("Tokens should not be empty") + } + if ctx.AST == nil { + t.Error("AST should not be nil") + } + + // Step 5: Test method chaining - all in one line + ctx2 := NewContext(tt.sql, tt.filename).WithTokens(tokens).WithAST(astObj, nil) + if ctx2.SQL != tt.sql { + t.Errorf("Chained SQL = %q, want %q", ctx2.SQL, tt.sql) + } + if ctx2.AST == nil { + t.Error("Chained AST should not be nil") + } + if len(ctx2.Tokens) == 0 { + t.Error("Chained Tokens should not be empty") + } + }) + } +} diff --git a/pkg/linter/linter_test.go b/pkg/linter/linter_test.go new file mode 100644 index 00000000..74a49444 --- /dev/null +++ b/pkg/linter/linter_test.go @@ -0,0 +1,1204 @@ +package linter + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/ajitpratap0/GoSQLX/pkg/models" +) + +// Mock rule for testing +type mockRule struct { + BaseRule + checkFunc func(*Context) ([]Violation, error) + fixFunc func(string, []Violation) (string, error) +} + +func newMockRule(id, name, description string, severity Severity, canAutoFix bool) *mockRule { + return &mockRule{ + BaseRule: NewBaseRule(id, name, description, severity, canAutoFix), + checkFunc: func(ctx *Context) ([]Violation, error) { + return []Violation{}, nil + }, + fixFunc: func(content string, violations []Violation) (string, error) { + return content, nil + }, + } +} + +func (r *mockRule) Check(ctx *Context) ([]Violation, error) { + if r.checkFunc != nil { + return r.checkFunc(ctx) + } + return []Violation{}, nil +} + +func (r *mockRule) Fix(content string, violations []Violation) (string, error) { + if r.fixFunc != nil { + return r.fixFunc(content, violations) + } + return content, nil +} + +// Mock trailing whitespace rule for testing +func newMockTrailingWhitespaceRule() *mockRule { + rule := newMockRule( + "L001", + "Trailing Whitespace", + "Unnecessary trailing whitespace at end of lines", + SeverityWarning, + true, + ) + + rule.checkFunc = func(ctx *Context) ([]Violation, error) { + violations := []Violation{} + for lineNum, line := range ctx.Lines { + if len(line) == 0 { + continue + } + // Check if line has trailing whitespace (spaces or tabs) + if len(line) > 0 && (strings.HasSuffix(line, " ") || strings.HasSuffix(line, "\t")) { + trimmed := strings.TrimRight(line, " \t") + column := len(trimmed) + 1 + violations = append(violations, Violation{ + Rule: rule.ID(), + RuleName: rule.Name(), + Severity: rule.Severity(), + Message: "Line has trailing whitespace", + Location: models.Location{Line: lineNum + 1, Column: column}, + Line: line, + Suggestion: "Remove trailing spaces or tabs from the end of the line", + CanAutoFix: true, + }) + } + } + return violations, nil + } + + rule.fixFunc = func(content string, violations []Violation) (string, error) { + lines := strings.Split(content, "\n") + for i, line := range lines { + lines[i] = strings.TrimRight(line, " \t") + } + return strings.Join(lines, "\n"), nil + } + + return rule +} + +// Mock mixed indentation rule for testing +func newMockMixedIndentationRule() *mockRule { + rule := newMockRule( + "L002", + "Mixed Indentation", + "Inconsistent use of tabs and spaces for indentation", + SeverityError, + true, + ) + + rule.checkFunc = func(ctx *Context) ([]Violation, error) { + violations := []Violation{} + var firstIndentType string // "tab" or "space" + + for lineNum, line := range ctx.Lines { + if len(line) == 0 { + continue + } + + // Get leading whitespace + leadingWhitespace := "" + for _, char := range line { + if char != ' ' && char != '\t' { + break + } + leadingWhitespace += string(char) + } + + if len(leadingWhitespace) == 0 { + continue + } + + hasTabs := strings.Contains(leadingWhitespace, "\t") + hasSpaces := strings.Contains(leadingWhitespace, " ") + + // Mixed tabs and spaces on same line + if hasTabs && hasSpaces { + violations = append(violations, Violation{ + Rule: rule.ID(), + RuleName: rule.Name(), + Severity: rule.Severity(), + Message: "Line mixes tabs and spaces for indentation", + Location: models.Location{Line: lineNum + 1, Column: 1}, + Line: line, + Suggestion: "Use either tabs or spaces consistently for indentation", + CanAutoFix: true, + }) + continue + } + + // Track first indentation type + currentType := "" + if hasTabs { + currentType = "tab" + } else if hasSpaces { + currentType = "space" + } + + if currentType != "" { + if firstIndentType == "" { + firstIndentType = currentType + } else if firstIndentType != currentType { + violations = append(violations, Violation{ + Rule: rule.ID(), + RuleName: rule.Name(), + Severity: rule.Severity(), + Message: "Inconsistent indentation: file uses both tabs and spaces", + Location: models.Location{Line: lineNum + 1, Column: 1}, + Line: line, + Suggestion: "Use " + firstIndentType + "s consistently throughout the file", + CanAutoFix: true, + }) + } + } + } + + return violations, nil + } + + return rule +} + +// Mock long lines rule for testing +func newMockLongLinesRule(maxLength int) *mockRule { + if maxLength <= 0 { + maxLength = 100 + } + + rule := newMockRule( + "L005", + "Long Lines", + "Lines should not exceed maximum length for readability", + SeverityInfo, + false, + ) + + rule.checkFunc = func(ctx *Context) ([]Violation, error) { + violations := []Violation{} + + for lineNum, line := range ctx.Lines { + lineLength := len(line) + + if lineLength == 0 { + continue + } + + if lineLength > maxLength { + violations = append(violations, Violation{ + Rule: rule.ID(), + RuleName: rule.Name(), + Severity: rule.Severity(), + Message: "Line exceeds maximum length", + Location: models.Location{Line: lineNum + 1, Column: maxLength + 1}, + Line: line, + Suggestion: "Split this line into multiple lines", + CanAutoFix: false, + }) + } + } + + return violations, nil + } + + return rule +} + +// TestLinter_New tests the New constructor +func TestLinter_New(t *testing.T) { + tests := []struct { + name string + rules []Rule + expectedRules int + }{ + { + name: "Create linter with no rules", + rules: []Rule{}, + expectedRules: 0, + }, + { + name: "Create linter with single rule", + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectedRules: 1, + }, + { + name: "Create linter with multiple rules", + rules: []Rule{ + newMockTrailingWhitespaceRule(), + newMockMixedIndentationRule(), + newMockLongLinesRule(80), + }, + expectedRules: 3, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + linter := New(tt.rules...) + + if linter == nil { + t.Fatal("Expected non-nil linter") + } + + if len(linter.rules) != tt.expectedRules { + t.Errorf("Expected %d rules, got %d", tt.expectedRules, len(linter.rules)) + } + }) + } +} + +// TestLinter_Rules tests the Rules() method +func TestLinter_Rules(t *testing.T) { + tests := []struct { + name string + rules []Rule + expectedCount int + }{ + { + name: "Verify Rules() returns correct rule list", + rules: []Rule{newMockTrailingWhitespaceRule(), newMockMixedIndentationRule()}, + expectedCount: 2, + }, + { + name: "Verify Rules() with empty list", + rules: []Rule{}, + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + linter := New(tt.rules...) + rules := linter.Rules() + + if len(rules) != tt.expectedCount { + t.Errorf("Expected %d rules, got %d", tt.expectedCount, len(rules)) + } + + // Verify Rules() returns copy not reference + if len(rules) > 0 { + originalPtr := &linter.rules + returnedPtr := &rules + if originalPtr == returnedPtr { + t.Error("Rules() should return a copy, not a reference") + } + } + }) + } +} + +// TestLinter_LintString tests the LintString() function +func TestLinter_LintString(t *testing.T) { + tests := []struct { + name string + sql string + rules []Rule + expectedViolations int + expectError bool + checkViolations func(*testing.T, []Violation) + }{ + { + name: "Empty SQL string", + sql: "", + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectedViolations: 0, + }, + { + name: "Valid SQL with no violations", + sql: "SELECT id, name FROM users WHERE active = true", + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectedViolations: 0, + }, + { + name: "SQL with single violation (trailing whitespace)", + sql: "SELECT id ", + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectedViolations: 1, + checkViolations: func(t *testing.T, violations []Violation) { + if violations[0].Rule != "L001" { + t.Errorf("Expected rule L001, got %s", violations[0].Rule) + } + }, + }, + { + name: "SQL with multiple violations from same rule", + sql: "SELECT id \nFROM users \nWHERE active = true ", + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectedViolations: 3, + }, + { + name: "SQL with violations from multiple rules", + sql: "SELECT id \n FROM users\n\tWHERE id = 1", + rules: []Rule{ + newMockTrailingWhitespaceRule(), + newMockMixedIndentationRule(), + }, + expectedViolations: 2, // 1 trailing whitespace + 1 mixed indentation + }, + { + name: "SQL with line numbers correctly tracked", + sql: "SELECT id\nFROM users \nWHERE active = true", + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectedViolations: 1, + checkViolations: func(t *testing.T, violations []Violation) { + if violations[0].Location.Line != 2 { + t.Errorf("Expected violation on line 2, got line %d", violations[0].Location.Line) + } + }, + }, + { + name: "SQL that fails tokenization (invalid syntax)", + sql: "SELECT 'unterminated string", + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectedViolations: 0, // No whitespace violations even with tokenization failure + }, + { + name: "SQL that fails parsing but has whitespace violations", + sql: "SELECT * FROM \nINVALID SYNTAX HERE", + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectedViolations: 1, // Whitespace rules work without parsing + }, + { + name: "Unicode SQL content", + sql: "SELECT name FROM users WHERE city = '東京' ", + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectedViolations: 1, + }, + { + name: "Multi-line SQL with violations on different lines", + sql: "SELECT id \nFROM users\nWHERE active = true ", + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectedViolations: 2, + checkViolations: func(t *testing.T, violations []Violation) { + if violations[0].Location.Line != 1 || violations[1].Location.Line != 3 { + t.Error("Violations not on expected lines") + } + }, + }, + { + name: "Very long SQL (100+ lines)", + sql: generateLongSQL(100), + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectedViolations: 100, // Every line has trailing whitespace + }, + { + name: "SQL with Windows line endings", + sql: "SELECT id \r\nFROM users ", + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectedViolations: 1, // Last line with trailing whitespace (first line split loses trailing) + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + linter := New(tt.rules...) + result := linter.LintString(tt.sql, "test.sql") + + if tt.expectError && result.Error == nil { + t.Error("Expected error but got none") + } + + if !tt.expectError && result.Error != nil { + t.Errorf("Unexpected error: %v", result.Error) + } + + if len(result.Violations) != tt.expectedViolations { + t.Errorf("Expected %d violations, got %d", tt.expectedViolations, len(result.Violations)) + for i, v := range result.Violations { + t.Logf("Violation %d: %s at line %d", i+1, v.Message, v.Location.Line) + } + } + + if tt.checkViolations != nil && len(result.Violations) > 0 { + tt.checkViolations(t, result.Violations) + } + }) + } +} + +// TestLinter_LintFile tests the LintFile() function +func TestLinter_LintFile(t *testing.T) { + tests := []struct { + name string + setupFile func(t *testing.T) string // Returns file path + rules []Rule + expectedViolations int + expectError bool + }{ + { + name: "Lint existing SQL file successfully", + setupFile: func(t *testing.T) string { + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "test.sql") + content := "SELECT id, name FROM users WHERE active = true" + if err := os.WriteFile(filePath, []byte(content), 0644); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + return filePath + }, + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectedViolations: 0, + }, + { + name: "Lint file with no violations", + setupFile: func(t *testing.T) string { + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "clean.sql") + content := "SELECT id, name\nFROM users\nWHERE active = true" + if err := os.WriteFile(filePath, []byte(content), 0644); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + return filePath + }, + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectedViolations: 0, + }, + { + name: "Lint file with violations", + setupFile: func(t *testing.T) string { + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "violations.sql") + content := "SELECT id \nFROM users \nWHERE active = true " + if err := os.WriteFile(filePath, []byte(content), 0644); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + return filePath + }, + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectedViolations: 3, + }, + { + name: "Lint non-existent file (error handling)", + setupFile: func(t *testing.T) string { + return "/nonexistent/path/to/file.sql" + }, + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectError: true, + }, + { + name: "Lint file with Unicode content", + setupFile: func(t *testing.T) string { + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "unicode.sql") + content := "SELECT name FROM users WHERE city = '東京' " + if err := os.WriteFile(filePath, []byte(content), 0644); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + return filePath + }, + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectedViolations: 1, + }, + { + name: "Lint empty file", + setupFile: func(t *testing.T) string { + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "empty.sql") + if err := os.WriteFile(filePath, []byte(""), 0644); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + return filePath + }, + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectedViolations: 0, + }, + { + name: "Lint file with mixed line endings", + setupFile: func(t *testing.T) string { + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "mixed.sql") + content := "SELECT id \r\nFROM users " + if err := os.WriteFile(filePath, []byte(content), 0644); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + return filePath + }, + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectedViolations: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + filePath := tt.setupFile(t) + linter := New(tt.rules...) + result := linter.LintFile(filePath) + + if tt.expectError && result.Error == nil { + t.Error("Expected error but got none") + } + + if !tt.expectError && result.Error != nil { + t.Errorf("Unexpected error: %v", result.Error) + } + + if !tt.expectError && len(result.Violations) != tt.expectedViolations { + t.Errorf("Expected %d violations, got %d", tt.expectedViolations, len(result.Violations)) + } + + if result.Filename != filePath { + t.Errorf("Expected filename %s, got %s", filePath, result.Filename) + } + }) + } +} + +// TestLinter_LintFiles tests the LintFiles() function +func TestLinter_LintFiles(t *testing.T) { + tests := []struct { + name string + setupFiles func(t *testing.T) []string + rules []Rule + expectedFiles int + expectedViolations int + }{ + { + name: "Lint zero files (empty list)", + setupFiles: func(t *testing.T) []string { + return []string{} + }, + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectedFiles: 0, + expectedViolations: 0, + }, + { + name: "Lint single file", + setupFiles: func(t *testing.T) []string { + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "test.sql") + content := "SELECT id, name FROM users" + if err := os.WriteFile(filePath, []byte(content), 0644); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + return []string{filePath} + }, + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectedFiles: 1, + expectedViolations: 0, + }, + { + name: "Lint multiple files successfully", + setupFiles: func(t *testing.T) []string { + tmpDir := t.TempDir() + file1 := filepath.Join(tmpDir, "test1.sql") + file2 := filepath.Join(tmpDir, "test2.sql") + content := "SELECT id FROM users" + if err := os.WriteFile(file1, []byte(content), 0644); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + if err := os.WriteFile(file2, []byte(content), 0644); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + return []string{file1, file2} + }, + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectedFiles: 2, + expectedViolations: 0, + }, + { + name: "Lint multiple files with some having violations", + setupFiles: func(t *testing.T) []string { + tmpDir := t.TempDir() + file1 := filepath.Join(tmpDir, "clean.sql") + file2 := filepath.Join(tmpDir, "dirty.sql") + if err := os.WriteFile(file1, []byte("SELECT id FROM users"), 0644); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + if err := os.WriteFile(file2, []byte("SELECT id \nFROM users "), 0644); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + return []string{file1, file2} + }, + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectedFiles: 2, + expectedViolations: 2, + }, + { + name: "Lint multiple files with some non-existent", + setupFiles: func(t *testing.T) []string { + tmpDir := t.TempDir() + file1 := filepath.Join(tmpDir, "exists.sql") + file2 := "/nonexistent/file.sql" + if err := os.WriteFile(file1, []byte("SELECT id FROM users"), 0644); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + return []string{file1, file2} + }, + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectedFiles: 2, + // No violations expected, but one file will have an error + }, + { + name: "Verify TotalFiles and TotalViolations counts", + setupFiles: func(t *testing.T) []string { + tmpDir := t.TempDir() + files := make([]string, 3) + for i := 0; i < 3; i++ { + filePath := filepath.Join(tmpDir, "test"+string(rune('0'+i))+".sql") + content := "SELECT id \nFROM users " + if err := os.WriteFile(filePath, []byte(content), 0644); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + files[i] = filePath + } + return files + }, + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectedFiles: 3, + expectedViolations: 6, // 2 violations per file * 3 files + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + files := tt.setupFiles(t) + linter := New(tt.rules...) + result := linter.LintFiles(files) + + if result.TotalFiles != tt.expectedFiles { + t.Errorf("Expected TotalFiles %d, got %d", tt.expectedFiles, result.TotalFiles) + } + + if result.TotalViolations != tt.expectedViolations { + t.Errorf("Expected TotalViolations %d, got %d", tt.expectedViolations, result.TotalViolations) + } + + if len(result.Files) != tt.expectedFiles { + t.Errorf("Expected %d file results, got %d", tt.expectedFiles, len(result.Files)) + } + }) + } +} + +// TestLinter_LintDirectory tests the LintDirectory() function +func TestLinter_LintDirectory(t *testing.T) { + tests := []struct { + name string + setupDir func(t *testing.T) (string, string) // Returns (dir, pattern) + rules []Rule + expectedFiles int + expectedViolations int + expectError bool + }{ + { + name: "Lint directory with *.sql pattern", + setupDir: func(t *testing.T) (string, string) { + tmpDir := t.TempDir() + os.WriteFile(filepath.Join(tmpDir, "test1.sql"), []byte("SELECT id FROM users"), 0644) + os.WriteFile(filepath.Join(tmpDir, "test2.sql"), []byte("SELECT name FROM products"), 0644) + return tmpDir, "*.sql" + }, + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectedFiles: 2, + expectedViolations: 0, + }, + { + name: "Lint directory with custom pattern (*.txt)", + setupDir: func(t *testing.T) (string, string) { + tmpDir := t.TempDir() + os.WriteFile(filepath.Join(tmpDir, "query1.txt"), []byte("SELECT id FROM users"), 0644) + os.WriteFile(filepath.Join(tmpDir, "query2.txt"), []byte("SELECT name FROM products"), 0644) + os.WriteFile(filepath.Join(tmpDir, "ignore.sql"), []byte("SELECT * FROM test"), 0644) + return tmpDir, "*.txt" + }, + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectedFiles: 2, + expectedViolations: 0, + }, + { + name: "Lint directory with no matching files", + setupDir: func(t *testing.T) (string, string) { + tmpDir := t.TempDir() + os.WriteFile(filepath.Join(tmpDir, "test.txt"), []byte("SELECT id FROM users"), 0644) + return tmpDir, "*.sql" + }, + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectedFiles: 0, + expectedViolations: 0, + }, + { + name: "Lint non-existent directory (error handling)", + setupDir: func(t *testing.T) (string, string) { + return "/nonexistent/directory", "*.sql" + }, + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectError: true, + }, + { + name: "Lint directory recursively with nested subdirectories", + setupDir: func(t *testing.T) (string, string) { + tmpDir := t.TempDir() + subDir := filepath.Join(tmpDir, "subdir") + os.Mkdir(subDir, 0755) + os.WriteFile(filepath.Join(tmpDir, "test1.sql"), []byte("SELECT id FROM users"), 0644) + os.WriteFile(filepath.Join(subDir, "test2.sql"), []byte("SELECT name FROM products"), 0644) + return tmpDir, "*.sql" + }, + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectedFiles: 2, + expectedViolations: 0, + }, + { + name: "Lint directory with mixed file types", + setupDir: func(t *testing.T) (string, string) { + tmpDir := t.TempDir() + os.WriteFile(filepath.Join(tmpDir, "query.sql"), []byte("SELECT id FROM users"), 0644) + os.WriteFile(filepath.Join(tmpDir, "readme.txt"), []byte("Documentation"), 0644) + os.WriteFile(filepath.Join(tmpDir, "config.json"), []byte("{}"), 0644) + return tmpDir, "*.sql" + }, + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectedFiles: 1, + expectedViolations: 0, + }, + { + name: "Lint directory with hidden files (.sql)", + setupDir: func(t *testing.T) (string, string) { + tmpDir := t.TempDir() + os.WriteFile(filepath.Join(tmpDir, ".hidden.sql"), []byte("SELECT id FROM users"), 0644) + os.WriteFile(filepath.Join(tmpDir, "visible.sql"), []byte("SELECT name FROM products"), 0644) + return tmpDir, "*.sql" + }, + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectedFiles: 2, // Both should be found + expectedViolations: 0, + }, + { + name: "Lint empty directory", + setupDir: func(t *testing.T) (string, string) { + tmpDir := t.TempDir() + return tmpDir, "*.sql" + }, + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectedFiles: 0, + expectedViolations: 0, + }, + { + name: "Verify file counts and violation aggregation", + setupDir: func(t *testing.T) (string, string) { + tmpDir := t.TempDir() + os.WriteFile(filepath.Join(tmpDir, "test1.sql"), []byte("SELECT id "), 0644) + os.WriteFile(filepath.Join(tmpDir, "test2.sql"), []byte("SELECT name \nFROM users "), 0644) + return tmpDir, "*.sql" + }, + rules: []Rule{newMockTrailingWhitespaceRule()}, + expectedFiles: 2, + expectedViolations: 3, // 1 from test1.sql, 2 from test2.sql + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir, pattern := tt.setupDir(t) + linter := New(tt.rules...) + result := linter.LintDirectory(dir, pattern) + + if tt.expectError { + // Check if there's an error in the result + hasError := false + for _, fileResult := range result.Files { + if fileResult.Error != nil { + hasError = true + break + } + } + if !hasError { + t.Error("Expected error but got none") + } + return + } + + if result.TotalFiles != tt.expectedFiles { + t.Errorf("Expected TotalFiles %d, got %d", tt.expectedFiles, result.TotalFiles) + } + + if result.TotalViolations != tt.expectedViolations { + t.Errorf("Expected TotalViolations %d, got %d", tt.expectedViolations, result.TotalViolations) + } + }) + } +} + +// TestFormatViolation tests the FormatViolation() function +func TestFormatViolation(t *testing.T) { + tests := []struct { + name string + violation Violation + expectedContent []string // Substrings that should appear in output + }{ + { + name: "Format violation with all fields", + violation: Violation{ + Rule: "L001", + RuleName: "Trailing Whitespace", + Severity: SeverityWarning, + Message: "Line has trailing whitespace", + Location: models.Location{Line: 5, Column: 10}, + Line: "SELECT id ", + Suggestion: "Remove trailing spaces", + CanAutoFix: true, + }, + expectedContent: []string{"L001", "Trailing Whitespace", "line 5", "column 10", "warning", "Remove trailing spaces", "SELECT id"}, + }, + { + name: "Format violation without suggestion", + violation: Violation{ + Rule: "L002", + RuleName: "Mixed Indentation", + Severity: SeverityError, + Message: "Inconsistent indentation", + Location: models.Location{Line: 3, Column: 1}, + Line: "\tSELECT id", + }, + expectedContent: []string{"L002", "Mixed Indentation", "line 3", "column 1", "error"}, + }, + { + name: "Format violation without line content", + violation: Violation{ + Rule: "L003", + RuleName: "Test Rule", + Severity: SeverityInfo, + Message: "Test message", + Location: models.Location{Line: 1, Column: 1}, + }, + expectedContent: []string{"L003", "Test Rule", "line 1", "info"}, + }, + { + name: "Format violation with column position 0", + violation: Violation{ + Rule: "L001", + RuleName: "Trailing Whitespace", + Severity: SeverityWarning, + Message: "Line has trailing whitespace", + Location: models.Location{Line: 1, Column: 0}, + Line: "SELECT id", + }, + expectedContent: []string{"L001", "line 1", "column 0"}, + }, + { + name: "Format violation with very long line", + violation: Violation{ + Rule: "L005", + RuleName: "Long Lines", + Severity: SeverityInfo, + Message: "Line exceeds maximum length", + Location: models.Location{Line: 2, Column: 101}, + Line: strings.Repeat("x", 200), + }, + expectedContent: []string{"L005", "Long Lines", "line 2", "column 101"}, + }, + { + name: "Format violation with Unicode content", + violation: Violation{ + Rule: "L001", + RuleName: "Trailing Whitespace", + Severity: SeverityWarning, + Message: "Line has trailing whitespace", + Location: models.Location{Line: 1, Column: 20}, + Line: "SELECT name FROM users WHERE city = '東京' ", + }, + expectedContent: []string{"L001", "東京"}, + }, + { + name: "Format violation with different severity levels - error", + violation: Violation{ + Rule: "L002", + RuleName: "Mixed Indentation", + Severity: SeverityError, + Message: "Test error", + Location: models.Location{Line: 1, Column: 1}, + }, + expectedContent: []string{"error"}, + }, + { + name: "Format violation with different severity levels - info", + violation: Violation{ + Rule: "L005", + RuleName: "Long Lines", + Severity: SeverityInfo, + Message: "Test info", + Location: models.Location{Line: 1, Column: 1}, + }, + expectedContent: []string{"info"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + output := FormatViolation(tt.violation) + + for _, expected := range tt.expectedContent { + if !strings.Contains(output, expected) { + t.Errorf("Expected output to contain %q, but it didn't.\nOutput: %s", expected, output) + } + } + }) + } +} + +// TestFormatResult tests the FormatResult() function +func TestFormatResult(t *testing.T) { + tests := []struct { + name string + result Result + expectedContent []string + }{ + { + name: "Format result with no files", + result: Result{ + Files: []FileResult{}, + TotalFiles: 0, + TotalViolations: 0, + }, + expectedContent: []string{"Total files: 0", "Total violations: 0"}, + }, + { + name: "Format result with single file, no violations", + result: Result{ + Files: []FileResult{ + { + Filename: "test.sql", + Violations: []Violation{}, + }, + }, + TotalFiles: 1, + TotalViolations: 0, + }, + expectedContent: []string{"Total files: 1", "Total violations: 0"}, + }, + { + name: "Format result with single file, single violation", + result: Result{ + Files: []FileResult{ + { + Filename: "test.sql", + Violations: []Violation{ + { + Rule: "L001", + RuleName: "Trailing Whitespace", + Severity: SeverityWarning, + Message: "Line has trailing whitespace", + Location: models.Location{Line: 1, Column: 10}, + }, + }, + }, + }, + TotalFiles: 1, + TotalViolations: 1, + }, + expectedContent: []string{"test.sql", "1 violation(s)", "L001", "Total files: 1", "Total violations: 1"}, + }, + { + name: "Format result with multiple files and violations", + result: Result{ + Files: []FileResult{ + { + Filename: "test1.sql", + Violations: []Violation{ + { + Rule: "L001", + RuleName: "Trailing Whitespace", + Severity: SeverityWarning, + Message: "Line has trailing whitespace", + Location: models.Location{Line: 1, Column: 10}, + }, + }, + }, + { + Filename: "test2.sql", + Violations: []Violation{ + { + Rule: "L002", + RuleName: "Mixed Indentation", + Severity: SeverityError, + Message: "Inconsistent indentation", + Location: models.Location{Line: 2, Column: 1}, + }, + }, + }, + }, + TotalFiles: 2, + TotalViolations: 2, + }, + expectedContent: []string{"test1.sql", "test2.sql", "L001", "L002", "Total files: 2", "Total violations: 2"}, + }, + { + name: "Format result with file errors", + result: Result{ + Files: []FileResult{ + { + Filename: "test.sql", + Error: os.ErrNotExist, + }, + }, + TotalFiles: 1, + TotalViolations: 0, + }, + expectedContent: []string{"test.sql", "ERROR"}, + }, + { + name: "Format result with mixed success and errors", + result: Result{ + Files: []FileResult{ + { + Filename: "success.sql", + Violations: []Violation{ + { + Rule: "L001", + RuleName: "Trailing Whitespace", + Severity: SeverityWarning, + Message: "Line has trailing whitespace", + Location: models.Location{Line: 1, Column: 10}, + }, + }, + }, + { + Filename: "error.sql", + Error: os.ErrPermission, + }, + }, + TotalFiles: 2, + TotalViolations: 1, + }, + expectedContent: []string{"success.sql", "error.sql", "ERROR", "Total files: 2", "Total violations: 1"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + output := FormatResult(tt.result) + + for _, expected := range tt.expectedContent { + if !strings.Contains(output, expected) { + t.Errorf("Expected output to contain %q, but it didn't.\nOutput: %s", expected, output) + } + } + }) + } +} + +// TestBaseRule tests the BaseRule implementation +func TestBaseRule(t *testing.T) { + tests := []struct { + name string + baseRule BaseRule + expectedID string + expectedStr string + severity Severity + canAutoFix bool + }{ + { + name: "Create BaseRule with all parameters", + baseRule: NewBaseRule( + "L001", + "Test Rule", + "This is a test rule", + SeverityWarning, + true, + ), + expectedID: "L001", + expectedStr: "Test Rule", + severity: SeverityWarning, + canAutoFix: true, + }, + { + name: "Verify ID() method", + baseRule: NewBaseRule( + "L002", + "Another Rule", + "Another test rule", + SeverityError, + false, + ), + expectedID: "L002", + expectedStr: "Another Rule", + severity: SeverityError, + canAutoFix: false, + }, + { + name: "Verify Name() method", + baseRule: NewBaseRule( + "L003", + "Third Rule", + "Third test rule", + SeverityInfo, + true, + ), + expectedID: "L003", + expectedStr: "Third Rule", + severity: SeverityInfo, + canAutoFix: true, + }, + { + name: "Verify Description() method", + baseRule: NewBaseRule( + "L004", + "Fourth Rule", + "Fourth test rule with long description", + SeverityWarning, + false, + ), + expectedID: "L004", + expectedStr: "Fourth Rule", + severity: SeverityWarning, + canAutoFix: false, + }, + { + name: "Verify Severity() and CanAutoFix() methods", + baseRule: NewBaseRule( + "L005", + "Fifth Rule", + "Fifth test rule", + SeverityError, + true, + ), + expectedID: "L005", + expectedStr: "Fifth Rule", + severity: SeverityError, + canAutoFix: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.baseRule.ID() != tt.expectedID { + t.Errorf("Expected ID %q, got %q", tt.expectedID, tt.baseRule.ID()) + } + + if tt.baseRule.Name() != tt.expectedStr { + t.Errorf("Expected Name %q, got %q", tt.expectedStr, tt.baseRule.Name()) + } + + if tt.baseRule.Description() == "" { + t.Error("Expected non-empty description") + } + + if tt.baseRule.Severity() != tt.severity { + t.Errorf("Expected Severity %q, got %q", tt.severity, tt.baseRule.Severity()) + } + + if tt.baseRule.CanAutoFix() != tt.canAutoFix { + t.Errorf("Expected CanAutoFix %v, got %v", tt.canAutoFix, tt.baseRule.CanAutoFix()) + } + }) + } +} + +// Helper function to generate long SQL for testing +func generateLongSQL(lines int) string { + var sb strings.Builder + for i := 0; i < lines; i++ { + sb.WriteString("SELECT id FROM users \n") + } + return strings.TrimSuffix(sb.String(), "\n") +} diff --git a/pkg/linter/rules/whitespace/long_lines_test.go b/pkg/linter/rules/whitespace/long_lines_test.go new file mode 100644 index 00000000..af0b9566 --- /dev/null +++ b/pkg/linter/rules/whitespace/long_lines_test.go @@ -0,0 +1,541 @@ +package whitespace + +import ( + "strings" + "testing" + + "github.com/ajitpratap0/GoSQLX/pkg/linter" +) + +func TestLongLinesRule_Check(t *testing.T) { + tests := []struct { + name string + sql string + maxLength int + expectedViolations int + }{ + { + name: "No violations - all lines under default max (100)", + sql: "SELECT id, name\nFROM users\nWHERE active = true", + maxLength: 100, + expectedViolations: 0, + }, + { + name: "Single line exactly at max length (boundary condition)", + sql: strings.Repeat("A", 100), + maxLength: 100, + expectedViolations: 0, + }, + { + name: "Single line one char over max length", + sql: strings.Repeat("A", 101), + maxLength: 100, + expectedViolations: 1, + }, + { + name: "Single line well over max length (150 chars)", + sql: "SELECT id, name, email, address, phone, city, state, zip, country, created_at, updated_at, deleted_at, is_active, is_verified, preferences FROM users WHERE active = true", + maxLength: 100, + expectedViolations: 1, + }, + { + name: "Multiple lines, some over max", + sql: "SELECT id, name\nFROM users\nWHERE active = true AND email IS NOT NULL AND verified = true AND created_at > '2023-01-01' AND deleted_at IS NULL", + maxLength: 100, + expectedViolations: 1, + }, + { + name: "All lines over max", + sql: strings.Repeat("A", 101) + "\n" + strings.Repeat("B", 102) + "\n" + strings.Repeat("C", 103), + maxLength: 100, + expectedViolations: 3, + }, + { + name: "Empty lines (should be ignored)", + sql: "SELECT id\n\n\nFROM users", + maxLength: 100, + expectedViolations: 0, + }, + { + name: "Comment line with -- over max (should be skipped)", + sql: "-- " + strings.Repeat("This is a very long comment ", 10), + maxLength: 100, + expectedViolations: 0, + }, + { + name: "Comment line with /* over max (should be skipped)", + sql: "/* " + strings.Repeat("This is a very long comment ", 10) + " */", + maxLength: 100, + expectedViolations: 0, + }, + { + name: "SQL with inline comment exceeding max (should trigger violation)", + sql: "SELECT id, name, email, address, phone, city, state, zip, country FROM users WHERE active = true -- inline comment", + maxLength: 100, + expectedViolations: 1, + }, + { + name: "Very long string literal in SQL", + sql: "INSERT INTO messages (content) VALUES ('" + strings.Repeat("very long message content ", 10) + "')", + maxLength: 100, + expectedViolations: 1, + }, + { + name: "Custom max length: 80 characters", + sql: "SELECT id, name, email, address, phone, city, state, zip, country FROM users WHERE active = true", + maxLength: 80, + expectedViolations: 1, + }, + { + name: "Custom max length: 120 characters", + sql: "SELECT id, name, email, address, phone, city, state, zip, country FROM users WHERE active = true", + maxLength: 120, + expectedViolations: 0, + }, + { + name: "Zero max length (should use default 100)", + sql: "SELECT id FROM users", + maxLength: 0, + expectedViolations: 0, + }, + { + name: "Comment-only line vs. code with comment", + sql: "-- This is a standalone comment that is quite long and exceeds the maximum line length configured\nSELECT id, name FROM users WHERE active = true -- This inline comment makes this line exceed max length", + maxLength: 80, + expectedViolations: 1, + }, + { + name: "Multi-line comment block with long lines (only first line skipped)", + sql: "/* This is a multi-line comment block\n with continuation that is short\n that should be skipped */\nSELECT id FROM users", + maxLength: 100, + expectedViolations: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rule := NewLongLinesRule(tt.maxLength) + ctx := linter.NewContext(tt.sql, "test.sql") + + violations, err := rule.Check(ctx) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if len(violations) != tt.expectedViolations { + t.Errorf("Expected %d violations, got %d", tt.expectedViolations, len(violations)) + for i, v := range violations { + t.Logf("Violation %d: %s at line %d (length: %d)", i+1, v.Message, v.Location.Line, len(v.Line)) + } + } + + // Verify violation details + for _, v := range violations { + if v.Rule != "L005" { + t.Errorf("Expected rule ID 'L005', got '%s'", v.Rule) + } + if v.RuleName != "Long Lines" { + t.Errorf("Expected rule name 'Long Lines', got '%s'", v.RuleName) + } + if v.Severity != linter.SeverityInfo { + t.Errorf("Expected severity 'info', got '%s'", v.Severity) + } + if v.CanAutoFix { + t.Error("Expected CanAutoFix to be false") + } + if v.Message != "Line exceeds maximum length" { + t.Errorf("Expected message 'Line exceeds maximum length', got '%s'", v.Message) + } + } + }) + } +} + +func TestLongLinesRule_Check_LineNumbers(t *testing.T) { + tests := []struct { + name string + sql string + maxLength int + expectedViolationAt []int + }{ + { + name: "Violation on first line", + sql: strings.Repeat("A", 101) + "\nSELECT id FROM users", + maxLength: 100, + expectedViolationAt: []int{1}, + }, + { + name: "Violation on third line", + sql: "SELECT id\nFROM users\nWHERE active = true AND verified = true AND email IS NOT NULL AND created_at > '2023-01-01' AND deleted_at IS NULL", + maxLength: 100, + expectedViolationAt: []int{3}, + }, + { + name: "Multiple violations on different lines", + sql: strings.Repeat("A", 101) + "\nSELECT id\n" + strings.Repeat("B", 102) + "\nFROM users\n" + strings.Repeat("C", 103), + maxLength: 100, + expectedViolationAt: []int{1, 3, 5}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rule := NewLongLinesRule(tt.maxLength) + ctx := linter.NewContext(tt.sql, "test.sql") + + violations, err := rule.Check(ctx) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if len(violations) != len(tt.expectedViolationAt) { + t.Errorf("Expected %d violations, got %d", len(tt.expectedViolationAt), len(violations)) + } + + for i, expectedLine := range tt.expectedViolationAt { + if i >= len(violations) { + break + } + if violations[i].Location.Line != expectedLine { + t.Errorf("Violation %d: expected line %d, got %d", i, expectedLine, violations[i].Location.Line) + } + } + }) + } +} + +func TestLongLinesRule_Fix(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "Fix returns unchanged content (no auto-fix support)", + input: "SELECT id, name FROM users", + expected: "SELECT id, name FROM users", + }, + { + name: "Violations exist but content unchanged after fix", + input: strings.Repeat("A", 150) + "\nSELECT id FROM users", + expected: strings.Repeat("A", 150) + "\nSELECT id FROM users", + }, + { + name: "Empty string handled correctly", + input: "", + expected: "", + }, + { + name: "Large file unchanged", + input: strings.Repeat("SELECT id, name, email, address, phone, city, state, zip, country FROM users WHERE active = true\n", 100), + expected: strings.Repeat("SELECT id, name, email, address, phone, city, state, zip, country FROM users WHERE active = true\n", 100), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rule := NewLongLinesRule(100) + ctx := linter.NewContext(tt.input, "test.sql") + + violations, err := rule.Check(ctx) + if err != nil { + t.Fatalf("Unexpected error during check: %v", err) + } + + fixed, err := rule.Fix(tt.input, violations) + if err != nil { + t.Fatalf("Unexpected error during fix: %v", err) + } + + if fixed != tt.expected { + t.Errorf("Fix result mismatch:\nExpected: %q\nGot: %q", tt.expected, fixed) + } + }) + } +} + +func TestLongLinesRule_Metadata(t *testing.T) { + rule := NewLongLinesRule(100) + + if rule.ID() != "L005" { + t.Errorf("Expected ID 'L005', got '%s'", rule.ID()) + } + + if rule.Name() != "Long Lines" { + t.Errorf("Expected name 'Long Lines', got '%s'", rule.Name()) + } + + if rule.Severity() != linter.SeverityInfo { + t.Errorf("Expected severity 'info', got '%s'", rule.Severity()) + } + + if rule.CanAutoFix() { + t.Error("Expected CanAutoFix to be false") + } + + if rule.Description() == "" { + t.Error("Expected non-empty description") + } +} + +func TestLongLinesRule_MaxLength(t *testing.T) { + tests := []struct { + name string + maxLength int + sql string + expectedViolations int + }{ + { + name: "Max length 50", + maxLength: 50, + sql: "SELECT id, name, email, address, phone FROM users WHERE active = true", + expectedViolations: 1, + }, + { + name: "Max length 200", + maxLength: 200, + sql: "SELECT id, name, email, address, phone, city, state, zip, country, created_at, updated_at FROM users WHERE active = true", + expectedViolations: 0, + }, + { + name: "Max length 1 (extreme case)", + maxLength: 1, + sql: "SELECT id FROM users", + expectedViolations: 1, + }, + { + name: "Default max length 100", + maxLength: 100, + sql: strings.Repeat("A", 100), + expectedViolations: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rule := NewLongLinesRule(tt.maxLength) + ctx := linter.NewContext(tt.sql, "test.sql") + + violations, err := rule.Check(ctx) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if len(violations) != tt.expectedViolations { + t.Errorf("Expected %d violations, got %d", tt.expectedViolations, len(violations)) + } + + // Verify MaxLength is set correctly + expectedMaxLength := tt.maxLength + if tt.maxLength <= 0 { + expectedMaxLength = 100 // Default + } + if rule.MaxLength != expectedMaxLength { + t.Errorf("Expected MaxLength %d, got %d", expectedMaxLength, rule.MaxLength) + } + }) + } +} + +func TestLongLinesRule_EdgeCases(t *testing.T) { + tests := []struct { + name string + sql string + maxLength int + expectedViolations int + description string + }{ + { + name: "Empty string", + sql: "", + maxLength: 100, + expectedViolations: 0, + description: "Empty input should not cause errors", + }, + { + name: "Single newline", + sql: "\n", + maxLength: 100, + expectedViolations: 0, + description: "Single newline should be handled", + }, + { + name: "Line with only spaces (counts toward length)", + sql: strings.Repeat(" ", 101), + maxLength: 100, + expectedViolations: 1, + description: "Spaces-only line should count toward length", + }, + { + name: "Unicode characters in long line", + sql: "SELECT id FROM users WHERE name = '" + strings.Repeat("日本語", 30) + "'", + maxLength: 100, + expectedViolations: 1, + description: "Unicode characters should be counted correctly", + }, + { + name: "Tabs counting as single characters", + sql: strings.Repeat("\t", 101), + maxLength: 100, + expectedViolations: 1, + description: "Tabs should count as single characters", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rule := NewLongLinesRule(tt.maxLength) + ctx := linter.NewContext(tt.sql, "test.sql") + + violations, err := rule.Check(ctx) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if len(violations) != tt.expectedViolations { + t.Errorf("%s: Expected %d violations, got %d", tt.description, tt.expectedViolations, len(violations)) + } + }) + } +} + +func TestLongLinesRule_CommentDetection(t *testing.T) { + tests := []struct { + name string + sql string + maxLength int + expectedViolations int + description string + }{ + { + name: "Single-line comment with -- at start", + sql: "-- " + strings.Repeat("This is a comment ", 20), + maxLength: 100, + expectedViolations: 0, + description: "Comment lines starting with -- should be skipped", + }, + { + name: "Single-line comment with -- after whitespace", + sql: " -- " + strings.Repeat("This is a comment ", 20), + maxLength: 100, + expectedViolations: 0, + description: "Comment lines with -- after whitespace should be skipped", + }, + { + name: "Block comment with /* at start", + sql: "/* " + strings.Repeat("This is a block comment ", 20) + " */", + maxLength: 100, + expectedViolations: 0, + description: "Comment lines starting with /* should be skipped", + }, + { + name: "Block comment with /* after whitespace", + sql: " /* " + strings.Repeat("This is a block comment ", 20) + " */", + maxLength: 100, + expectedViolations: 0, + description: "Comment lines with /* after whitespace should be skipped", + }, + { + name: "SQL with inline -- comment (should trigger)", + sql: "SELECT id, name, email, address, phone, city FROM users WHERE active = true -- inline comment", + maxLength: 80, + expectedViolations: 1, + description: "SQL with inline comments should trigger violations", + }, + { + name: "SQL with inline /* comment (should trigger)", + sql: "SELECT id, name, email, address, phone, city FROM users WHERE active = true /* inline comment */", + maxLength: 80, + expectedViolations: 1, + description: "SQL with inline block comments should trigger violations", + }, + { + name: "Mixed content: comment line and long SQL line", + sql: "-- " + strings.Repeat("Long comment ", 20) + "\nSELECT id, name, email, address, phone, city, state, zip FROM users WHERE active = true", + maxLength: 80, + expectedViolations: 1, + description: "Comment lines should be skipped, but SQL lines should be checked", + }, + { + name: "Multi-line block comment (only first line is comment)", + sql: "/* Start of comment\n Continuation short\n End of comment */", + maxLength: 100, + expectedViolations: 0, + description: "Block comment starts are detected, continuation lines checked normally", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rule := NewLongLinesRule(tt.maxLength) + ctx := linter.NewContext(tt.sql, "test.sql") + + violations, err := rule.Check(ctx) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if len(violations) != tt.expectedViolations { + t.Errorf("%s: Expected %d violations, got %d", tt.description, tt.expectedViolations, len(violations)) + for i, v := range violations { + t.Logf("Violation %d: %s at line %d", i+1, v.Message, v.Location.Line) + } + } + }) + } +} + +func TestLongLinesRule_ViolationDetails(t *testing.T) { + sql := strings.Repeat("A", 150) + rule := NewLongLinesRule(100) + ctx := linter.NewContext(sql, "test.sql") + + violations, err := rule.Check(ctx) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if len(violations) != 1 { + t.Fatalf("Expected 1 violation, got %d", len(violations)) + } + + v := violations[0] + + // Verify violation properties + if v.Rule != "L005" { + t.Errorf("Expected Rule 'L005', got '%s'", v.Rule) + } + + if v.RuleName != "Long Lines" { + t.Errorf("Expected RuleName 'Long Lines', got '%s'", v.RuleName) + } + + if v.Severity != linter.SeverityInfo { + t.Errorf("Expected Severity 'info', got '%s'", v.Severity) + } + + if v.Message != "Line exceeds maximum length" { + t.Errorf("Expected Message 'Line exceeds maximum length', got '%s'", v.Message) + } + + if v.Location.Line != 1 { + t.Errorf("Expected Location.Line 1, got %d", v.Location.Line) + } + + if v.Location.Column != 101 { + t.Errorf("Expected Location.Column 101 (maxLength+1), got %d", v.Location.Column) + } + + if v.Line != sql { + t.Errorf("Expected Line to contain full line content") + } + + expectedSuggestion := "Split this line into multiple lines (current: 150 chars, max: 100)" + if v.Suggestion != expectedSuggestion { + t.Errorf("Expected Suggestion '%s', got '%s'", expectedSuggestion, v.Suggestion) + } + + if v.CanAutoFix { + t.Error("Expected CanAutoFix to be false") + } +} diff --git a/pkg/linter/rules/whitespace/mixed_indentation_test.go b/pkg/linter/rules/whitespace/mixed_indentation_test.go new file mode 100644 index 00000000..c641a651 --- /dev/null +++ b/pkg/linter/rules/whitespace/mixed_indentation_test.go @@ -0,0 +1,523 @@ +package whitespace + +import ( + "testing" + + "github.com/ajitpratap0/GoSQLX/pkg/linter" +) + +func TestMixedIndentationRule_Check(t *testing.T) { + tests := []struct { + name string + sql string + expectedViolations int + }{ + { + name: "No indentation - all lines start with non-whitespace", + sql: "SELECT id, name\nFROM users\nWHERE active = true", + expectedViolations: 0, + }, + { + name: "Consistent spaces throughout file (multiple lines)", + sql: "SELECT id, name\n FROM users\n WHERE active = true\n AND verified = true", + expectedViolations: 0, + }, + { + name: "Consistent tabs throughout file (multiple lines)", + sql: "SELECT id, name\n\tFROM users\n\tWHERE active = true\n\t\tAND verified = true", + expectedViolations: 0, + }, + { + name: "Single line with tabs and spaces mixed in leading whitespace", + sql: "\t SELECT id FROM users", + expectedViolations: 1, + }, + { + name: "Multiple lines: some with tab indent, some with space indent", + sql: "SELECT id\n\tFROM users\n WHERE active = true", + expectedViolations: 1, + }, + { + name: "Empty lines should be ignored", + sql: "SELECT id\n\n FROM users\n\n WHERE active = true", + expectedViolations: 0, + }, + { + name: "Whitespace-only lines", + sql: "SELECT id\n \n FROM users", + expectedViolations: 0, + }, + { + name: "First line sets space indent, later line uses tabs", + sql: " SELECT id\n\tFROM users", + expectedViolations: 1, + }, + { + name: "First line sets tab indent, later line uses spaces", + sql: "\tSELECT id\n FROM users", + expectedViolations: 1, + }, + { + name: "Complex: nested indentation all spaces (no violations)", + sql: "SELECT\n id,\n name,\n (\n SELECT COUNT(*)\n FROM orders\n WHERE user_id = users.id\n ) AS order_count\nFROM users", + expectedViolations: 0, + }, + { + name: "Complex: nested indentation all tabs (no violations)", + sql: "SELECT\n\tid,\n\tname,\n\t(\n\t\tSELECT COUNT(*)\n\t\tFROM orders\n\t\tWHERE user_id = users.id\n\t) AS order_count\nFROM users", + expectedViolations: 0, + }, + { + name: "Line with no leading whitespace (should be ignored)", + sql: "SELECT id FROM users", + expectedViolations: 0, + }, + { + name: "Single tab character indentation", + sql: "\tSELECT id FROM users", + expectedViolations: 0, + }, + { + name: "Single space character indentation", + sql: " SELECT id FROM users", + expectedViolations: 0, + }, + { + name: "Multiple spaces (4 spaces) indentation", + sql: " SELECT id FROM users", + expectedViolations: 0, + }, + { + name: "Tab followed by content", + sql: "\tSELECT id\n\tFROM users\n\tWHERE active = true", + expectedViolations: 0, + }, + { + name: "Spaces followed by content", + sql: " SELECT id\n FROM users\n WHERE active = true", + expectedViolations: 0, + }, + { + name: "Mixed on same line + inconsistent across file (multiple violations)", + sql: "\t SELECT id\n FROM users\n\tWHERE active = true", + expectedViolations: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rule := NewMixedIndentationRule() + ctx := linter.NewContext(tt.sql, "test.sql") + + violations, err := rule.Check(ctx) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if len(violations) != tt.expectedViolations { + t.Errorf("Expected %d violations, got %d", tt.expectedViolations, len(violations)) + for i, v := range violations { + t.Logf("Violation %d: %s at line %d", i+1, v.Message, v.Location.Line) + } + } + + // Verify violation details + for _, v := range violations { + if v.Rule != "L002" { + t.Errorf("Expected rule ID 'L002', got '%s'", v.Rule) + } + if v.RuleName != "Mixed Indentation" { + t.Errorf("Expected rule name 'Mixed Indentation', got '%s'", v.RuleName) + } + if v.Severity != linter.SeverityError { + t.Errorf("Expected severity 'error', got '%s'", v.Severity) + } + if !v.CanAutoFix { + t.Error("Expected CanAutoFix to be true") + } + } + }) + } +} + +func TestMixedIndentationRule_Check_ViolationMessages(t *testing.T) { + tests := []struct { + name string + sql string + expectedMessage string + }{ + { + name: "Mixed tabs and spaces on same line", + sql: "\t SELECT id FROM users", + expectedMessage: "Line mixes tabs and spaces for indentation", + }, + { + name: "Inconsistent indentation across file", + sql: " SELECT id\n\tFROM users", + expectedMessage: "Inconsistent indentation: file uses both tabs and spaces", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rule := NewMixedIndentationRule() + ctx := linter.NewContext(tt.sql, "test.sql") + + violations, err := rule.Check(ctx) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if len(violations) == 0 { + t.Fatal("Expected at least one violation") + } + + if violations[0].Message != tt.expectedMessage { + t.Errorf("Expected message '%s', got '%s'", tt.expectedMessage, violations[0].Message) + } + }) + } +} + +func TestMixedIndentationRule_Check_LineNumbers(t *testing.T) { + tests := []struct { + name string + sql string + expectedViolationAt []int + }{ + { + name: "Violation on first line", + sql: "\t SELECT id FROM users", + expectedViolationAt: []int{1}, + }, + { + name: "Violation on second line", + sql: " SELECT id\n\tFROM users", + expectedViolationAt: []int{2}, + }, + { + name: "Violations on multiple lines", + sql: "\t SELECT id\n FROM users\n\tWHERE active = true", + expectedViolationAt: []int{1, 3}, + }, + { + name: "Violation on third line only", + sql: "SELECT id\n FROM users\n\tWHERE active = true", + expectedViolationAt: []int{3}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rule := NewMixedIndentationRule() + ctx := linter.NewContext(tt.sql, "test.sql") + + violations, err := rule.Check(ctx) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if len(violations) != len(tt.expectedViolationAt) { + t.Errorf("Expected %d violations, got %d", len(tt.expectedViolationAt), len(violations)) + } + + for i, expectedLine := range tt.expectedViolationAt { + if i >= len(violations) { + break + } + if violations[i].Location.Line != expectedLine { + t.Errorf("Violation %d: expected line %d, got %d", i, expectedLine, violations[i].Location.Line) + } + } + }) + } +} + +func TestMixedIndentationRule_Fix(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "Convert all tabs to spaces (tabs at start of lines)", + input: "\tSELECT id\n\tFROM users\n\t\tWHERE active = true", + expected: " SELECT id\n FROM users\n WHERE active = true", + }, + { + name: "Preserve already consistent spacing", + input: " SELECT id\n FROM users\n WHERE active = true", + expected: " SELECT id\n FROM users\n WHERE active = true", + }, + { + name: "Handle nested/multiple indent levels (tabs → spaces)", + input: "SELECT\n\tid,\n\tname,\n\t(\n\t\tSELECT COUNT(*)\n\t\tFROM orders\n\t) AS count\nFROM users", + expected: "SELECT\n id,\n name,\n (\n SELECT COUNT(*)\n FROM orders\n ) AS count\nFROM users", + }, + { + name: "Preserve non-leading tabs (tabs in content should not be converted)", + input: "\tSELECT\t'value'\tFROM users", + expected: " SELECT\t'value'\tFROM users", + }, + { + name: "Mixed indentation file conversion", + input: "\tSELECT id\n FROM users\n\t\tWHERE active = true", + expected: " SELECT id\n FROM users\n WHERE active = true", + }, + { + name: "Empty file handling", + input: "", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rule := NewMixedIndentationRule() + ctx := linter.NewContext(tt.input, "test.sql") + + violations, err := rule.Check(ctx) + if err != nil { + t.Fatalf("Unexpected error during check: %v", err) + } + + fixed, err := rule.Fix(tt.input, violations) + if err != nil { + t.Fatalf("Unexpected error during fix: %v", err) + } + + if fixed != tt.expected { + t.Errorf("Fix result mismatch:\nExpected: %q\nGot: %q", tt.expected, fixed) + } + }) + } +} + +func TestMixedIndentationRule_Fix_PreservesContent(t *testing.T) { + tests := []struct { + name string + input string + }{ + { + name: "Preserve string literals with tabs", + input: "\tSELECT 'data\twith\ttabs' FROM users", + }, + { + name: "Preserve comments", + input: "\t-- This is a comment\n\tSELECT id FROM users", + }, + { + name: "Preserve empty lines", + input: "\tSELECT id\n\n\tFROM users", + }, + { + name: "Preserve line endings", + input: "\tSELECT id\n\tFROM users\n\tWHERE active = true", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rule := NewMixedIndentationRule() + ctx := linter.NewContext(tt.input, "test.sql") + + violations, err := rule.Check(ctx) + if err != nil { + t.Fatalf("Unexpected error during check: %v", err) + } + + fixed, err := rule.Fix(tt.input, violations) + if err != nil { + t.Fatalf("Unexpected error during fix: %v", err) + } + + // Verify that fixing removes tabs from indentation + // Count tabs in original vs fixed + originalLeadingTabs := 0 + fixedLeadingTabs := 0 + + for _, line := range ctx.Lines { + for _, char := range line { + if char == '\t' { + originalLeadingTabs++ + } else if char != ' ' { + break + } + } + } + + fixedCtx := linter.NewContext(fixed, "test.sql") + for _, line := range fixedCtx.Lines { + for _, char := range line { + if char == '\t' { + fixedLeadingTabs++ + } else if char != ' ' { + break + } + } + } + + if originalLeadingTabs > 0 && fixedLeadingTabs >= originalLeadingTabs { + t.Errorf("Expected leading tabs to be reduced, original: %d, fixed: %d", originalLeadingTabs, fixedLeadingTabs) + } + }) + } +} + +func TestMixedIndentationRule_Fix_Idempotency(t *testing.T) { + // Applying fix multiple times should yield the same result + tests := []struct { + name string + input string + }{ + { + name: "Single tab indentation", + input: "\tSELECT id FROM users", + }, + { + name: "Multiple tab levels", + input: "\tSELECT id\n\t\tFROM users", + }, + { + name: "Mixed indentation", + input: "\t SELECT id\n FROM users", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rule := NewMixedIndentationRule() + ctx := linter.NewContext(tt.input, "test.sql") + + violations, err := rule.Check(ctx) + if err != nil { + t.Fatalf("Unexpected error during check: %v", err) + } + + fixed1, err := rule.Fix(tt.input, violations) + if err != nil { + t.Fatalf("Unexpected error during first fix: %v", err) + } + + // Apply fix again + ctx2 := linter.NewContext(fixed1, "test.sql") + violations2, err := rule.Check(ctx2) + if err != nil { + t.Fatalf("Unexpected error during second check: %v", err) + } + + fixed2, err := rule.Fix(fixed1, violations2) + if err != nil { + t.Fatalf("Unexpected error during second fix: %v", err) + } + + if fixed1 != fixed2 { + t.Errorf("Fix is not idempotent:\nFirst fix: %q\nSecond fix: %q", fixed1, fixed2) + } + }) + } +} + +func TestMixedIndentationRule_Metadata(t *testing.T) { + rule := NewMixedIndentationRule() + + if rule.ID() != "L002" { + t.Errorf("Expected ID 'L002', got '%s'", rule.ID()) + } + + if rule.Name() != "Mixed Indentation" { + t.Errorf("Expected name 'Mixed Indentation', got '%s'", rule.Name()) + } + + if rule.Severity() != linter.SeverityError { + t.Errorf("Expected severity 'error', got '%s'", rule.Severity()) + } + + if !rule.CanAutoFix() { + t.Error("Expected CanAutoFix to be true") + } + + if rule.Description() == "" { + t.Error("Expected non-empty description") + } +} + +func TestMixedIndentationRule_EdgeCases(t *testing.T) { + tests := []struct { + name string + sql string + expectedViolations int + description string + }{ + { + name: "Only whitespace line with tabs", + sql: "\t\t\t", + expectedViolations: 0, + description: "Whitespace-only lines should be ignored", + }, + { + name: "Only whitespace line with spaces", + sql: " ", + expectedViolations: 0, + description: "Whitespace-only lines should be ignored", + }, + { + name: "Empty string", + sql: "", + expectedViolations: 0, + description: "Empty input should not cause errors", + }, + { + name: "Single newline", + sql: "\n", + expectedViolations: 0, + description: "Single newline should be handled", + }, + { + name: "Multiple empty lines", + sql: "\n\n\n", + expectedViolations: 0, + description: "Multiple empty lines should be ignored", + }, + { + name: "Tab at end of line (not leading)", + sql: "SELECT id\tFROM users", + expectedViolations: 0, + description: "Tabs in content (not leading) should not trigger violations", + }, + { + name: "Space at end of line (not leading)", + sql: "SELECT id FROM users", + expectedViolations: 0, + description: "Spaces in content should not trigger violations", + }, + { + name: "Very deep nesting with tabs", + sql: "\t\t\t\t\t\tSELECT id FROM users", + expectedViolations: 0, + description: "Deep nesting with consistent tabs is valid", + }, + { + name: "Very deep nesting with spaces", + sql: " SELECT id FROM users", + expectedViolations: 0, + description: "Deep nesting with consistent spaces is valid", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rule := NewMixedIndentationRule() + ctx := linter.NewContext(tt.sql, "test.sql") + + violations, err := rule.Check(ctx) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if len(violations) != tt.expectedViolations { + t.Errorf("%s: Expected %d violations, got %d", tt.description, tt.expectedViolations, len(violations)) + } + }) + } +}