44 "context"
55 "fmt"
66 "net/http"
7+ "net/url"
78 "strings"
89
910 "github.com/gorilla/websocket"
@@ -16,8 +17,7 @@ import (
1617)
1718
1819var wsUpgrader = websocket.Upgrader {
19- // TODO: validate Origin header for production CSRF protection.
20- CheckOrigin : func (r * http.Request ) bool { return true },
20+ CheckOrigin : allowWebSocketOrigin ,
2121}
2222
2323var dockerBridgeIPv4Lookup = netutil .LookupInterfaceIPv4
@@ -69,6 +69,83 @@ func wsTokenAuth() *hook.Handler[*core.RequestEvent] {
6969 }
7070}
7171
72+ func allowWebSocketOrigin (r * http.Request ) bool {
73+ origin := strings .TrimSpace (r .Header .Get ("Origin" ))
74+ if origin == "" {
75+ return true
76+ }
77+ parsed , err := url .Parse (origin )
78+ if err != nil || parsed .Scheme == "" || parsed .Host == "" {
79+ return false
80+ }
81+ requestScheme := resolveWebSocketHTTPScheme (r )
82+ if ! strings .EqualFold (parsed .Scheme , requestScheme ) {
83+ return false
84+ }
85+ return sameWebSocketOriginHost (parsed .Host , resolveWebSocketHTTPHost (r ), requestScheme )
86+ }
87+
88+ func resolveWebSocketHTTPScheme (r * http.Request ) string {
89+ if strings .EqualFold (strings .TrimSpace (r .Header .Get ("X-Forwarded-Proto" )), "https" ) || r .TLS != nil {
90+ return "https"
91+ }
92+ return "http"
93+ }
94+
95+ func resolveWebSocketHTTPHost (r * http.Request ) string {
96+ host := firstForwardedHostValue (r .Host )
97+ forwardedHost := firstForwardedHostValue (r .Header .Get ("X-Forwarded-Host" ))
98+ if host == "" {
99+ host = forwardedHost
100+ }
101+ if forwardedHost != "" && forwardedHostCarriesPort (host , forwardedHost ) {
102+ host = forwardedHost
103+ }
104+ if ! hostHasExplicitPort (host ) {
105+ if forwardedPort := firstForwardedPortValue (r .Header .Get ("X-Forwarded-Port" )); forwardedPort != "" {
106+ host = appendPortIfMissing (host , forwardedPort )
107+ }
108+ }
109+ return host
110+ }
111+
112+ func sameWebSocketOriginHost (originHost string , requestHost string , scheme string ) bool {
113+ if ! strings .EqualFold (stripOptionalPort (originHost ), stripOptionalPort (requestHost )) {
114+ return false
115+ }
116+ return effectivePort (originHost , scheme ) == effectivePort (requestHost , scheme )
117+ }
118+
119+ func effectivePort (host string , scheme string ) string {
120+ if host == "" {
121+ return defaultPortForScheme (scheme )
122+ }
123+ if strings .HasPrefix (host , "[" ) {
124+ if idx := strings .LastIndex (host , "]:" ); idx >= 0 {
125+ return host [idx + 2 :]
126+ }
127+ return defaultPortForScheme (scheme )
128+ }
129+ idx := strings .LastIndex (host , ":" )
130+ if idx <= 0 || strings .Contains (host [:idx ], ":" ) {
131+ return defaultPortForScheme (scheme )
132+ }
133+ port := host [idx + 1 :]
134+ for _ , ch := range port {
135+ if ch < '0' || ch > '9' {
136+ return defaultPortForScheme (scheme )
137+ }
138+ }
139+ return port
140+ }
141+
142+ func defaultPortForScheme (scheme string ) string {
143+ if strings .EqualFold (strings .TrimSpace (scheme ), "https" ) {
144+ return "443"
145+ }
146+ return "80"
147+ }
148+
72149// registerServerRoutes registers server catalog/ops routes (non-terminal).
73150// These handle connectivity checks, power, ports, and systemd operations.
74151func registerServerRoutes (g * router.RouterGroup [* core.RequestEvent ]) {
0 commit comments