Skip to content

Commit 27ac131

Browse files
authored
Implement AI "stop" -- in the client, open ai responses/chat, and gemini backends (#2704)
1 parent 2450ed0 commit 27ac131

17 files changed

Lines changed: 2028 additions & 115 deletions

aiprompts/conn-arch.md

Lines changed: 612 additions & 0 deletions
Large diffs are not rendered by default.

aiprompts/fe-conn-arch.md

Lines changed: 1007 additions & 0 deletions
Large diffs are not rendered by default.

frontend/app/aipanel/aipanel.tsx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import { waveAIHasSelection } from "@/app/aipanel/waveai-focus-utils";
66
import { ErrorBoundary } from "@/app/element/errorboundary";
77
import { atoms, getSettingsKeyAtom } from "@/app/store/global";
88
import { globalStore } from "@/app/store/jotaiStore";
9-
import { useTabModel } from "@/app/store/tab-model";
9+
import { maybeUseTabModel } from "@/app/store/tab-model";
1010
import { checkKeyPressed, keydownWrapper } from "@/util/keyutil";
1111
import { isMacOS, isWindows } from "@/util/platformutil";
1212
import { cn } from "@/util/util";
@@ -255,7 +255,7 @@ const AIPanelComponentInner = memo(() => {
255255
const isFocused = jotai.useAtomValue(model.isWaveAIFocusedAtom);
256256
const telemetryEnabled = jotai.useAtomValue(getSettingsKeyAtom("telemetry:enabled")) ?? false;
257257
const isPanelVisible = jotai.useAtomValue(model.getPanelVisibleAtom());
258-
const tabModel = useTabModel();
258+
const tabModel = maybeUseTabModel();
259259
const defaultMode = jotai.useAtomValue(getSettingsKeyAtom("waveai:defaultmode")) ?? "waveai@balanced";
260260
const aiModeConfigs = jotai.useAtomValue(model.aiModeConfigs);
261261

frontend/app/aipanel/aipanelinput.tsx

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -169,24 +169,35 @@ export const AIPanelInput = memo(({ onSubmit, status, model }: AIPanelInputProps
169169
<i className="fa fa-paperclip text-sm"></i>
170170
</button>
171171
</Tooltip>
172-
<Tooltip content="Send message (Enter)" placement="top" divClassName="absolute bottom-1.5 right-1">
173-
<button
174-
type="submit"
175-
disabled={status !== "ready" || !input.trim()}
176-
className={cn(
177-
"w-5 h-5 transition-colors flex items-center justify-center",
178-
status !== "ready" || !input.trim()
179-
? "text-gray-400"
180-
: "text-accent/80 hover:text-accent cursor-pointer"
181-
)}
182-
>
183-
{status === "streaming" ? (
184-
<i className="fa fa-spinner fa-spin text-sm"></i>
185-
) : (
172+
{status === "streaming" ? (
173+
<Tooltip content="Stop Response" placement="top" divClassName="absolute bottom-1.5 right-1">
174+
<button
175+
type="button"
176+
onClick={() => model.stopResponse()}
177+
className={cn(
178+
"w-5 h-5 transition-colors flex items-center justify-center",
179+
"text-green-500 hover:text-green-400 cursor-pointer"
180+
)}
181+
>
182+
<i className="fa fa-square text-sm"></i>
183+
</button>
184+
</Tooltip>
185+
) : (
186+
<Tooltip content="Send message (Enter)" placement="top" divClassName="absolute bottom-1.5 right-1">
187+
<button
188+
type="submit"
189+
disabled={status !== "ready" || !input.trim()}
190+
className={cn(
191+
"w-5 h-5 transition-colors flex items-center justify-center",
192+
status !== "ready" || !input.trim()
193+
? "text-gray-400"
194+
: "text-accent/80 hover:text-accent cursor-pointer"
195+
)}
196+
>
186197
<i className="fa fa-paper-plane text-sm"></i>
187-
)}
188-
</button>
189-
</Tooltip>
198+
</button>
199+
</Tooltip>
200+
)}
190201
</div>
191202
</form>
192203
</div>

frontend/app/aipanel/waveai-model.tsx

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,12 @@ export interface DroppedFile {
4242

4343
export class WaveAIModel {
4444
private static instance: WaveAIModel | null = null;
45-
private inputRef: React.RefObject<AIPanelInputRef> | null = null;
46-
private scrollToBottomCallback: (() => void) | null = null;
47-
private useChatSendMessage: UseChatSendMessageType | null = null;
48-
private useChatSetMessages: UseChatSetMessagesType | null = null;
49-
private useChatStatus: ChatStatus = "ready";
50-
private useChatStop: (() => void) | null = null;
45+
inputRef: React.RefObject<AIPanelInputRef> | null = null;
46+
scrollToBottomCallback: (() => void) | null = null;
47+
useChatSendMessage: UseChatSendMessageType | null = null;
48+
useChatSetMessages: UseChatSetMessagesType | null = null;
49+
useChatStatus: ChatStatus = "ready";
50+
useChatStop: (() => void) | null = null;
5151
// Used for injecting Wave-specific message data into DefaultChatTransport's prepareSendMessagesRequest
5252
realMessage: AIMessage | null = null;
5353
orefContext: ORef;
@@ -324,6 +324,29 @@ export class WaveAIModel {
324324
}
325325
}
326326

327+
async reloadChatFromBackend(chatIdValue: string): Promise<WaveUIMessage[]> {
328+
const chatData = await RpcApi.GetWaveAIChatCommand(TabRpcClient, { chatid: chatIdValue });
329+
const messages: UIMessage[] = chatData?.messages ?? [];
330+
globalStore.set(this.isChatEmptyAtom, messages.length === 0);
331+
return messages as WaveUIMessage[];
332+
}
333+
334+
async stopResponse() {
335+
this.useChatStop?.();
336+
await new Promise((resolve) => setTimeout(resolve, 500));
337+
338+
const chatIdValue = globalStore.get(this.chatId);
339+
if (!chatIdValue) {
340+
return;
341+
}
342+
try {
343+
const messages = await this.reloadChatFromBackend(chatIdValue);
344+
this.useChatSetMessages?.(messages);
345+
} catch (error) {
346+
console.error("Failed to reload chat after stop:", error);
347+
}
348+
}
349+
327350
getAndClearMessage(): AIMessage | null {
328351
const msg = this.realMessage;
329352
this.realMessage = null;
@@ -448,10 +471,7 @@ export class WaveAIModel {
448471
}
449472

450473
try {
451-
const chatData = await RpcApi.GetWaveAIChatCommand(TabRpcClient, { chatid: chatIdValue });
452-
const messages: UIMessage[] = chatData?.messages ?? [];
453-
globalStore.set(this.isChatEmptyAtom, messages.length === 0);
454-
return messages as WaveUIMessage[]; // this is safe just different RPC type vs the FE type, but they are compatible
474+
return await this.reloadChatFromBackend(chatIdValue);
455475
} catch (error) {
456476
console.error("Failed to load chat:", error);
457477
this.setError("Failed to load chat. Starting new chat...");

frontend/app/app.tsx

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import { ClientModel } from "@/app/store/client-model";
55
import { GlobalModel } from "@/app/store/global-model";
6+
import { getTabModelByTabId, TabModelContext } from "@/app/store/tab-model";
67
import { Workspace } from "@/app/workspace/workspace";
78
import { ContextMenuModel } from "@/store/contextmenu";
89
import { atoms, createBlock, getSettingsPrefixAtom, globalStore, isDev, removeFlashError } from "@/store/global";
@@ -31,12 +32,15 @@ const dlog = debug("wave:app");
3132
const focusLog = debug("wave:focus");
3233

3334
const App = ({ onFirstRender }: { onFirstRender: () => void }) => {
35+
const tabId = useAtomValue(atoms.staticTabId);
3436
useEffect(() => {
3537
onFirstRender();
3638
}, []);
3739
return (
3840
<Provider store={globalStore}>
39-
<AppInner />
41+
<TabModelContext.Provider value={getTabModelByTabId(tabId)}>
42+
<AppInner />
43+
</TabModelContext.Provider>
4044
</Provider>
4145
);
4246
};

frontend/app/store/tab-model.ts

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ import { globalStore } from "./jotaiStore";
77
import * as WOS from "./wos";
88

99
const tabModelCache = new Map<string, TabModel>();
10-
const activeTabIdAtom = atom<string>(null) as PrimitiveAtom<string>;
10+
export const activeTabIdAtom = atom<string>(null) as PrimitiveAtom<string>;
1111

12-
class TabModel {
12+
export class TabModel {
1313
tabId: string;
1414
tabAtom: Atom<Tab>;
1515
tabNumBlocksAtom: Atom<number>;
@@ -40,7 +40,7 @@ class TabModel {
4040
}
4141
}
4242

43-
function getTabModelByTabId(tabId: string): TabModel {
43+
export function getTabModelByTabId(tabId: string): TabModel {
4444
let model = tabModelCache.get(tabId);
4545
if (model == null) {
4646
model = new TabModel(tabId);
@@ -49,22 +49,24 @@ function getTabModelByTabId(tabId: string): TabModel {
4949
return model;
5050
}
5151

52-
function getActiveTabModel(): TabModel | null {
52+
export function getActiveTabModel(): TabModel | null {
5353
const activeTabId = globalStore.get(activeTabIdAtom);
5454
if (activeTabId == null) {
5555
return null;
5656
}
5757
return getTabModelByTabId(activeTabId);
5858
}
5959

60-
const TabModelContext = createContext<TabModel | undefined>(undefined);
60+
export const TabModelContext = createContext<TabModel | undefined>(undefined);
6161

62-
function useTabModel(): TabModel {
62+
export function useTabModel(): TabModel {
6363
const model = useContext(TabModelContext);
6464
if (model == null) {
6565
throw new Error("useTabModel must be used within a TabModelProvider");
6666
}
6767
return model;
6868
}
6969

70-
export { activeTabIdAtom, getActiveTabModel, getTabModelByTabId, TabModel, TabModelContext, useTabModel };
70+
export function maybeUseTabModel(): TabModel {
71+
return useContext(TabModelContext);
72+
}

frontend/app/workspace/workspace.tsx

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import { CenteredDiv } from "@/app/element/quickelems";
77
import { ModalsRenderer } from "@/app/modals/modalsrenderer";
88
import { TabBar } from "@/app/tab/tabbar";
99
import { TabContent } from "@/app/tab/tabcontent";
10-
import { getTabModelByTabId, TabModelContext } from "@/app/store/tab-model";
1110
import { Widgets } from "@/app/workspace/widgets";
1211
import { WorkspaceLayoutModel } from "@/app/workspace/workspace-layout-model";
1312
import { atoms, getApi } from "@/store/global";
@@ -70,7 +69,7 @@ const WorkspaceElem = memo(() => {
7069
className="overflow-hidden"
7170
>
7271
<div ref={aiPanelWrapperRef} className="w-full h-full">
73-
<AIPanel />
72+
{tabId !== "" && <AIPanel />}
7473
</div>
7574
</Panel>
7675
<PanelResizeHandle className="w-0.5 bg-transparent hover:bg-zinc-500/20 transition-colors" />
@@ -79,9 +78,7 @@ const WorkspaceElem = memo(() => {
7978
<CenteredDiv>No Active Tab</CenteredDiv>
8079
) : (
8180
<div className="flex flex-row h-full">
82-
<TabModelContext.Provider value={getTabModelByTabId(tabId)}>
83-
<TabContent key={tabId} tabId={tabId} />
84-
</TabModelContext.Provider>
81+
<TabContent key={tabId} tabId={tabId} />
8582
<Widgets />
8683
</div>
8784
)}

pkg/aiusechat/chatstore/chatstore.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package chatstore
55

66
import (
77
"fmt"
8+
"slices"
89
"sync"
910

1011
"github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes"
@@ -109,3 +110,20 @@ func (cs *ChatStore) PostMessage(chatId string, aiOpts *uctypes.AIOptsType, mess
109110

110111
return nil
111112
}
113+
114+
func (cs *ChatStore) RemoveMessage(chatId string, messageId string) bool {
115+
cs.lock.Lock()
116+
defer cs.lock.Unlock()
117+
118+
chat := cs.chats[chatId]
119+
if chat == nil {
120+
return false
121+
}
122+
123+
initialLen := len(chat.NativeMessages)
124+
chat.NativeMessages = slices.DeleteFunc(chat.NativeMessages, func(msg uctypes.GenAIMessage) bool {
125+
return msg.GetMessageId() == messageId
126+
})
127+
128+
return len(chat.NativeMessages) < initialLen
129+
}

pkg/aiusechat/gemini/gemini-backend.go

Lines changed: 24 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -42,45 +42,6 @@ func ensureAltSse(endpoint string) (string, error) {
4242
return endpoint, nil
4343
}
4444

45-
// UpdateToolUseData updates the tool use data for a specific tool call in the chat
46-
func UpdateToolUseData(chatId string, toolCallId string, toolUseData uctypes.UIMessageDataToolUse) error {
47-
chat := chatstore.DefaultChatStore.Get(chatId)
48-
if chat == nil {
49-
return fmt.Errorf("chat not found: %s", chatId)
50-
}
51-
52-
for _, genMsg := range chat.NativeMessages {
53-
chatMsg, ok := genMsg.(*GeminiChatMessage)
54-
if !ok {
55-
continue
56-
}
57-
58-
for i, part := range chatMsg.Parts {
59-
if part.FunctionCall != nil && part.ToolUseData != nil && part.ToolUseData.ToolCallId == toolCallId {
60-
// Update the message with new tool use data
61-
updatedMsg := &GeminiChatMessage{
62-
MessageId: chatMsg.MessageId,
63-
Role: chatMsg.Role,
64-
Parts: make([]GeminiMessagePart, len(chatMsg.Parts)),
65-
Usage: chatMsg.Usage,
66-
}
67-
copy(updatedMsg.Parts, chatMsg.Parts)
68-
updatedMsg.Parts[i].ToolUseData = &toolUseData
69-
70-
aiOpts := &uctypes.AIOptsType{
71-
APIType: chat.APIType,
72-
Model: chat.Model,
73-
APIVersion: chat.APIVersion,
74-
}
75-
76-
return chatstore.DefaultChatStore.PostMessage(chatId, aiOpts, updatedMsg)
77-
}
78-
}
79-
}
80-
81-
return fmt.Errorf("tool call with ID %s not found in chat %s", toolCallId, chatId)
82-
}
83-
8445
// appendPartToLastUserMessage appends a text part to the last user message in the contents slice
8546
func appendPartToLastUserMessage(contents []GeminiContent, text string) {
8647
for i := len(contents) - 1; i >= 0; i-- {
@@ -347,6 +308,14 @@ func processGeminiStream(
347308
if errors.Is(err, io.EOF) {
348309
break
349310
}
311+
if sseHandler.Err() != nil {
312+
partialMsg := extractPartialGeminiMessage(msgID, textBuilder.String())
313+
return &uctypes.WaveStopReason{
314+
Kind: uctypes.StopKindCanceled,
315+
ErrorType: "client_disconnect",
316+
ErrorText: "client disconnected",
317+
}, partialMsg, nil
318+
}
350319
_ = sseHandler.AiMsgError(fmt.Sprintf("stream decode error: %v", err))
351320
return &uctypes.WaveStopReason{
352321
Kind: uctypes.StopKindError,
@@ -512,3 +481,19 @@ func processGeminiStream(
512481

513482
return stopReason, assistantMsg, nil
514483
}
484+
485+
func extractPartialGeminiMessage(msgID string, text string) *GeminiChatMessage {
486+
if text == "" {
487+
return nil
488+
}
489+
490+
return &GeminiChatMessage{
491+
MessageId: msgID,
492+
Role: "model",
493+
Parts: []GeminiMessagePart{
494+
{
495+
Text: text,
496+
},
497+
},
498+
}
499+
}

0 commit comments

Comments
 (0)