diff --git a/cmd/cli/commands/run.go b/cmd/cli/commands/run.go index b585c9a82..7ec76d887 100644 --- a/cmd/cli/commands/run.go +++ b/cmd/cli/commands/run.go @@ -132,6 +132,7 @@ func generateInteractiveWithReadline(cmd *cobra.Command, desktopClient *desktop. fmt.Fprintln(os.Stderr, " Ctrl + w Delete the word before the cursor") fmt.Fprintln(os.Stderr, "") fmt.Fprintln(os.Stderr, " Ctrl + l Clear the screen") + fmt.Fprintln(os.Stderr, " Ctrl + x Open prompt in your default text editor") fmt.Fprintln(os.Stderr, " Ctrl + c Stop the model from responding") fmt.Fprintln(os.Stderr, " Ctrl + d Exit (/bye)") fmt.Fprintln(os.Stderr, "") diff --git a/cmd/cli/readline/editor.go b/cmd/cli/readline/editor.go new file mode 100644 index 000000000..ee20340f7 --- /dev/null +++ b/cmd/cli/readline/editor.go @@ -0,0 +1,117 @@ +package readline + +import ( + "errors" + "fmt" + "os" + "os/exec" + "runtime" + "strings" +) + +const ( + defaultEditor = "vi" + defaultShell = "/bin/sh" + windowsEditor = "notepad" + windowsShell = "cmd" +) + +func openInEditor(fd uintptr, termios any, content string) (string, error) { + if err := UnsetRawMode(fd, termios); err != nil { + return content, err + } + + edited, err := runEditor(content) + + if _, restoreErr := SetRawMode(fd); restoreErr != nil { + return content, errors.Join(err, restoreErr) + } + + if err != nil { + return content, err + } + + return edited, nil +} + +func platformize(linux, windows string) string { + if runtime.GOOS == "windows" { + return windows + } + return linux +} + +func defaultEnvShell() []string { + shell := os.Getenv("SHELL") + if shell == "" { + shell = platformize(defaultShell, windowsShell) + } + flag := "-c" + if shell == windowsShell { + flag = "/C" + } + return []string{shell, flag} +} + +func resolveEditor() ([]string, bool) { + editor := strings.TrimSpace(os.Getenv("EDITOR")) + if editor == "" { + editor = platformize(defaultEditor, windowsEditor) + } + + if !strings.Contains(editor, " ") { + return []string{editor}, false + } + + if !strings.ContainsAny(editor, "\"'\\") { + return strings.Split(editor, " "), false + } + + shell := defaultEnvShell() + return append(shell, editor), true +} + +func buildEditorCmd(filePath string) *exec.Cmd { + args, shell := resolveEditor() + + if shell { + // The editor string is the last element — append the file path to it + safeFilePath := strings.ReplaceAll(filePath, "'", "'\\''") + args[len(args)-1] = fmt.Sprintf("%s '%s'", args[len(args)-1], safeFilePath) + } else { + args = append(args, filePath) + } + + //nolint:gosec // $EDITOR is a user-controlled local env var, same trust model as git/kubectl + cmd := exec.Command(args[0], args[1:]...) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + return cmd +} + +func runEditor(content string) (string, error) { + tmpFile, err := os.CreateTemp("", "docker-model-prompt-*.txt") + if err != nil { + return content, err + } + defer os.Remove(tmpFile.Name()) + + if _, err := tmpFile.WriteString(content); err != nil { + tmpFile.Close() + return content, err + } + tmpFile.Close() + + cmd := buildEditorCmd(tmpFile.Name()) + if err := cmd.Run(); err != nil { + return content, err + } + + edited, err := os.ReadFile(tmpFile.Name()) + if err != nil { + return content, err + } + + return strings.TrimRight(string(edited), "\r\n"), nil +} diff --git a/cmd/cli/readline/editor_test.go b/cmd/cli/readline/editor_test.go new file mode 100644 index 000000000..64489d813 --- /dev/null +++ b/cmd/cli/readline/editor_test.go @@ -0,0 +1,96 @@ +//go:build !windows + +package readline + +import ( + "os" + "path/filepath" + "testing" +) + +func createMockEditor(t *testing.T, scriptBody string) string { + t.Helper() + editorScript := filepath.Join(t.TempDir(), "mock-editor.sh") + if err := os.WriteFile(editorScript, []byte("#!/bin/sh\n"+scriptBody+"\n"), 0o755); err != nil { + t.Fatalf("failed to create mock editor: %v", err) + } + t.Setenv("EDITOR", editorScript) + return editorScript +} + +func TestRunEditor(t *testing.T) { + tests := []struct { + name string + mockEditorScript string + input string + expected string + }{ + { + name: "modifies content", + mockEditorScript: `printf " edited" >> "$1"`, + input: "hello docker model prompt", + expected: "hello docker model prompt edited", + }, + { + name: "empty content", + mockEditorScript: `printf "new content" > "$1"`, + input: "", + expected: "new content", + }, + { + name: "strips trailing newline", + mockEditorScript: `printf "edited\n" > "$1"`, + input: "", + expected: "edited", + }, + { + name: "strips trailing carriage return and newline", + mockEditorScript: `printf "edited\r\n" > "$1"`, + input: "", + expected: "edited", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + createMockEditor(t, tt.mockEditorScript) + + result, err := runEditor(tt.input) + if err != nil { + t.Fatalf("runEditor failed: %v", err) + } + + if result != tt.expected { + t.Errorf("expected %q, got %q", tt.expected, result) + } + }) + } +} + +func TestRunEditorReturnsOriginalContentOnFailure(t *testing.T) { + t.Setenv("EDITOR", "non_exists_editor") + + content := "docker model prompt hello" + result, err := runEditor(content) + if err == nil { + t.Fatal("expected error from nonexistent editor") + } + + if result != content { + t.Errorf("expected original content on failure, got %q", result) + } +} + +func TestRunEditorWithEditorArgs(t *testing.T) { + editorScript := createMockEditor(t, `printf "edited with args" > "$2"`) + t.Setenv("EDITOR", editorScript+" --wait") + + result, err := runEditor("original") + if err != nil { + t.Fatalf("runEditor failed: %v", err) + } + + if result != "edited with args" { + t.Errorf("expected %q, got %q", "edited with args", result) + } +} diff --git a/cmd/cli/readline/readline.go b/cmd/cli/readline/readline.go index d8121362e..66f315691 100644 --- a/cmd/cli/readline/readline.go +++ b/cmd/cli/readline/readline.go @@ -209,6 +209,14 @@ func (i *Instance) Readline() (string, error) { buf.ClearScreen() case CharCtrlW: buf.DeleteWord() + case CharCtrlX: + fd := os.Stdin.Fd() + edited, err := openInEditor(fd, i.Terminal.termios, buf.String()) + if err != nil { + fmt.Fprintf(os.Stderr, "error opening editor: %s\n", err) + break + } + buf.Replace([]rune(edited)) case CharCtrlZ: fd := os.Stdin.Fd() return handleCharCtrlZ(fd, i.Terminal.termios) diff --git a/cmd/cli/readline/types.go b/cmd/cli/readline/types.go index f4efa8d92..a18b71d32 100644 --- a/cmd/cli/readline/types.go +++ b/cmd/cli/readline/types.go @@ -24,6 +24,7 @@ const ( CharTranspose = 20 CharCtrlU = 21 CharCtrlW = 23 + CharCtrlX = 24 CharCtrlY = 25 CharCtrlZ = 26 CharEsc = 27