Skip to content

Commit cfbb39a

Browse files
authored
some RPC client/server cleanup (#2709)
* announce and unannounce should propagate * route the sendmessage command * add timeouts for sending to OutputCh (appropriately) * add ctx for SendCancel function
1 parent 5549d43 commit cfbb39a

File tree

5 files changed

+86
-48
lines changed

5 files changed

+86
-48
lines changed

pkg/web/web.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package web
55

66
import (
77
"bytes"
8+
"context"
89
"encoding/base64"
910
"encoding/json"
1011
"fmt"
@@ -254,7 +255,7 @@ func handleRemoteStreamFile(w http.ResponseWriter, req *http.Request, conn strin
254255
return handleRemoteStreamFileFromCh(w, req, path, rtnCh, rpcOpts.StreamCancelFn, no404)
255256
}
256257

257-
func handleRemoteStreamFileFromCh(w http.ResponseWriter, req *http.Request, path string, rtnCh <-chan wshrpc.RespOrErrorUnion[wshrpc.FileData], streamCancelFn func(), no404 bool) error {
258+
func handleRemoteStreamFileFromCh(w http.ResponseWriter, req *http.Request, path string, rtnCh <-chan wshrpc.RespOrErrorUnion[wshrpc.FileData], streamCancelFn func(context.Context) error, no404 bool) error {
258259
firstPk := true
259260
var fileInfo *wshrpc.FileInfo
260261
loopDone := false
@@ -270,7 +271,9 @@ func handleRemoteStreamFileFromCh(w http.ResponseWriter, req *http.Request, path
270271
select {
271272
case <-ctx.Done():
272273
if streamCancelFn != nil {
273-
streamCancelFn()
274+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
275+
defer cancel()
276+
streamCancelFn(ctx)
274277
}
275278
return ctx.Err()
276279
case respUnion, ok := <-rtnCh:

pkg/wshrpc/wshclient/wshclientutil.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package wshclient
55

66
import (
7+
"context"
78
"errors"
89

910
"github.com/wavetermdev/waveterm/pkg/panichandler"
@@ -62,9 +63,8 @@ func sendRpcRequestResponseStreamHelper[T any](w *wshutil.WshRpc, command string
6263
rtnErr(respChan, err)
6364
return respChan
6465
}
65-
opts.StreamCancelFn = func() {
66-
// TODO coordinate the cancel with the for loop below
67-
reqHandler.SendCancel()
66+
opts.StreamCancelFn = func(ctx context.Context) error {
67+
return reqHandler.SendCancel(ctx)
6868
}
6969
go func() {
7070
defer func() {

pkg/wshutil/wshadapter.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ func serverImplAdapter(impl any) func(*RpcResponseHandler) bool {
9696
}
9797
rmethod := findCmdMethod(impl, cmd)
9898
if rmethod == nil {
99-
if !handler.NeedsResponse() {
99+
if !handler.NeedsResponse() && cmd != wshrpc.Command_Message {
100100
// we also send an out of band message here since this is likely unexpected and will require debugging
101101
handler.SendMessage(fmt.Sprintf("command %q method %q not found", handler.GetCommand(), methodDecl.MethodName))
102102
}

pkg/wshutil/wshrouter.go

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -169,26 +169,26 @@ func (router *WshRouter) getRouteInfo(rpcId string) *routeInfo {
169169
}
170170

171171
func (router *WshRouter) handleAnnounceMessage(msg RpcMessage, input msgAndRoute) {
172-
// if we have an upstream, send it there
173-
// if we don't (we are the terminal router), then add it to our announced route map
172+
if msg.Source != input.fromRouteId {
173+
router.Lock.Lock()
174+
router.AnnouncedRoutes[msg.Source] = input.fromRouteId
175+
router.Lock.Unlock()
176+
}
174177
upstream := router.GetUpstreamClient()
175178
if upstream != nil {
176179
upstream.SendRpcMessage(input.msgBytes, "announce-upstream")
177-
return
178180
}
179-
if msg.Source == input.fromRouteId {
180-
// not necessary to save the id mapping
181-
return
182-
}
183-
router.Lock.Lock()
184-
defer router.Lock.Unlock()
185-
router.AnnouncedRoutes[msg.Source] = input.fromRouteId
186181
}
187182

188-
func (router *WshRouter) handleUnannounceMessage(msg RpcMessage) {
183+
func (router *WshRouter) handleUnannounceMessage(msg RpcMessage, input msgAndRoute) {
189184
router.Lock.Lock()
190-
defer router.Lock.Unlock()
191185
delete(router.AnnouncedRoutes, msg.Source)
186+
router.Lock.Unlock()
187+
188+
upstream := router.GetUpstreamClient()
189+
if upstream != nil {
190+
upstream.SendRpcMessage(input.msgBytes, "unannounce-upstream")
191+
}
192192
}
193193

194194
func (router *WshRouter) getAnnouncedRoute(routeId string) string {
@@ -204,21 +204,21 @@ func (router *WshRouter) sendRoutedMessage(msgBytes []byte, routeId string) bool
204204
rpc.SendRpcMessage(msgBytes, "route")
205205
return true
206206
}
207+
localRouteId := router.getAnnouncedRoute(routeId)
208+
if localRouteId != "" {
209+
rpc := router.GetRpc(localRouteId)
210+
if rpc != nil {
211+
rpc.SendRpcMessage(msgBytes, "route-local")
212+
return true
213+
}
214+
}
207215
upstream := router.GetUpstreamClient()
208216
if upstream != nil {
209217
upstream.SendRpcMessage(msgBytes, "route-upstream")
210218
return true
211-
} else {
212-
// we are the upstream, so consult our announced routes map
213-
localRouteId := router.getAnnouncedRoute(routeId)
214-
rpc := router.GetRpc(localRouteId)
215-
if rpc == nil {
216-
log.Printf("[router] no rpc for route id %q\n", routeId)
217-
return false
218-
}
219-
rpc.SendRpcMessage(msgBytes, "route-local")
220-
return true
221219
}
220+
log.Printf("[router] no rpc for route id %q\n", routeId)
221+
return false
222222
}
223223

224224
func (router *WshRouter) runServer() {
@@ -236,7 +236,7 @@ func (router *WshRouter) runServer() {
236236
continue
237237
}
238238
if msg.Command == wshrpc.Command_RouteUnannounce {
239-
router.handleUnannounceMessage(msg)
239+
router.handleUnannounceMessage(msg, input)
240240
continue
241241
}
242242
if msg.Command != "" {
@@ -353,14 +353,22 @@ func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient, sh
353353
func (router *WshRouter) UnregisterRoute(routeId string) {
354354
log.Printf("[router] unregistering wsh route %q\n", routeId)
355355
router.Lock.Lock()
356-
defer router.Lock.Unlock()
357356
delete(router.RouteMap, routeId)
358357
// clear out announced routes
359-
for routeId, localRouteId := range router.AnnouncedRoutes {
358+
for announcedRouteId, localRouteId := range router.AnnouncedRoutes {
360359
if localRouteId == routeId {
361-
delete(router.AnnouncedRoutes, routeId)
360+
delete(router.AnnouncedRoutes, announcedRouteId)
362361
}
363362
}
363+
upstream := router.UpstreamClient
364+
router.Lock.Unlock()
365+
366+
if upstream != nil {
367+
unannounceMsg := RpcMessage{Command: wshrpc.Command_RouteUnannounce, Source: routeId}
368+
unannounceBytes, _ := json.Marshal(unannounceMsg)
369+
upstream.SendRpcMessage(unannounceBytes, "route-unannounce")
370+
}
371+
364372
go func() {
365373
defer func() {
366374
panichandler.PanicHandler("WshRouter:unregisterRoute:routegone", recover())

pkg/wshutil/wshrpc.go

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,9 @@ func (w *WshRpc) handleRequestInternal(req *RpcMessage, pprofCtx context.Context
313313
}
314314
respHandler.contextCancelFn.Store(&cancelFn)
315315
respHandler.ctx = withRespHandler(ctx, respHandler)
316-
w.registerResponseHandler(req.ReqId, respHandler)
316+
if req.ReqId != "" {
317+
w.registerResponseHandler(req.ReqId, respHandler)
318+
}
317319
isAsync := false
318320
defer func() {
319321
panicErr := panichandler.PanicHandler("handleRequest", recover())
@@ -502,7 +504,7 @@ func (handler *RpcRequestHandler) Context() context.Context {
502504
return handler.ctx
503505
}
504506

505-
func (handler *RpcRequestHandler) SendCancel() {
507+
func (handler *RpcRequestHandler) SendCancel(ctx context.Context) error {
506508
defer func() {
507509
panichandler.PanicHandler("SendCancel", recover())
508510
}()
@@ -512,8 +514,14 @@ func (handler *RpcRequestHandler) SendCancel() {
512514
AuthToken: handler.w.GetAuthToken(),
513515
}
514516
barr, _ := json.Marshal(msg) // will never fail
515-
handler.w.OutputCh <- barr
516-
handler.finalize()
517+
select {
518+
case handler.w.OutputCh <- barr:
519+
handler.finalize()
520+
return nil
521+
case <-ctx.Done():
522+
handler.finalize()
523+
return fmt.Errorf("timeout sending cancel")
524+
}
517525
}
518526

519527
func (handler *RpcRequestHandler) ResponseDone() bool {
@@ -607,24 +615,28 @@ func (handler *RpcResponseHandler) SendMessage(msg string) {
607615
Message: msg,
608616
},
609617
AuthToken: handler.w.GetAuthToken(),
618+
Route: handler.source, // send back to source
610619
}
611620
msgBytes, _ := json.Marshal(rpcMsg) // will never fail
612-
handler.w.OutputCh <- msgBytes
621+
select {
622+
case handler.w.OutputCh <- msgBytes:
623+
case <-handler.ctx.Done():
624+
}
613625
}
614626

615627
func (handler *RpcResponseHandler) SendResponse(data any, done bool) error {
616628
defer func() {
617629
panichandler.PanicHandler("SendResponse", recover())
618630
}()
619-
if handler.reqId == "" {
620-
return nil // no response expected
621-
}
622631
if handler.done.Load() {
623632
return fmt.Errorf("request already done, cannot send additional response")
624633
}
625634
if done {
626635
defer handler.close()
627636
}
637+
if handler.reqId == "" {
638+
return nil
639+
}
628640
msg := &RpcMessage{
629641
ResId: handler.reqId,
630642
Data: data,
@@ -635,25 +647,35 @@ func (handler *RpcResponseHandler) SendResponse(data any, done bool) error {
635647
if err != nil {
636648
return err
637649
}
638-
handler.w.OutputCh <- barr
639-
return nil
650+
select {
651+
case handler.w.OutputCh <- barr:
652+
return nil
653+
case <-handler.ctx.Done():
654+
return fmt.Errorf("timeout sending response")
655+
}
640656
}
641657

642658
func (handler *RpcResponseHandler) SendResponseError(err error) {
643659
defer func() {
644660
panichandler.PanicHandler("SendResponseError", recover())
645661
}()
646-
if handler.reqId == "" || handler.done.Load() {
662+
if handler.done.Load() {
647663
return
648664
}
649665
defer handler.close()
666+
if handler.reqId == "" {
667+
return
668+
}
650669
msg := &RpcMessage{
651670
ResId: handler.reqId,
652671
Error: err.Error(),
653672
AuthToken: handler.w.GetAuthToken(),
654673
}
655674
barr, _ := json.Marshal(msg) // will never fail
656-
handler.w.OutputCh <- barr
675+
select {
676+
case handler.w.OutputCh <- barr:
677+
case <-handler.ctx.Done():
678+
}
657679
}
658680

659681
func (handler *RpcResponseHandler) IsCanceled() bool {
@@ -675,11 +697,11 @@ func (handler *RpcResponseHandler) Finalize() {
675697
if handler.reqId != "" {
676698
handler.w.unregisterResponseHandler(handler.reqId)
677699
}
678-
if handler.reqId == "" || handler.done.Load() {
700+
if handler.done.Load() {
679701
return
680702
}
703+
// SendResponse with done=true will call close() via defer, even when reqId is empty
681704
handler.SendResponse(nil, true)
682-
handler.close()
683705
}
684706

685707
func (handler *RpcResponseHandler) IsDone() bool {
@@ -726,8 +748,13 @@ func (w *WshRpc) SendComplexRequest(command string, data any, opts *wshrpc.RpcOp
726748
return nil, err
727749
}
728750
handler.respCh = w.registerRpc(handler, command, opts.Route, handler.reqId)
729-
w.OutputCh <- barr
730-
return handler, nil
751+
select {
752+
case w.OutputCh <- barr:
753+
return handler, nil
754+
case <-handler.ctx.Done():
755+
handler.finalize()
756+
return nil, fmt.Errorf("timeout sending request")
757+
}
731758
}
732759

733760
func (w *WshRpc) IsServerDone() bool {

0 commit comments

Comments
 (0)