Skip to content

Commit 4abf7a3

Browse files
committed
feat(session): support local logs and command cancellation
1 parent 4145920 commit 4abf7a3

7 files changed

Lines changed: 355 additions & 33 deletions

File tree

cmd/new.go

Lines changed: 146 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,21 @@ package cmd
33

44
import (
55
"context"
6+
"encoding/json"
67
"errors"
78
"fmt"
89
"io"
910
"os"
1011
"os/signal"
12+
"sync"
1113
"syscall"
1214
"time"
1315

1416
"github.com/RealWhyKnot/Handoff/internal/audit"
1517
"github.com/RealWhyKnot/Handoff/internal/capabilities"
1618
"github.com/RealWhyKnot/Handoff/internal/dispatch"
1719
"github.com/RealWhyKnot/Handoff/internal/relay"
20+
"github.com/RealWhyKnot/Handoff/internal/supportlog"
1821
)
1922

2023
// Version is stamped from main.go.
@@ -24,6 +27,7 @@ var Version = "0.1.0"
2427
// and runs the host agent loop until Ctrl+C or relay disconnect.
2528
func New(args []string) {
2629
relayBase := defaultRelay()
30+
supportlog.Printf("session start relay=%s version=%s", relayBase, Version)
2731
fmt.Println("relay:", relayBase)
2832

2933
ctx, cancel := context.WithCancel(context.Background())
@@ -34,10 +38,12 @@ func New(args []string) {
3438
mint, err := relay.Mint(mintCtx, relayBase)
3539
mintCancel()
3640
if err != nil {
41+
supportlog.Printf("mint failed: %v", err)
3742
fmt.Fprintln(os.Stderr, "could not mint session:", err)
3843
os.Exit(1)
3944
}
4045
sid := shortSid(mint.ViewToken)
46+
supportlog.Printf("mint ok sid=%s view_url=%s", sid, mint.ViewURL)
4147
fmt.Println()
4248
fmt.Println("session live -- share the view URL with your helper:")
4349
fmt.Println(" ", mint.ViewURL)
@@ -48,6 +54,7 @@ func New(args []string) {
4854
// Audit log.
4955
log, err := audit.New()
5056
if err != nil {
57+
supportlog.Printf("audit log unavailable: %v", err)
5158
fmt.Fprintln(os.Stderr, "warning: audit log unavailable:", err)
5259
}
5360
defer func() {
@@ -59,13 +66,15 @@ func New(args []string) {
5966
// Dispatcher with all capabilities registered.
6067
router := dispatch.New()
6168
capabilities.RegisterAll(router)
69+
supportlog.Printf("capabilities registered count=%d", len(router.Kinds()))
6270
fmt.Printf("ready -- %d capabilities registered\n\n", len(router.Kinds()))
6371

6472
// Handle Ctrl+C.
6573
sig := make(chan os.Signal, 1)
6674
signal.Notify(sig, os.Interrupt, syscall.SIGTERM)
6775
go func() {
6876
<-sig
77+
supportlog.Printf("shutdown requested by signal")
6978
fmt.Println("\nshutting down...")
7079
cancel()
7180
}()
@@ -75,55 +84,174 @@ func New(args []string) {
7584
bridge, err := relay.Dial(dialCtx, relayBase, mint.WriteToken)
7685
dialCancel()
7786
if err != nil {
87+
supportlog.Printf("bridge dial failed sid=%s: %v", sid, err)
7888
fmt.Fprintln(os.Stderr, "could not open bridge:", err)
7989
os.Exit(1)
8090
}
8191
defer bridge.Close()
92+
supportlog.Printf("bridge connected sid=%s", sid)
8293

8394
hostname, _ := os.Hostname()
8495
if err := bridge.SendHello(ctx, hostname, Version, router.Kinds()); err != nil {
96+
supportlog.Printf("hello failed sid=%s: %v", sid, err)
8597
fmt.Fprintln(os.Stderr, "hello failed:", err)
8698
os.Exit(1)
8799
}
100+
supportlog.Printf("hello sent sid=%s hostname=%s", sid, hostname)
88101

89-
// Main loop: receive commands, dispatch, send results.
102+
jobs := newJobRunner(ctx, sid, router, bridge, log)
103+
defer jobs.CancelAll()
104+
105+
// Main loop: receive commands and hand them to cancellable workers.
90106
for {
91107
cmd, err := bridge.Recv(ctx)
92108
if err != nil {
93109
if errors.Is(err, context.Canceled) || errors.Is(err, io.EOF) {
110+
supportlog.Printf("recv ended sid=%s err=%v", sid, err)
94111
return
95112
}
113+
supportlog.Printf("recv error sid=%s: %v", sid, err)
96114
fmt.Fprintln(os.Stderr, "recv error:", err)
97115
return
98116
}
117+
supportlog.Printf("command received sid=%s id=%s kind=%s", sid, cmd.ID, cmd.Kind)
118+
if cmd.Kind == "control.cancel" {
119+
targetID := readStringExtra(cmd.Extras, "target_id")
120+
if targetID == "" {
121+
supportlog.Printf("cancel control missing target sid=%s id=%s", sid, cmd.ID)
122+
continue
123+
}
124+
if jobs.Cancel(targetID) {
125+
fmt.Printf("[cancel] %s\n", targetID)
126+
supportlog.Printf("command cancel requested sid=%s id=%s", sid, targetID)
127+
} else {
128+
supportlog.Printf("cancel target not active sid=%s id=%s", sid, targetID)
129+
}
130+
continue
131+
}
99132
fmt.Printf("[cmd] %s kind=%s\n", cmd.ID, cmd.Kind)
133+
jobs.Start(cmd)
134+
}
135+
}
100136

101-
out := router.Dispatch(ctx, cmd.Kind, cmd.Extras)
137+
type jobRunner struct {
138+
rootCtx context.Context
139+
sid string
140+
router *dispatch.Router
141+
bridge *relay.Bridge
142+
audit *audit.Logger
102143

103-
// Audit.
104-
if log != nil {
105-
res := "ok"
106-
if !out.OK {
107-
res = "err"
144+
mu sync.Mutex
145+
active map[string]context.CancelFunc
146+
}
147+
148+
func newJobRunner(rootCtx context.Context, sid string, router *dispatch.Router, bridge *relay.Bridge, auditLog *audit.Logger) *jobRunner {
149+
return &jobRunner{
150+
rootCtx: rootCtx,
151+
sid: sid,
152+
router: router,
153+
bridge: bridge,
154+
audit: auditLog,
155+
active: map[string]context.CancelFunc{},
156+
}
157+
}
158+
159+
func (r *jobRunner) Start(cmd *relay.Command) {
160+
var jobCtx context.Context
161+
var cancel context.CancelFunc
162+
if cmd.TimeoutMS > 0 {
163+
jobCtx, cancel = context.WithTimeout(r.rootCtx, time.Duration(cmd.TimeoutMS)*time.Millisecond)
164+
} else {
165+
jobCtx, cancel = context.WithCancel(r.rootCtx)
166+
}
167+
168+
r.mu.Lock()
169+
r.active[cmd.ID] = cancel
170+
r.mu.Unlock()
171+
172+
go func() {
173+
defer func() {
174+
r.mu.Lock()
175+
delete(r.active, cmd.ID)
176+
r.mu.Unlock()
177+
cancel()
178+
}()
179+
180+
out := r.router.Dispatch(jobCtx, cmd.Kind, cmd.Extras)
181+
if err := jobCtx.Err(); err != nil {
182+
out.OK = false
183+
switch {
184+
case errors.Is(err, context.DeadlineExceeded):
185+
out.Error = "command timed out"
186+
case errors.Is(err, context.Canceled):
187+
out.Error = "command cancelled"
188+
default:
189+
out.Error = err.Error()
108190
}
109-
_ = log.Write(audit.Entry{
110-
SessionID: sid,
111-
Capability: cmd.Kind,
112-
Args: cmd.Extras,
113-
Consent: "session",
114-
Result: res,
115-
ElapsedMs: out.ElapsedMs,
116-
Detail: out.Error,
117-
})
118191
}
119192

120-
// Result back to the relay.
121-
if err := bridge.SendCommandResult(ctx, cmd.ID, out.OK, out.Result, out.Error, out.ElapsedMs); err != nil {
193+
r.writeAudit(cmd, out)
194+
if err := r.bridge.SendCommandResult(r.rootCtx, cmd.ID, out.OK, out.Result, out.Error, out.ElapsedMs); err != nil {
195+
supportlog.Printf("send result failed sid=%s id=%s: %v", r.sid, cmd.ID, err)
122196
fmt.Fprintln(os.Stderr, "could not send result:", err)
123197
return
124198
}
199+
supportlog.Printf("command result sent sid=%s id=%s ok=%v elapsed_ms=%d", r.sid, cmd.ID, out.OK, out.ElapsedMs)
125200
fmt.Printf(" -> ok=%v elapsed=%dms\n", out.OK, out.ElapsedMs)
201+
}()
202+
}
203+
204+
func (r *jobRunner) Cancel(id string) bool {
205+
r.mu.Lock()
206+
cancel := r.active[id]
207+
r.mu.Unlock()
208+
if cancel == nil {
209+
return false
210+
}
211+
cancel()
212+
return true
213+
}
214+
215+
func (r *jobRunner) CancelAll() {
216+
r.mu.Lock()
217+
cancels := make([]context.CancelFunc, 0, len(r.active))
218+
for _, cancel := range r.active {
219+
cancels = append(cancels, cancel)
220+
}
221+
r.mu.Unlock()
222+
for _, cancel := range cancels {
223+
cancel()
224+
}
225+
}
226+
227+
func (r *jobRunner) writeAudit(cmd *relay.Command, out dispatch.Outcome) {
228+
if r.audit == nil {
229+
return
230+
}
231+
res := "ok"
232+
if !out.OK {
233+
res = "err"
234+
}
235+
_ = r.audit.Write(audit.Entry{
236+
SessionID: r.sid,
237+
Capability: cmd.Kind,
238+
Args: cmd.Extras,
239+
Consent: "session",
240+
Result: res,
241+
ElapsedMs: out.ElapsedMs,
242+
Detail: out.Error,
243+
})
244+
}
245+
246+
func readStringExtra(extras map[string]json.RawMessage, name string) string {
247+
if extras == nil {
248+
return ""
249+
}
250+
var value string
251+
if raw, ok := extras[name]; ok {
252+
_ = json.Unmarshal(raw, &value)
126253
}
254+
return value
127255
}
128256

129257
func defaultRelay() string {

internal/relay/relay.go

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"net/url"
1818
"sort"
1919
"strings"
20+
"sync"
2021
"time"
2122

2223
"github.com/coder/websocket"
@@ -34,9 +35,10 @@ type MintResponse struct {
3435
// `id` and `kind` fields are fixed; any other keys in the JSON payload
3536
// are accessible via Extras for kind-specific arguments.
3637
type Command struct {
37-
ID string `json:"id"`
38-
Kind string `json:"kind"`
39-
Extras map[string]json.RawMessage `json:"-"`
38+
ID string `json:"id"`
39+
Kind string `json:"kind"`
40+
TimeoutMS int `json:"timeout_ms,omitempty"`
41+
Extras map[string]json.RawMessage `json:"-"`
4042
}
4143

4244
// UnmarshalJSON decodes a Command keeping the kind-specific fields as
@@ -54,6 +56,10 @@ func (c *Command) UnmarshalJSON(data []byte) error {
5456
_ = json.Unmarshal(v, &c.Kind)
5557
delete(raw, "kind")
5658
}
59+
if v, ok := raw["timeout_ms"]; ok {
60+
_ = json.Unmarshal(v, &c.TimeoutMS)
61+
delete(raw, "timeout_ms")
62+
}
5763
c.Extras = raw
5864
return nil
5965
}
@@ -91,7 +97,8 @@ func Mint(ctx context.Context, baseURL string) (*MintResponse, error) {
9197

9298
// Bridge is the long-lived WebSocket between this host and the relay.
9399
type Bridge struct {
94-
conn *websocket.Conn
100+
conn *websocket.Conn
101+
writeMu sync.Mutex
95102
}
96103

97104
// Dial opens the WebSocket to the relay's /ws endpoint with the write
@@ -185,6 +192,8 @@ func (b *Bridge) send(ctx context.Context, kind string, payload interface{}) err
185192
if err := json.NewEncoder(&buf).Encode(ev); err != nil {
186193
return err
187194
}
195+
b.writeMu.Lock()
196+
defer b.writeMu.Unlock()
188197
return b.conn.Write(ctx, websocket.MessageText, bytes.TrimRight(buf.Bytes(), "\n"))
189198
}
190199

internal/relay/relay_test.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,19 @@ func TestSendHelloIncludesProtocolAndCapabilities(t *testing.T) {
8383
t.Fatalf("capabilities = %#v, want %#v", event.Payload.Capabilities, wantCaps)
8484
}
8585
}
86+
87+
func TestCommandUnmarshalCapturesTimeoutOutsideExtras(t *testing.T) {
88+
var cmd Command
89+
if err := json.Unmarshal([]byte(`{"id":"cmd-1","kind":"ps.exec","timeout_ms":1500,"script":"Start-Sleep 30"}`), &cmd); err != nil {
90+
t.Fatalf("Unmarshal: %v", err)
91+
}
92+
if cmd.ID != "cmd-1" || cmd.Kind != "ps.exec" || cmd.TimeoutMS != 1500 {
93+
t.Fatalf("command = %#v", cmd)
94+
}
95+
if _, ok := cmd.Extras["timeout_ms"]; ok {
96+
t.Fatal("timeout_ms should not be forwarded to capability extras")
97+
}
98+
if _, ok := cmd.Extras["script"]; !ok {
99+
t.Fatal("script extra missing")
100+
}
101+
}

0 commit comments

Comments
 (0)