Skip to content

Commit e61d8d9

Browse files
committed
feat(channel): improve approval messages with protocol and MCP tool context
1 parent a506cb3 commit e61d8d9

12 files changed

Lines changed: 167 additions & 70 deletions

File tree

internal/api/server_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ func TestGetApiApprovals_WithPending(t *testing.T) {
322322

323323
// Fire a request in a goroutine (blocks until resolved or timeout)
324324
go func() {
325-
_, _ = broker.Request("api.github.com", 443, 30*time.Second)
325+
_, _ = broker.Request("api.github.com", 443, "", 30*time.Second)
326326
}()
327327

328328
// Wait briefly for the request to register
@@ -371,7 +371,7 @@ func TestPostResolve_Success(t *testing.T) {
371371
// Create a pending request
372372
resultCh := make(chan channel.Response, 1)
373373
go func() {
374-
resp, _ := broker.Request("example.com", 443, 30*time.Second)
374+
resp, _ := broker.Request("example.com", 443, "", 30*time.Second)
375375
resultCh <- resp
376376
}()
377377
time.Sleep(50 * time.Millisecond)
@@ -486,7 +486,7 @@ func TestPostResolve_AlwaysAllow(t *testing.T) {
486486

487487
resultCh := make(chan channel.Response, 1)
488488
go func() {
489-
resp, _ := broker.Request("example.com", 443, 30*time.Second)
489+
resp, _ := broker.Request("example.com", 443, "", 30*time.Second)
490490
resultCh <- resp
491491
}()
492492
time.Sleep(50 * time.Millisecond)

internal/channel/broker.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,19 @@ func (b *Broker) now() time.Time {
107107
return time.Now()
108108
}
109109

110+
// RequestOption configures optional fields on an ApprovalRequest.
111+
type RequestOption func(*ApprovalRequest)
112+
113+
// WithToolArgs sets the truncated tool arguments on an MCP approval request.
114+
func WithToolArgs(args string) RequestOption {
115+
return func(r *ApprovalRequest) {
116+
r.ToolArgs = args
117+
}
118+
}
119+
110120
// Request sends an approval request to all channels and blocks until one
111121
// responds or the timeout expires. Returns the first response received.
112-
func (b *Broker) Request(dest string, port int, timeout time.Duration) (Response, error) {
122+
func (b *Broker) Request(dest string, port int, protocol string, timeout time.Duration, opts ...RequestOption) (Response, error) {
113123
id := fmt.Sprintf("req_%d", b.nextID.Add(1))
114124
ch := make(chan Response, 1)
115125

@@ -158,8 +168,12 @@ func (b *Broker) Request(dest string, port int, timeout time.Duration) (Response
158168
ID: id,
159169
Destination: dest,
160170
Port: port,
171+
Protocol: protocol,
161172
CreatedAt: b.now(),
162173
}
174+
for _, opt := range opts {
175+
opt(&req)
176+
}
163177
b.waiters[id] = waiter{ch: ch, req: req}
164178
b.mu.Unlock()
165179

internal/channel/channel.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ type ApprovalRequest struct {
6060
ID string
6161
Destination string
6262
Port int
63+
Protocol string // detected protocol (e.g. "https", "ssh", "mcp")
64+
ToolArgs string // truncated tool arguments (MCP only)
6365
CreatedAt time.Time
6466
}
6567

internal/channel/channel_test.go

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ func TestBrokerBroadcastToAllChannels(t *testing.T) {
145145

146146
broker = NewBroker([]Channel{ch1, ch2})
147147

148-
resp, err := broker.Request("evil.com", 443, 5*time.Second)
148+
resp, err := broker.Request("evil.com", 443, "", 5*time.Second)
149149
if err != nil {
150150
t.Fatalf("request: %v", err)
151151
}
@@ -184,7 +184,7 @@ func TestBrokerFirstResponseWins(t *testing.T) {
184184

185185
broker = NewBroker([]Channel{ch1, ch2})
186186

187-
resp, err := broker.Request("evil.com", 443, 5*time.Second)
187+
resp, err := broker.Request("evil.com", 443, "", 5*time.Second)
188188
if err != nil {
189189
t.Fatalf("request: %v", err)
190190
}
@@ -209,7 +209,7 @@ func TestBrokerCancelOnOtherChannels(t *testing.T) {
209209

210210
broker = NewBroker([]Channel{ch1, ch2})
211211

212-
_, err := broker.Request("evil.com", 443, 5*time.Second)
212+
_, err := broker.Request("evil.com", 443, "", 5*time.Second)
213213
if err != nil {
214214
t.Fatalf("request: %v", err)
215215
}
@@ -261,7 +261,7 @@ func TestBrokerSimultaneousResolve(t *testing.T) {
261261
done := make(chan struct{})
262262
var resp Response
263263
go func() {
264-
resp, _ = broker.Request("evil.com", 443, 5*time.Second)
264+
resp, _ = broker.Request("evil.com", 443, "", 5*time.Second)
265265
close(done)
266266
}()
267267

@@ -305,7 +305,7 @@ func TestBrokerPendingLimitExceeded(t *testing.T) {
305305
wg.Add(1)
306306
go func() {
307307
defer wg.Done()
308-
_, _ = broker.Request("example.com", 443, 2*time.Second)
308+
_, _ = broker.Request("example.com", 443, "", 2*time.Second)
309309
}()
310310
}
311311
// Wait until all 3 are registered as waiters.
@@ -314,7 +314,7 @@ func TestBrokerPendingLimitExceeded(t *testing.T) {
314314
}
315315

316316
// The 4th request should be auto-denied.
317-
resp, err := broker.Request("example.com", 443, time.Second)
317+
resp, err := broker.Request("example.com", 443, "", time.Second)
318318
if resp != ResponseDeny {
319319
t.Errorf("expected Deny, got %v", resp)
320320
}
@@ -355,7 +355,7 @@ func TestBrokerPendingLimitZeroMeansUnlimited(t *testing.T) {
355355
wg.Add(1)
356356
go func() {
357357
defer wg.Done()
358-
_, err := broker.Request("example.com", 443, 5*time.Second)
358+
_, err := broker.Request("example.com", 443, "", 5*time.Second)
359359
if errors.Is(err, ErrPendingLimitExceeded) {
360360
t.Error("should not hit pending limit with MaxPending=0")
361361
}
@@ -384,14 +384,14 @@ func TestBrokerDestinationRateLimiting(t *testing.T) {
384384

385385
// First 3 requests to the same destination should succeed.
386386
for i := 0; i < 3; i++ {
387-
_, err := broker.Request("api.example.com", 443, time.Second)
387+
_, err := broker.Request("api.example.com", 443, "", time.Second)
388388
if err != nil {
389389
t.Fatalf("request %d: unexpected error: %v", i, err)
390390
}
391391
}
392392

393393
// 4th request within the same window should be rate limited.
394-
resp, err := broker.Request("api.example.com", 443, time.Second)
394+
resp, err := broker.Request("api.example.com", 443, "", time.Second)
395395
if resp != ResponseDeny {
396396
t.Errorf("expected Deny, got %v", resp)
397397
}
@@ -400,14 +400,14 @@ func TestBrokerDestinationRateLimiting(t *testing.T) {
400400
}
401401

402402
// A different destination should still work.
403-
_, err = broker.Request("other.example.com", 443, time.Second)
403+
_, err = broker.Request("other.example.com", 443, "", time.Second)
404404
if err != nil {
405405
t.Fatalf("different destination should not be rate limited: %v", err)
406406
}
407407

408408
// Advance time past the window. The original destination should work again.
409409
fakeNow = fakeNow.Add(61 * time.Second)
410-
_, err = broker.Request("api.example.com", 443, time.Second)
410+
_, err = broker.Request("api.example.com", 443, "", time.Second)
411411
if err != nil {
412412
t.Fatalf("after window expiry, request should succeed: %v", err)
413413
}
@@ -431,7 +431,7 @@ func TestBrokerDestinationRateLimitDisabled(t *testing.T) {
431431

432432
// Should accept many requests without rate limiting.
433433
for i := 0; i < 20; i++ {
434-
_, err := broker.Request("api.example.com", 443, time.Second)
434+
_, err := broker.Request("api.example.com", 443, "", time.Second)
435435
if err != nil {
436436
t.Fatalf("request %d: unexpected error: %v", i, err)
437437
}
@@ -454,7 +454,7 @@ func TestBrokerCancelAllDeniesAllPending(t *testing.T) {
454454
// Start n requests that will block waiting for approval.
455455
for i := 0; i < n; i++ {
456456
go func() {
457-
resp, err := broker.Request("cancel-test.com", 443, 5*time.Second)
457+
resp, err := broker.Request("cancel-test.com", 443, "", 5*time.Second)
458458
results <- result{resp, err}
459459
}()
460460
}
@@ -489,7 +489,7 @@ func TestBrokerCancelAllRejectsNewRequests(t *testing.T) {
489489
broker.CancelAll()
490490

491491
start := time.Now()
492-
resp, err := broker.Request("post-cancel.com", 443, 5*time.Second)
492+
resp, err := broker.Request("post-cancel.com", 443, "", 5*time.Second)
493493
elapsed := time.Since(start)
494494

495495
if resp != ResponseDeny {
@@ -511,7 +511,7 @@ func TestBrokerCancelAllCallsCancelOnChannels(t *testing.T) {
511511

512512
// Send a request that blocks.
513513
go func() {
514-
_, _ = broker.Request("test.com", 443, 5*time.Second)
514+
_, _ = broker.Request("test.com", 443, "", 5*time.Second)
515515
}()
516516

517517
// Wait for it to register.
@@ -539,7 +539,7 @@ func TestBrokerTimeout(t *testing.T) {
539539
ch1 := newMockChannel(ChannelTelegram)
540540
broker := NewBroker([]Channel{ch1})
541541

542-
resp, err := broker.Request("slow.com", 443, 50*time.Millisecond)
542+
resp, err := broker.Request("slow.com", 443, "", 50*time.Millisecond)
543543
if err == nil {
544544
t.Fatalf("expected timeout error, got response %v", resp)
545545
}
@@ -553,7 +553,7 @@ func TestBrokerTimeoutCallsCancelOnChannels(t *testing.T) {
553553
ch2 := newMockChannel(ChannelHTTP)
554554
broker := NewBroker([]Channel{ch1, ch2})
555555

556-
_, _ = broker.Request("slow.com", 443, 50*time.Millisecond)
556+
_, _ = broker.Request("slow.com", 443, "", 50*time.Millisecond)
557557

558558
// Give cancellations time to propagate.
559559
time.Sleep(20 * time.Millisecond)
@@ -571,7 +571,7 @@ func TestBrokerTimeoutCallsCancelOnChannels(t *testing.T) {
571571
func TestBrokerNoChannelsTimesOut(t *testing.T) {
572572
broker := NewBroker(nil)
573573

574-
resp, err := broker.Request("no-channels.com", 443, 50*time.Millisecond)
574+
resp, err := broker.Request("no-channels.com", 443, "", 50*time.Millisecond)
575575
if err == nil {
576576
t.Fatal("expected timeout error with no channels")
577577
}
@@ -593,7 +593,7 @@ func TestBrokerHasWaiterAndTimedOut(t *testing.T) {
593593

594594
done := make(chan struct{})
595595
go func() {
596-
_, _ = broker.Request("test.com", 443, 50*time.Millisecond)
596+
_, _ = broker.Request("test.com", 443, "", 50*time.Millisecond)
597597
close(done)
598598
}()
599599

@@ -647,7 +647,7 @@ func TestBrokerEmptyChannelSlice(t *testing.T) {
647647
// Empty slice (not nil) should behave the same as nil channels.
648648
broker := NewBroker([]Channel{})
649649

650-
resp, err := broker.Request("empty-slice.com", 443, 50*time.Millisecond)
650+
resp, err := broker.Request("empty-slice.com", 443, "", 50*time.Millisecond)
651651
if err == nil {
652652
t.Fatal("expected timeout error with empty channel slice")
653653
}
@@ -684,7 +684,7 @@ func TestBrokerChannelPanicRecovery(t *testing.T) {
684684
broker = NewBroker([]Channel{panicCh, goodCh})
685685

686686
// The panicking channel should not prevent the good channel from resolving.
687-
resp, err := broker.Request("panic-test.com", 443, 5*time.Second)
687+
resp, err := broker.Request("panic-test.com", 443, "", 5*time.Second)
688688
if err != nil {
689689
t.Fatalf("request: %v", err)
690690
}
@@ -705,7 +705,7 @@ func TestBrokerAllChannelsPanic(t *testing.T) {
705705
broker := NewBroker([]Channel{panicCh1, panicCh2})
706706

707707
// With all channels panicking, the request should time out.
708-
resp, err := broker.Request("all-panic.com", 443, 50*time.Millisecond)
708+
resp, err := broker.Request("all-panic.com", 443, "", 50*time.Millisecond)
709709
if err == nil {
710710
t.Fatal("expected timeout error when all channels panic")
711711
}
@@ -720,7 +720,7 @@ func TestBrokerPendingRequests(t *testing.T) {
720720

721721
// Start a request that blocks.
722722
go func() {
723-
_, _ = broker.Request("pending-test.com", 443, 5*time.Second)
723+
_, _ = broker.Request("pending-test.com", 443, "", 5*time.Second)
724724
}()
725725

726726
// Wait for it to register.
@@ -780,7 +780,7 @@ func TestBrokerChannelErrorDoesNotBlockOthers(t *testing.T) {
780780

781781
broker = NewBroker([]Channel{ch1, ch2})
782782

783-
resp, err := broker.Request("test.com", 443, 5*time.Second)
783+
resp, err := broker.Request("test.com", 443, "", 5*time.Second)
784784
if err != nil {
785785
t.Fatalf("request: %v", err)
786786
}

internal/channel/http/http.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ type WebhookPayload struct {
2929
Type string `json:"type"`
3030
Destination string `json:"destination"`
3131
Port int `json:"port"`
32-
Tool string `json:"tool"`
32+
Protocol string `json:"protocol,omitempty"`
33+
Tool string `json:"tool,omitempty"`
34+
ToolArgs string `json:"tool_args,omitempty"`
3335
Timestamp time.Time `json:"timestamp"`
3436
}
3537

@@ -129,9 +131,13 @@ func (h *HTTPChannel) deliverApproval(req channel.ApprovalRequest) {
129131
Type: "approval",
130132
Destination: req.Destination,
131133
Port: req.Port,
132-
Tool: "",
134+
Protocol: req.Protocol,
135+
ToolArgs: req.ToolArgs,
133136
Timestamp: req.CreatedAt,
134137
}
138+
if req.Protocol == "mcp" {
139+
payload.Tool = req.Destination
140+
}
135141

136142
body, err := json.Marshal(payload)
137143
if err != nil {

0 commit comments

Comments
 (0)