|
4 | 4 | package cmd |
5 | 5 |
|
6 | 6 | import ( |
| 7 | + "encoding/base64" |
7 | 8 | "fmt" |
8 | 9 | "io" |
| 10 | + "mime" |
9 | 11 | "os" |
10 | | - "strings" |
| 12 | + "path/filepath" |
11 | 13 |
|
12 | 14 | "github.com/spf13/cobra" |
13 | | - "github.com/wavetermdev/waveterm/pkg/waveobj" |
14 | 15 | "github.com/wavetermdev/waveterm/pkg/wshrpc" |
15 | 16 | "github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient" |
16 | 17 | "github.com/wavetermdev/waveterm/pkg/wshutil" |
17 | 18 | ) |
18 | 19 |
|
19 | 20 | var aiCmd = &cobra.Command{ |
20 | | - Use: "ai [-] [message...]", |
21 | | - Short: "Send a message to an AI block", |
| 21 | + Use: "ai [options] [files...]", |
| 22 | + Short: "Append content to Wave AI sidebar prompt", |
| 23 | + Long: `Append content to Wave AI sidebar prompt (does not auto-submit by default) |
| 24 | +
|
| 25 | +Arguments: |
| 26 | + files... Files to attach (use '-' for stdin) |
| 27 | +
|
| 28 | +Examples: |
| 29 | + git diff | wsh ai - # Pipe diff to AI, ask question in UI |
| 30 | + wsh ai main.go # Attach file, ask question in UI |
| 31 | + wsh ai *.go -m "find bugs" # Attach files with message |
| 32 | + wsh ai -s - -m "review" < log.txt # Stdin + message, auto-submit |
| 33 | + wsh ai -n config.json # New chat with file attached`, |
22 | 34 | RunE: aiRun, |
23 | 35 | PreRunE: preRunSetupRpcClient, |
24 | 36 | DisableFlagsInUseLine: true, |
25 | 37 | } |
26 | 38 |
|
27 | | -var aiFileFlags []string |
| 39 | +var aiMessageFlag string |
| 40 | +var aiSubmitFlag bool |
28 | 41 | var aiNewBlockFlag bool |
29 | 42 |
|
30 | 43 | func init() { |
31 | 44 | rootCmd.AddCommand(aiCmd) |
32 | | - aiCmd.Flags().BoolVarP(&aiNewBlockFlag, "new", "n", false, "create a new AI block") |
33 | | - aiCmd.Flags().StringArrayVarP(&aiFileFlags, "file", "f", nil, "attach file content (use '-' for stdin)") |
| 45 | + aiCmd.Flags().StringVarP(&aiMessageFlag, "message", "m", "", "optional message/question to append after files") |
| 46 | + aiCmd.Flags().BoolVarP(&aiSubmitFlag, "submit", "s", false, "submit the prompt immediately after appending") |
| 47 | + aiCmd.Flags().BoolVarP(&aiNewBlockFlag, "new", "n", false, "create a new AI chat instead of using existing") |
34 | 48 | } |
35 | 49 |
|
36 | | -func encodeFile(builder *strings.Builder, file io.Reader, fileName string) error { |
37 | | - data, err := io.ReadAll(file) |
38 | | - if err != nil { |
39 | | - return fmt.Errorf("error reading file: %w", err) |
| 50 | +func getMimeType(filename string) string { |
| 51 | + ext := filepath.Ext(filename) |
| 52 | + if ext == "" { |
| 53 | + return "text/plain" |
40 | 54 | } |
41 | | - // Start delimiter with the file name |
42 | | - builder.WriteString(fmt.Sprintf("\n@@@start file %q\n", fileName)) |
43 | | - // Read the file content and write it to the builder |
44 | | - builder.Write(data) |
45 | | - // End delimiter with the file name |
46 | | - builder.WriteString(fmt.Sprintf("\n@@@end file %q\n\n", fileName)) |
47 | | - return nil |
| 55 | + mimeType := mime.TypeByExtension(ext) |
| 56 | + if mimeType == "" { |
| 57 | + return "text/plain" |
| 58 | + } |
| 59 | + return mimeType |
| 60 | +} |
| 61 | + |
| 62 | +func getMaxFileSize(mimeType string) (int, string) { |
| 63 | + if mimeType == "application/pdf" { |
| 64 | + return 5 * 1024 * 1024, "5MB" |
| 65 | + } |
| 66 | + if mimeType[:6] == "image/" { |
| 67 | + return 7 * 1024 * 1024, "7MB" |
| 68 | + } |
| 69 | + return 200 * 1024, "200KB" |
48 | 70 | } |
49 | 71 |
|
50 | 72 | func aiRun(cmd *cobra.Command, args []string) (rtnErr error) { |
51 | 73 | defer func() { |
52 | 74 | sendActivity("ai", rtnErr == nil) |
53 | 75 | }() |
54 | 76 |
|
55 | | - if len(args) == 0 { |
| 77 | + if len(args) == 0 && aiMessageFlag == "" { |
56 | 78 | OutputHelpMessage(cmd) |
57 | | - return fmt.Errorf("no message provided") |
| 79 | + return fmt.Errorf("no files or message provided") |
58 | 80 | } |
59 | 81 |
|
| 82 | + const maxBatchSize = 7 * 1024 * 1024 |
| 83 | + const largeFileThreshold = 1 * 1024 * 1024 |
| 84 | + const maxFileCount = 15 |
| 85 | + const rpcTimeout = 30000 |
| 86 | + |
| 87 | + var allFiles []wshrpc.AIAttachedFile |
60 | 88 | var stdinUsed bool |
61 | | - var message strings.Builder |
62 | 89 |
|
63 | | - // Handle file attachments first |
64 | | - for _, file := range aiFileFlags { |
65 | | - if file == "-" { |
| 90 | + if len(args) > maxFileCount { |
| 91 | + return fmt.Errorf("too many files (maximum %d files allowed)", maxFileCount) |
| 92 | + } |
| 93 | + |
| 94 | + for _, filePath := range args { |
| 95 | + var data []byte |
| 96 | + var fileName string |
| 97 | + var mimeType string |
| 98 | + var err error |
| 99 | + |
| 100 | + if filePath == "-" { |
66 | 101 | if stdinUsed { |
67 | 102 | return fmt.Errorf("stdin (-) can only be used once") |
68 | 103 | } |
69 | 104 | stdinUsed = true |
70 | | - if err := encodeFile(&message, os.Stdin, "<stdin>"); err != nil { |
| 105 | + |
| 106 | + data, err = io.ReadAll(os.Stdin) |
| 107 | + if err != nil { |
71 | 108 | return fmt.Errorf("reading from stdin: %w", err) |
72 | 109 | } |
| 110 | + fileName = "stdin" |
| 111 | + mimeType = "text/plain" |
73 | 112 | } else { |
74 | | - fd, err := os.Open(file) |
| 113 | + fileInfo, err := os.Stat(filePath) |
75 | 114 | if err != nil { |
76 | | - return fmt.Errorf("opening file %s: %w", file, err) |
| 115 | + return fmt.Errorf("accessing file %s: %w", filePath, err) |
77 | 116 | } |
78 | | - defer fd.Close() |
79 | | - if err := encodeFile(&message, fd, file); err != nil { |
80 | | - return fmt.Errorf("reading file %s: %w", file, err) |
| 117 | + if fileInfo.IsDir() { |
| 118 | + return fmt.Errorf("%s is a directory, not a file", filePath) |
81 | 119 | } |
| 120 | + |
| 121 | + data, err = os.ReadFile(filePath) |
| 122 | + if err != nil { |
| 123 | + return fmt.Errorf("reading file %s: %w", filePath, err) |
| 124 | + } |
| 125 | + fileName = filepath.Base(filePath) |
| 126 | + mimeType = getMimeType(filePath) |
82 | 127 | } |
83 | | - } |
84 | 128 |
|
85 | | - // Default to "waveai" block |
86 | | - isDefaultBlock := blockArg == "" |
87 | | - if isDefaultBlock { |
88 | | - blockArg = "view@waveai" |
| 129 | + maxSize, sizeStr := getMaxFileSize(mimeType) |
| 130 | + if len(data) > maxSize { |
| 131 | + return fmt.Errorf("file %s exceeds maximum size of %s for %s files", fileName, sizeStr, mimeType) |
| 132 | + } |
| 133 | + |
| 134 | + allFiles = append(allFiles, wshrpc.AIAttachedFile{ |
| 135 | + Name: fileName, |
| 136 | + Type: mimeType, |
| 137 | + Size: len(data), |
| 138 | + Data64: base64.StdEncoding.EncodeToString(data), |
| 139 | + }) |
89 | 140 | } |
90 | | - var fullORef *waveobj.ORef |
91 | | - var err error |
92 | | - if !aiNewBlockFlag { |
93 | | - fullORef, err = resolveSimpleId(blockArg) |
| 141 | + |
| 142 | + tabId := os.Getenv("WAVETERM_TABID") |
| 143 | + if tabId == "" { |
| 144 | + return fmt.Errorf("WAVETERM_TABID environment variable not set") |
94 | 145 | } |
95 | | - if (err != nil && isDefaultBlock) || aiNewBlockFlag { |
96 | | - // Create new AI block if default block doesn't exist |
97 | | - data := &wshrpc.CommandCreateBlockData{ |
98 | | - BlockDef: &waveobj.BlockDef{ |
99 | | - Meta: map[string]interface{}{ |
100 | | - waveobj.MetaKey_View: "waveai", |
101 | | - }, |
102 | | - }, |
103 | | - Focused: true, |
104 | | - } |
105 | 146 |
|
106 | | - newORef, err := wshclient.CreateBlockCommand(RpcClient, *data, &wshrpc.RpcOpts{Timeout: 2000}) |
107 | | - if err != nil { |
108 | | - return fmt.Errorf("creating AI block: %w", err) |
| 147 | + route := wshutil.MakeTabRouteId(tabId) |
| 148 | + |
| 149 | + if aiNewBlockFlag { |
| 150 | + newChatData := wshrpc.CommandWaveAIAddContextData{ |
| 151 | + NewChat: true, |
109 | 152 | } |
110 | | - fullORef = &newORef |
111 | | - // Wait for the block's route to be available |
112 | | - gotRoute, err := wshclient.WaitForRouteCommand(RpcClient, wshrpc.CommandWaitForRouteData{ |
113 | | - RouteId: wshutil.MakeFeBlockRouteId(fullORef.OID), |
114 | | - WaitMs: 4000, |
115 | | - }, &wshrpc.RpcOpts{Timeout: 5000}) |
| 153 | + err := wshclient.WaveAIAddContextCommand(RpcClient, newChatData, &wshrpc.RpcOpts{ |
| 154 | + Route: route, |
| 155 | + Timeout: rpcTimeout, |
| 156 | + }) |
116 | 157 | if err != nil { |
117 | | - return fmt.Errorf("waiting for AI block: %w", err) |
| 158 | + return fmt.Errorf("creating new chat: %w", err) |
118 | 159 | } |
119 | | - if !gotRoute { |
120 | | - return fmt.Errorf("AI block route could not be established") |
121 | | - } |
122 | | - } else if err != nil { |
123 | | - return fmt.Errorf("resolving block: %w", err) |
124 | 160 | } |
125 | 161 |
|
126 | | - // Create the route for this block |
127 | | - route := wshutil.MakeFeBlockRouteId(fullORef.OID) |
| 162 | + var smallFiles []wshrpc.AIAttachedFile |
| 163 | + var smallFilesSize int |
128 | 164 |
|
129 | | - // Then handle main message |
130 | | - if args[0] == "-" { |
131 | | - if stdinUsed { |
132 | | - return fmt.Errorf("stdin (-) can only be used once") |
133 | | - } |
134 | | - data, err := io.ReadAll(os.Stdin) |
135 | | - if err != nil { |
136 | | - return fmt.Errorf("reading from stdin: %w", err) |
137 | | - } |
138 | | - message.Write(data) |
139 | | - |
140 | | - // Also include any remaining arguments (excluding the "-" itself) |
141 | | - if len(args) > 1 { |
142 | | - if message.Len() > 0 { |
143 | | - message.WriteString(" ") |
| 165 | + for _, file := range allFiles { |
| 166 | + if file.Size > largeFileThreshold { |
| 167 | + contextData := wshrpc.CommandWaveAIAddContextData{ |
| 168 | + Files: []wshrpc.AIAttachedFile{file}, |
| 169 | + } |
| 170 | + err := wshclient.WaveAIAddContextCommand(RpcClient, contextData, &wshrpc.RpcOpts{ |
| 171 | + Route: route, |
| 172 | + Timeout: rpcTimeout, |
| 173 | + }) |
| 174 | + if err != nil { |
| 175 | + return fmt.Errorf("adding file %s: %w", file.Name, err) |
144 | 176 | } |
145 | | - message.WriteString(strings.Join(args[1:], " ")) |
| 177 | + } else { |
| 178 | + smallFilesSize += file.Size |
| 179 | + if smallFilesSize > maxBatchSize { |
| 180 | + return fmt.Errorf("small files total size exceeds maximum batch size of 7MB") |
| 181 | + } |
| 182 | + smallFiles = append(smallFiles, file) |
146 | 183 | } |
147 | | - } else { |
148 | | - message.WriteString(strings.Join(args, " ")) |
149 | 184 | } |
150 | 185 |
|
151 | | - if message.Len() == 0 { |
152 | | - return fmt.Errorf("message is empty") |
153 | | - } |
154 | | - if message.Len() > 50*1024 { |
155 | | - return fmt.Errorf("current max message size is 50k") |
| 186 | + finalContextData := wshrpc.CommandWaveAIAddContextData{ |
| 187 | + Files: smallFiles, |
| 188 | + Text: aiMessageFlag, |
| 189 | + Submit: aiSubmitFlag, |
156 | 190 | } |
157 | 191 |
|
158 | | - messageData := wshrpc.AiMessageData{ |
159 | | - Message: message.String(), |
160 | | - } |
161 | | - err = wshclient.AiSendMessageCommand(RpcClient, messageData, &wshrpc.RpcOpts{ |
| 192 | + err := wshclient.WaveAIAddContextCommand(RpcClient, finalContextData, &wshrpc.RpcOpts{ |
162 | 193 | Route: route, |
163 | | - Timeout: 2000, |
| 194 | + Timeout: rpcTimeout, |
164 | 195 | }) |
165 | 196 | if err != nil { |
166 | | - return fmt.Errorf("sending message: %w", err) |
| 197 | + return fmt.Errorf("adding context: %w", err) |
167 | 198 | } |
168 | 199 |
|
169 | 200 | return nil |
|
0 commit comments