Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions cmd/wsh/cmd/wshcmd-connserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ const (

var connServerRouter bool
var connServerRouterDomainSocket bool
var connServerRouterTCP bool
var connServerConnName string
var connServerDev bool
var ConnServerWshRouter *wshutil.WshRouter
Expand All @@ -54,6 +55,7 @@ var connServerInitialEnv map[string]string
func init() {
serverCmd.Flags().BoolVar(&connServerRouter, "router", false, "run in local router mode (stdio upstream)")
serverCmd.Flags().BoolVar(&connServerRouterDomainSocket, "router-domainsocket", false, "run in local router mode (domain socket upstream)")
serverCmd.Flags().BoolVar(&connServerRouterTCP, "router-tcp", false, "run in local router mode (tcp upstream)")
serverCmd.Flags().StringVar(&connServerConnName, "conn", "", "connection name")
serverCmd.Flags().BoolVar(&connServerDev, "dev", false, "enable dev mode with file logging and PID in logs")
rootCmd.AddCommand(serverCmd)
Expand Down Expand Up @@ -394,6 +396,114 @@ func serverRunRouterDomainSocket(jwtToken string) error {
select {}
}

func serverRunRouterTCP(jwtToken string) error {
log.Printf("starting connserver router (tcp upstream)")

// extract tcp address from JWT token (unverified - we're on the client side)
tcpAddr, err := wshutil.ExtractUnverifiedSocketName(jwtToken)
if err != nil {
return fmt.Errorf("error extracting tcp address from JWT: %v", err)
}

// connect to the forwarded tcp port
conn, err := net.Dial("tcp", tcpAddr)
if err != nil {
return fmt.Errorf("error connecting to tcp upstream %s: %v", tcpAddr, err)
}

// create router
router := wshutil.NewWshRouter()
ConnServerWshRouter = router

// create proxy for the tcp connection
upstreamProxy := wshutil.MakeRpcProxy("connserver-upstream")

// goroutine to write to the tcp connection
go func() {
defer func() {
panichandler.PanicHandler("serverRunRouterTCP:WriteLoop", recover())
}()
writeErr := wshutil.AdaptOutputChToStream(upstreamProxy.ToRemoteCh, conn)
if writeErr != nil {
log.Printf("error writing to upstream tcp connection: %v\n", writeErr)
}
}()

// goroutine to read from the tcp connection
go func() {
defer func() {
panichandler.PanicHandler("serverRunRouterTCP:ReadLoop", recover())
}()
defer func() {
log.Printf("upstream tcp connection closed, shutting down")
wshutil.DoShutdown("", 0, true)
}()
wshutil.AdaptStreamToMsgCh(conn, upstreamProxy.FromRemoteCh, nil)
}()

// register the tcp connection as upstream
router.RegisterUpstream(upstreamProxy)

// use the router's control RPC to authenticate with upstream
controlRpc := router.GetControlRpc()

// authenticate with the upstream router using the JWT
_, err = wshclient.AuthenticateCommand(controlRpc, jwtToken, &wshrpc.RpcOpts{Route: wshutil.ControlRootRoute})
if err != nil {
return fmt.Errorf("error authenticating with upstream: %v", err)
}
log.Printf("authenticated with upstream router")

// fetch and set JWT public key
log.Printf("trying to get JWT public key")
jwtPublicKeyB64, err := wshclient.GetJwtPublicKeyCommand(controlRpc, nil)
if err != nil {
return fmt.Errorf("error getting jwt public key: %v", err)
}
jwtPublicKeyBytes, err := base64.StdEncoding.DecodeString(jwtPublicKeyB64)
if err != nil {
return fmt.Errorf("error decoding jwt public key: %v", err)
}
err = wavejwt.SetPublicKey(jwtPublicKeyBytes)
if err != nil {
return fmt.Errorf("error setting jwt public key: %v", err)
}
log.Printf("got JWT public key")

// now setup the connserver rpc client
client, bareRouteId, err := setupConnServerRpcClientWithRouter(router, tcpAddr)
if err != nil {
return fmt.Errorf("error setting up connserver rpc client: %v", err)
}
wshfs.RpcClient = client
wshfs.RpcClientRouteId = bareRouteId

// set up the local domain socket listener for local wsh commands
unixListener, err := MakeRemoteUnixListener()
if err != nil {
return fmt.Errorf("cannot create unix listener: %v", err)
}
log.Printf("unix listener started")
go func() {
defer func() {
panichandler.PanicHandler("serverRunRouterTCP:runListener", recover())
}()
runListener(unixListener, router)
}()

// run the sysinfo loop
go func() {
defer func() {
panichandler.PanicHandler("serverRunRouterTCP:RunSysInfoLoop", recover())
}()
wshremote.RunSysInfoLoop(client, connServerConnName)
}()
startJobLogCleanup()

log.Printf("running server (router-tcp mode), successfully started")
select {}
}

func serverRunNormal(jwtToken string) error {
sockName, err := wshutil.ExtractUnverifiedSocketName(jwtToken)
if err != nil {
Expand Down Expand Up @@ -491,6 +601,20 @@ func serverRun(cmd *cobra.Command, args []string) error {
}
return err
}
if connServerRouterTCP {
jwtToken, err := askForJwtToken()
if err != nil {
if logFile != nil {
fmt.Fprintf(logFile, "askForJwtToken error: %v\n", err)
}
return err
}
err = serverRunRouterTCP(jwtToken)
if err != nil && logFile != nil {
fmt.Fprintf(logFile, "serverRunRouterTCP error: %v\n", err)
}
return err
}
jwtToken, err := askForJwtToken()
if err != nil {
if logFile != nil {
Expand Down
3 changes: 2 additions & 1 deletion pkg/aiusechat/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,8 @@ func GenerateTabStateAndTools(ctx context.Context, tabid string, widgetAccess bo
}
if viewTypes["term"] {
tools = append(tools, GetTermGetScrollbackToolDefinition(tabid))
// tools = append(tools, GetTermCommandOutputToolDefinition(tabid))
tools = append(tools, GetTermCommandOutputToolDefinition(tabid))
tools = append(tools, GetTermSendCommandToolDefinition(tabid))
}
if viewTypes["web"] {
tools = append(tools, GetWebNavigateToolDefinition(tabid))
Expand Down
115 changes: 115 additions & 0 deletions pkg/aiusechat/tools_term.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package aiusechat

import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"strings"
Expand Down Expand Up @@ -241,6 +242,120 @@ func parseTermCommandOutputInput(input any) (*TermCommandOutputToolInput, error)
return result, nil
}

type TermSendCommandToolInput struct {
WidgetId string `json:"widget_id"`
Command string `json:"command"`
WaitForOutput bool `json:"wait_for_output,omitempty"`
}

func parseTermSendCommandInput(input any) (*TermSendCommandToolInput, error) {
result := &TermSendCommandToolInput{}
if input == nil {
return nil, fmt.Errorf("widget_id and command are required")
}
inputBytes, err := json.Marshal(input)
if err != nil {
return nil, fmt.Errorf("failed to marshal input: %w", err)
}
if err := json.Unmarshal(inputBytes, result); err != nil {
return nil, fmt.Errorf("failed to unmarshal input: %w", err)
}
if result.WidgetId == "" {
return nil, fmt.Errorf("widget_id is required")
}
if result.Command == "" {
return nil, fmt.Errorf("command is required")
}
return result, nil
}

func GetTermSendCommandToolDefinition(tabId string) uctypes.ToolDefinition {
return uctypes.ToolDefinition{
Name: "term_send_command",
DisplayName: "Run Command in Terminal",
Description: "Execute a shell command in an open terminal widget. Sends the command text followed by Enter. If wait_for_output is true, returns the terminal scrollback after a short delay so you can see the result. Requires user approval before execution.",
ToolLogName: "term:sendcommand",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"widget_id": map[string]any{
"type": "string",
"description": "8-character widget ID of the terminal widget to run the command in",
},
"command": map[string]any{
"type": "string",
"description": "The shell command to execute",
},
"wait_for_output": map[string]any{
"type": "boolean",
"description": "If true, wait briefly and return terminal output after the command runs (default: true)",
},
},
"required": []string{"widget_id", "command"},
"additionalProperties": false,
},
ToolCallDesc: func(input any, output any, toolUseData *uctypes.UIMessageDataToolUse) string {
parsed, err := parseTermSendCommandInput(input)
if err != nil {
return fmt.Sprintf("error parsing input: %v", err)
}
return fmt.Sprintf("running in terminal %s: %s", parsed.WidgetId, parsed.Command)
},
ToolApproval: func(input any) string {
return uctypes.ApprovalNeedsApproval
},
ToolAnyCallback: func(input any, toolUseData *uctypes.UIMessageDataToolUse) (any, error) {
parsed, err := parseTermSendCommandInput(input)
if err != nil {
return nil, err
}

ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelFn()

fullBlockId, err := wcore.ResolveBlockIdFromPrefix(ctx, tabId, parsed.WidgetId)
if err != nil {
return nil, fmt.Errorf("terminal widget %q not found: %w", parsed.WidgetId, err)
}

inputBytes := []byte(parsed.Command + "\r")
inputData64 := base64.StdEncoding.EncodeToString(inputBytes)

rpcClient := wshclient.GetBareRpcClient()
err = wshclient.ControllerInputCommand(
rpcClient,
wshrpc.CommandBlockInputData{
BlockId: fullBlockId,
InputData64: inputData64,
},
&wshrpc.RpcOpts{},
)
if err != nil {
return nil, fmt.Errorf("failed to send command to terminal: %w", err)
}

waitForOutput := parsed.WaitForOutput || true
if waitForOutput {
time.Sleep(2 * time.Second)
output, err := getTermScrollbackOutput(
tabId,
parsed.WidgetId,
wshrpc.CommandTermGetScrollbackLinesData{
LineStart: 0,
LineEnd: 50,
},
)
if err != nil {
return map[string]any{"sent": true, "note": "command sent; could not read output"}, nil
}
return map[string]any{"sent": true, "output": output}, nil
}

return map[string]any{"sent": true}, nil
},
}
}

func GetTermCommandOutputToolDefinition(tabId string) uctypes.ToolDefinition {
return uctypes.ToolDefinition{
Name: "term_command_output",
Expand Down
5 changes: 2 additions & 3 deletions pkg/aiusechat/usechat-prompts.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@ var SystemPromptText_OpenAI = strings.Join([]string{
`If a request would execute dangerous or destructive actions, warn briefly and provide a safer alternative.`,
`If output is very long, prefer a brief summary plus a copy-ready fenced block or offer a follow-up chunking strategy.`,

`You can write and edit local files on disk using available tools, but you cannot execute shell commands, run code in the terminal, or read/write remote files.`,
`If the user asks you to execute commands or run code, or deal with remote files say that these features aren't available yet AND make sure to say that they are coming soon (stay tuned for updates).`,
`Instead, show them exactly what command or code they could copy-paste to run manually.`,
`You can read and write local files, capture screenshots, browse the web, read terminal scrollback, and execute shell commands in open terminal widgets using the term_send_command tool. Always prefer to gather context (read files, get scrollback) before acting. For destructive or side-effectful commands, briefly explain what the command will do before calling the tool — the user will be shown an approval prompt before execution.`,
`If the user asks you to run a command and a terminal widget is available, call term_send_command rather than just showing them the command to copy-paste. If no terminal widget is open, tell them to open one and then re-ask.`,

// Final reminder
`You have NO API access to widgets or Wave unless provided via an explicit tool.`,
Expand Down
Loading