@@ -9,8 +9,10 @@ import (
99 "context"
1010 "fmt"
1111 "io"
12+ "net"
1213 "net/http"
1314 "net/http/httptest"
15+ "strings"
1416 "sync/atomic"
1517 "testing"
1618
@@ -221,3 +223,174 @@ func TestSSE405AllowHeader(t *testing.T) {
221223 })
222224 }
223225}
226+
227+ // TestSSELocalhostProtection verifies that DNS rebinding protection
228+ // is automatically enabled for localhost servers.
229+ func TestSSELocalhostProtection (t * testing.T ) {
230+ server := NewServer (testImpl , nil )
231+
232+ tests := []struct {
233+ name string
234+ listenAddr string
235+ hostHeader string
236+ disableProtection bool
237+ wantStatus int
238+ }{
239+ {
240+ name : "127.0.0.1 accepts 127.0.0.1" ,
241+ listenAddr : "127.0.0.1:0" ,
242+ hostHeader : "127.0.0.1:1234" ,
243+ wantStatus : http .StatusOK ,
244+ },
245+ {
246+ name : "127.0.0.1 accepts localhost" ,
247+ listenAddr : "127.0.0.1:0" ,
248+ hostHeader : "localhost:1234" ,
249+ wantStatus : http .StatusOK ,
250+ },
251+ {
252+ name : "127.0.0.1 rejects evil.com" ,
253+ listenAddr : "127.0.0.1:0" ,
254+ hostHeader : "evil.com" ,
255+ wantStatus : http .StatusForbidden ,
256+ },
257+ {
258+ name : "127.0.0.1 rejects evil.com:80" ,
259+ listenAddr : "127.0.0.1:0" ,
260+ hostHeader : "evil.com:80" ,
261+ wantStatus : http .StatusForbidden ,
262+ },
263+ {
264+ name : "127.0.0.1 rejects localhost.evil.com" ,
265+ listenAddr : "127.0.0.1:0" ,
266+ hostHeader : "localhost.evil.com" ,
267+ wantStatus : http .StatusForbidden ,
268+ },
269+ {
270+ name : "0.0.0.0 via localhost rejects evil.com" ,
271+ listenAddr : "0.0.0.0:0" ,
272+ hostHeader : "evil.com" ,
273+ wantStatus : http .StatusForbidden ,
274+ },
275+ {
276+ name : "disabled accepts evil.com" ,
277+ listenAddr : "127.0.0.1:0" ,
278+ hostHeader : "evil.com" ,
279+ disableProtection : true ,
280+ wantStatus : http .StatusOK ,
281+ },
282+ }
283+
284+ for _ , tt := range tests {
285+ t .Run (tt .name , func (t * testing.T ) {
286+ opts := & SSEOptions {
287+ DisableLocalhostProtection : tt .disableProtection ,
288+ }
289+ handler := NewSSEHandler (func (req * http.Request ) * Server { return server }, opts )
290+
291+ listener , err := net .Listen ("tcp" , tt .listenAddr )
292+ if err != nil {
293+ t .Fatalf ("Failed to listen on %s: %v" , tt .listenAddr , err )
294+ }
295+ defer listener .Close ()
296+
297+ srv := & http.Server {Handler : handler }
298+ go srv .Serve (listener )
299+ defer srv .Close ()
300+
301+ // Use a GET request since it's the entry point for SSE sessions.
302+ // For accepted requests, the response will be a hanging SSE stream,
303+ // but we only need to check the initial status code.
304+ req , err := http .NewRequest ("GET" , fmt .Sprintf ("http://%s" , listener .Addr ().String ()), nil )
305+ if err != nil {
306+ t .Fatal (err )
307+ }
308+ req .Host = tt .hostHeader
309+ req .Header .Set ("Accept" , "text/event-stream" )
310+
311+ resp , err := http .DefaultClient .Do (req )
312+ if err != nil {
313+ t .Fatal (err )
314+ }
315+ defer resp .Body .Close ()
316+
317+ if got := resp .StatusCode ; got != tt .wantStatus {
318+ t .Errorf ("Status code: got %d, want %d" , got , tt .wantStatus )
319+ }
320+ })
321+ }
322+ }
323+
324+ func TestSSEOriginProtection (t * testing.T ) {
325+ server := NewServer (testImpl , nil )
326+
327+ tests := []struct {
328+ name string
329+ protection * http.CrossOriginProtection
330+ requestOrigin string
331+ wantStatusCode int
332+ }{
333+ {
334+ name : "default protection with Origin header" ,
335+ protection : nil ,
336+ requestOrigin : "https://example.com" ,
337+ wantStatusCode : http .StatusForbidden ,
338+ },
339+ {
340+ name : "custom protection with trusted origin and same Origin" ,
341+ protection : func () * http.CrossOriginProtection {
342+ p := http .NewCrossOriginProtection ()
343+ if err := p .AddTrustedOrigin ("https://example.com" ); err != nil {
344+ t .Fatal (err )
345+ }
346+ return p
347+ }(),
348+ requestOrigin : "https://example.com" ,
349+ wantStatusCode : http .StatusNotFound , // origin accepted; session not found
350+ },
351+ {
352+ name : "custom protection with trusted origin and different Origin" ,
353+ protection : func () * http.CrossOriginProtection {
354+ p := http .NewCrossOriginProtection ()
355+ if err := p .AddTrustedOrigin ("https://example.com" ); err != nil {
356+ t .Fatal (err )
357+ }
358+ return p
359+ }(),
360+ requestOrigin : "https://malicious.com" ,
361+ wantStatusCode : http .StatusForbidden ,
362+ },
363+ }
364+
365+ for _ , tt := range tests {
366+ t .Run (tt .name , func (t * testing.T ) {
367+ opts := & SSEOptions {
368+ CrossOriginProtection : tt .protection ,
369+ }
370+ handler := NewSSEHandler (func (req * http.Request ) * Server { return server }, opts )
371+ httpServer := httptest .NewServer (handler )
372+ defer httpServer .Close ()
373+
374+ // Use POST with a valid session-like URL to test origin protection
375+ // without creating a hanging GET connection.
376+ reqReader := strings .NewReader (`{"jsonrpc":"2.0","id":1,"method":"ping"}` )
377+ req , err := http .NewRequest (http .MethodPost , httpServer .URL + "?sessionid=nonexistent" , reqReader )
378+ if err != nil {
379+ t .Fatal (err )
380+ }
381+ req .Header .Set ("Content-Type" , "application/json" )
382+ req .Header .Set ("Origin" , tt .requestOrigin )
383+
384+ resp , err := http .DefaultClient .Do (req )
385+ if err != nil {
386+ t .Fatal (err )
387+ }
388+ defer resp .Body .Close ()
389+
390+ if got := resp .StatusCode ; got != tt .wantStatusCode {
391+ body , _ := io .ReadAll (resp .Body )
392+ t .Errorf ("Status code: got %d, want %d (body: %s)" , got , tt .wantStatusCode , body )
393+ }
394+ })
395+ }
396+ }
0 commit comments