22package cmd
33
44import (
5+ "bytes"
56 "context"
67 "crypto/rand"
78 "encoding/base64"
@@ -44,7 +45,7 @@ func Tunnel(args []string) {
4445 opts , err := parseTunnelArgs (args )
4546 if err != nil {
4647 fmt .Fprintln (os .Stderr , "handoff tunnel:" , err )
47- fmt .Fprintln (os .Stderr , "usage: handoff tunnel <connect-token> [--local-port PORT] [--relay URL]" )
48+ fmt .Fprintln (os .Stderr , "usage: handoff tunnel <connect-token> [--local-port PORT] [--relay URL] [--http-host HOST] " )
4849 os .Exit (2 )
4950 }
5051
@@ -88,6 +89,13 @@ func Tunnel(args []string) {
8889 }
8990 fmt .Printf ("tunnel ready -- forwarding 127.0.0.1:%d -> host %s:%d\n " ,
9091 opts .localPort , ready .HostAddr , ready .HostPort )
92+ httpHost := opts .httpHost
93+ if httpHost == "" {
94+ httpHost = defaultTunnelHTTPHost (ready .HostAddr , ready .HostPort )
95+ }
96+ if httpHost != "" {
97+ fmt .Printf ("HTTP Host headers will use %s\n " , httpHost )
98+ }
9199
92100 listener , err := net .Listen ("tcp" , net .JoinHostPort ("127.0.0.1" , strconv .Itoa (opts .localPort )))
93101 if err != nil {
@@ -97,7 +105,7 @@ func Tunnel(args []string) {
97105 defer listener .Close ()
98106 fmt .Println ("press Ctrl+C to close the tunnel." )
99107
100- client := newTunnelClient (conn )
108+ client := newTunnelClient (conn , httpHost )
101109 defer client .shutdown ("operator close" )
102110
103111 // Lifecycle hook for foreground-window policy. Currently a no-op.
@@ -113,6 +121,7 @@ type tunnelOptions struct {
113121 token string
114122 localPort int
115123 relay string
124+ httpHost string
116125}
117126
118127func parseTunnelArgs (args []string ) (tunnelOptions , error ) {
@@ -137,6 +146,16 @@ func parseTunnelArgs(args []string) (tunnelOptions, error) {
137146 }
138147 i ++
139148 opts .relay = args [i ]
149+ case a == "--http-host" || a == "--host-header" :
150+ if i + 1 >= len (args ) {
151+ return opts , errors .New ("--http-host requires a value" )
152+ }
153+ i ++
154+ host , err := cleanHTTPHost (args [i ])
155+ if err != nil {
156+ return opts , err
157+ }
158+ opts .httpHost = host
140159 case strings .HasPrefix (a , "--local-port=" ):
141160 port , err := strconv .Atoi (strings .TrimPrefix (a , "--local-port=" ))
142161 if err != nil || port <= 0 || port > 65535 {
@@ -145,6 +164,18 @@ func parseTunnelArgs(args []string) (tunnelOptions, error) {
145164 opts .localPort = port
146165 case strings .HasPrefix (a , "--relay=" ):
147166 opts .relay = strings .TrimPrefix (a , "--relay=" )
167+ case strings .HasPrefix (a , "--http-host=" ):
168+ host , err := cleanHTTPHost (strings .TrimPrefix (a , "--http-host=" ))
169+ if err != nil {
170+ return opts , err
171+ }
172+ opts .httpHost = host
173+ case strings .HasPrefix (a , "--host-header=" ):
174+ host , err := cleanHTTPHost (strings .TrimPrefix (a , "--host-header=" ))
175+ if err != nil {
176+ return opts , err
177+ }
178+ opts .httpHost = host
148179 case strings .HasPrefix (a , "-" ):
149180 return opts , fmt .Errorf ("unknown flag %q" , a )
150181 default :
@@ -167,6 +198,17 @@ func parseTunnelArgs(args []string) (tunnelOptions, error) {
167198 return opts , nil
168199}
169200
201+ func cleanHTTPHost (host string ) (string , error ) {
202+ host = strings .TrimSpace (host )
203+ if host == "" {
204+ return "" , errors .New ("--http-host requires a non-empty value" )
205+ }
206+ if strings .ContainsAny (host , "\r \n " ) {
207+ return "" , errors .New ("--http-host must not contain newlines" )
208+ }
209+ return host , nil
210+ }
211+
170212func tunnelWsURL (relayBase , token string ) (string , error ) {
171213 if isTunnelURL (token ) {
172214 u , err := url .Parse (token )
@@ -249,7 +291,8 @@ func awaitTunnelReady(ctx context.Context, conn *websocket.Conn) (*tunnelReady,
249291}
250292
251293type tunnelClient struct {
252- conn * websocket.Conn
294+ conn * websocket.Conn
295+ httpHost string
253296
254297 writeMu sync.Mutex
255298
@@ -258,8 +301,8 @@ type tunnelClient struct {
258301 closed bool
259302}
260303
261- func newTunnelClient (conn * websocket.Conn ) * tunnelClient {
262- return & tunnelClient {conn : conn , streams : map [string ]net.Conn {}}
304+ func newTunnelClient (conn * websocket.Conn , httpHost string ) * tunnelClient {
305+ return & tunnelClient {conn : conn , httpHost : httpHost , streams : map [string ]net.Conn {}}
263306}
264307
265308func (t * tunnelClient ) acceptLoop (ctx context.Context , listener net.Listener ) {
@@ -319,6 +362,8 @@ func (t *tunnelClient) readLoop(ctx context.Context) error {
319362
320363func (t * tunnelClient ) copyToRelay (ctx context.Context , streamID string , conn net.Conn ) {
321364 buf := make ([]byte , 16 * 1024 )
365+ rewriteHTTPHost := t .httpHost != ""
366+ var pending []byte
322367 defer t .dropStream (streamID , "local end of stream" )
323368 for {
324369 if ctx .Err () != nil {
@@ -328,6 +373,20 @@ func (t *tunnelClient) copyToRelay(ctx context.Context, streamID string, conn ne
328373 if n > 0 {
329374 chunk := make ([]byte , n )
330375 copy (chunk , buf [:n ])
376+ if rewriteHTTPHost {
377+ pending = append (pending , chunk ... )
378+ rewritten , ready := rewriteHTTPHostHeader (pending , t .httpHost )
379+ if ! ready && err == nil {
380+ continue
381+ }
382+ if ready {
383+ chunk = rewritten
384+ } else {
385+ chunk = pending
386+ }
387+ pending = nil
388+ rewriteHTTPHost = false
389+ }
331390 frame := tunnelFrame {
332391 Kind : "data" ,
333392 StreamID : streamID ,
@@ -341,12 +400,135 @@ func (t *tunnelClient) copyToRelay(ctx context.Context, streamID string, conn ne
341400 if ! errors .Is (err , io .EOF ) {
342401 supportlog .Printf ("tunnel-client local read error stream=%s: %v" , streamID , err )
343402 }
403+ if rewriteHTTPHost && len (pending ) > 0 {
404+ _ = t .sendFrame (ctx , tunnelFrame {
405+ Kind : "data" ,
406+ StreamID : streamID ,
407+ DataB64 : base64 .StdEncoding .EncodeToString (pending ),
408+ })
409+ pending = nil
410+ rewriteHTTPHost = false
411+ }
344412 _ = t .sendFrame (ctx , tunnelFrame {Kind : "stream_close" , StreamID : streamID , Reason : "eof" })
345413 return
346414 }
347415 }
348416}
349417
418+ const httpHeaderRewriteLimit = 64 * 1024
419+
420+ var httpMethodPrefixes = [][]byte {
421+ []byte ("GET " ),
422+ []byte ("POST " ),
423+ []byte ("HEAD " ),
424+ []byte ("PUT " ),
425+ []byte ("DELETE " ),
426+ []byte ("OPTIONS " ),
427+ []byte ("PATCH " ),
428+ []byte ("TRACE " ),
429+ }
430+
431+ func defaultTunnelHTTPHost (host string , port int ) string {
432+ host = strings .TrimSpace (host )
433+ if host == "" || isLoopbackTunnelHost (host ) {
434+ return ""
435+ }
436+ if port <= 0 || port == 80 {
437+ cleaned , err := cleanHTTPHost (host )
438+ if err != nil {
439+ return ""
440+ }
441+ return cleaned
442+ }
443+ hostPort := host + ":" + strconv .Itoa (port )
444+ if ip := net .ParseIP (host ); ip != nil && strings .Contains (host , ":" ) {
445+ hostPort = net .JoinHostPort (host , strconv .Itoa (port ))
446+ }
447+ cleaned , err := cleanHTTPHost (hostPort )
448+ if err != nil {
449+ return ""
450+ }
451+ return cleaned
452+ }
453+
454+ func isLoopbackTunnelHost (host string ) bool {
455+ if strings .EqualFold (host , "localhost" ) {
456+ return true
457+ }
458+ ip := net .ParseIP (host )
459+ return ip != nil && ip .IsLoopback ()
460+ }
461+
462+ func rewriteHTTPHostHeader (data []byte , host string ) ([]byte , bool ) {
463+ if host == "" {
464+ return data , true
465+ }
466+ if ! hasHTTPMethodPrefix (data ) {
467+ if mayBecomeHTTPMethod (data ) {
468+ return nil , false
469+ }
470+ return data , true
471+ }
472+ headerEnd := bytes .Index (data , []byte ("\r \n \r \n " ))
473+ if headerEnd < 0 {
474+ if len (data ) < httpHeaderRewriteLimit {
475+ return nil , false
476+ }
477+ return data , true
478+ }
479+
480+ header := data [:headerEnd ]
481+ body := data [headerEnd + 4 :]
482+ lines := bytes .Split (header , []byte ("\r \n " ))
483+ if len (lines ) == 0 {
484+ return data , true
485+ }
486+
487+ hostLine := []byte ("Host: " + host )
488+ out := make ([][]byte , 0 , len (lines )+ 1 )
489+ out = append (out , lines [0 ])
490+ replaced := false
491+ for _ , line := range lines [1 :] {
492+ if bytes .HasPrefix (bytes .ToLower (line ), []byte ("host:" )) {
493+ if ! replaced {
494+ out = append (out , hostLine )
495+ replaced = true
496+ }
497+ continue
498+ }
499+ out = append (out , line )
500+ }
501+ if ! replaced {
502+ out = append (out [:1 ], append ([][]byte {hostLine }, out [1 :]... )... )
503+ }
504+
505+ rewritten := bytes .Join (out , []byte ("\r \n " ))
506+ rewritten = append (rewritten , []byte ("\r \n \r \n " )... )
507+ rewritten = append (rewritten , body ... )
508+ return rewritten , true
509+ }
510+
511+ func hasHTTPMethodPrefix (data []byte ) bool {
512+ for _ , method := range httpMethodPrefixes {
513+ if bytes .HasPrefix (data , method ) {
514+ return true
515+ }
516+ }
517+ return false
518+ }
519+
520+ func mayBecomeHTTPMethod (data []byte ) bool {
521+ if len (data ) == 0 || len (data ) > len ("OPTIONS " ) {
522+ return false
523+ }
524+ for _ , method := range httpMethodPrefixes {
525+ if len (data ) <= len (method ) && bytes .HasPrefix (method , data ) {
526+ return true
527+ }
528+ }
529+ return false
530+ }
531+
350532func (t * tunnelClient ) sendFrame (ctx context.Context , frame tunnelFrame ) error {
351533 data , err := json .Marshal (frame )
352534 if err != nil {
0 commit comments