@@ -19,6 +19,7 @@ import (
1919 "github.com/wavetermdev/waveterm/pkg/baseds"
2020 "github.com/wavetermdev/waveterm/pkg/panichandler"
2121 "github.com/wavetermdev/waveterm/pkg/remote/fileshare/wshfs"
22+ "github.com/wavetermdev/waveterm/pkg/util/envutil"
2223 "github.com/wavetermdev/waveterm/pkg/util/packetparser"
2324 "github.com/wavetermdev/waveterm/pkg/util/sigutil"
2425 "github.com/wavetermdev/waveterm/pkg/wavebase"
@@ -42,6 +43,7 @@ var connServerRouterDomainSocket bool
4243var connServerConnName string
4344var connServerDev bool
4445var ConnServerWshRouter * wshutil.WshRouter
46+ var connServerInitialEnv map [string ]string
4547
4648func init () {
4749 serverCmd .Flags ().BoolVar (& connServerRouter , "router" , false , "run in local router mode (stdio upstream)" )
@@ -120,18 +122,18 @@ func runListener(listener net.Listener, router *wshutil.WshRouter) {
120122 }
121123}
122124
123- func setupConnServerRpcClientWithRouter (router * wshutil.WshRouter ) (* wshutil.WshRpc , error ) {
125+ func setupConnServerRpcClientWithRouter (router * wshutil.WshRouter , sockName string ) (* wshutil.WshRpc , error ) {
124126 routeId := wshutil .MakeConnectionRouteId (connServerConnName )
125127 rpcCtx := wshrpc.RpcContext {
126128 RouteId : routeId ,
127129 Conn : connServerConnName ,
128130 }
129-
131+
130132 bareRouteId := wshutil .MakeRandomProcRouteId ()
131133 bareClient := wshutil .MakeWshRpc (wshrpc.RpcContext {}, & wshclient.WshServer {}, bareRouteId )
132134 router .RegisterTrustedLeaf (bareClient , bareRouteId )
133-
134- connServerClient := wshutil .MakeWshRpc (rpcCtx , wshremote .MakeRemoteRpcServerImpl (os .Stdout , router , bareClient , false ), routeId )
135+
136+ connServerClient := wshutil .MakeWshRpc (rpcCtx , wshremote .MakeRemoteRpcServerImpl (os .Stdout , router , bareClient , false , connServerInitialEnv , sockName ), routeId )
135137 router .RegisterTrustedLeaf (connServerClient , routeId )
136138 return connServerClient , nil
137139}
@@ -170,8 +172,10 @@ func serverRunRouter() error {
170172 }()
171173 router .RegisterUpstream (termProxy )
172174
175+ sockName := getRemoteDomainSocketName ()
176+
173177 // setup the connserver rpc client first
174- client , err := setupConnServerRpcClientWithRouter (router )
178+ client , err := setupConnServerRpcClientWithRouter (router , sockName )
175179 if err != nil {
176180 return fmt .Errorf ("error setting up connserver rpc client: %v" , err )
177181 }
@@ -267,23 +271,19 @@ func serverRunRouterDomainSocket(jwtToken string) error {
267271 // register the domain socket connection as upstream
268272 router .RegisterUpstream (upstreamProxy )
269273
270- // setup the connserver rpc client (leaf)
271- client , err := setupConnServerRpcClientWithRouter (router )
272- if err != nil {
273- return fmt .Errorf ("error setting up connserver rpc client: %v" , err )
274- }
275- wshfs .RpcClient = client
274+ // use the router's control RPC to authenticate with upstream
275+ controlRpc := router .GetControlRpc ()
276276
277277 // authenticate with the upstream router using the JWT
278- _ , err = wshclient .AuthenticateCommand (client , jwtToken , & wshrpc.RpcOpts {Route : wshutil .ControlRoute })
278+ _ , err = wshclient .AuthenticateCommand (controlRpc , jwtToken , & wshrpc.RpcOpts {Route : wshutil .ControlRootRoute })
279279 if err != nil {
280280 return fmt .Errorf ("error authenticating with upstream: %v" , err )
281281 }
282282 log .Printf ("authenticated with upstream router" )
283283
284284 // fetch and set JWT public key
285285 log .Printf ("trying to get JWT public key" )
286- jwtPublicKeyB64 , err := wshclient .GetJwtPublicKeyCommand (client , nil )
286+ jwtPublicKeyB64 , err := wshclient .GetJwtPublicKeyCommand (controlRpc , nil )
287287 if err != nil {
288288 return fmt .Errorf ("error getting jwt public key: %v" , err )
289289 }
@@ -297,6 +297,13 @@ func serverRunRouterDomainSocket(jwtToken string) error {
297297 }
298298 log .Printf ("got JWT public key" )
299299
300+ // now setup the connserver rpc client
301+ client , err := setupConnServerRpcClientWithRouter (router , sockName )
302+ if err != nil {
303+ return fmt .Errorf ("error setting up connserver rpc client: %v" , err )
304+ }
305+ wshfs .RpcClient = client
306+
300307 // set up the local domain socket listener for local wsh commands
301308 unixListener , err := MakeRemoteUnixListener ()
302309 if err != nil {
@@ -323,7 +330,11 @@ func serverRunRouterDomainSocket(jwtToken string) error {
323330}
324331
325332func serverRunNormal (jwtToken string ) error {
326- err := setupRpcClient (wshremote .MakeRemoteRpcServerImpl (os .Stdout , nil , nil , false ), jwtToken )
333+ sockName , err := wshutil .ExtractUnverifiedSocketName (jwtToken )
334+ if err != nil {
335+ return fmt .Errorf ("error extracting socket name from JWT: %v" , err )
336+ }
337+ err = setupRpcClient (wshremote .MakeRemoteRpcServerImpl (os .Stdout , nil , nil , false , connServerInitialEnv , sockName ), jwtToken )
327338 if err != nil {
328339 return err
329340 }
@@ -359,6 +370,8 @@ func askForJwtToken() (string, error) {
359370}
360371
361372func serverRun (cmd * cobra.Command , args []string ) error {
373+ connServerInitialEnv = envutil .PruneInitialEnv (envutil .SliceToMap (os .Environ ()))
374+
362375 var logFile * os.File
363376 if connServerDev {
364377 var err error
0 commit comments