diff --git a/net/firewall/relay.go b/net/firewall/relay.go index 15205e4..f3bfef1 100644 --- a/net/firewall/relay.go +++ b/net/firewall/relay.go @@ -17,6 +17,12 @@ import ( "golang.org/x/sync/errgroup" ) +// maxFrameSize bounds the length prefix the relay will accept from either +// peer. Legitimate traffic at the topology MTU sits well under 64 KiB — +// this cap catches corrupt or malicious length prefixes that would +// otherwise trigger multi-GiB allocations in the subsequent make/ReadFull. +const maxFrameSize = 65535 + // InterceptAction tells the relay what to do with a DNS frame. type InterceptAction uint8 @@ -132,6 +138,11 @@ func (r *Relay) forward(ctx context.Context, src, dst net.Conn, dir Direction) e if frameLen == 0 { continue } + if frameLen > maxFrameSize { + return r.wrapError(ctx, fmt.Errorf( + "frame length %d exceeds maximum %d: peer protocol violation", + frameLen, maxFrameSize)) + } // Grow the frame buffer if needed, reuse otherwise. if uint32(cap(frameBuf)) < frameLen { @@ -173,11 +184,12 @@ func (r *Relay) forward(ctx context.Context, src, dst net.Conn, dir Direction) e } } - // When a DNS hook is active, drop non-IPv4 frames that are not - // ARP (EtherType 0x0806). This prevents IPv6 from bypassing the - // egress policy. Without a DNS hook, non-IPv4 frames pass through - // as before (needed for basic network bootstrapping). - if hdr == nil && r.dnsHook != nil { + // Non-IPv4 frames return hdr == nil from ParseHeaders. Under a + // DNS hook or a deny-default filter, drop them (except ARP) so + // that IPv6 and exotic EtherTypes cannot bypass the egress policy. + // With neither, non-IPv4 frames pass through as before (needed + // for basic network bootstrapping on allow-default setups). + if hdr == nil && (r.dnsHook != nil || r.filter.defaultAction == Deny) { if len(frameBuf) >= 14 { etherType := binary.BigEndian.Uint16(frameBuf[12:14]) if etherType != 0x0806 { // not ARP diff --git a/net/firewall/relay_test.go b/net/firewall/relay_test.go index 1a10285..242fa5d 100644 --- a/net/firewall/relay_test.go +++ b/net/firewall/relay_test.go @@ -98,6 +98,41 @@ func TestRelay_EndToEnd(t *testing.T) { <-errCh } +func TestRelay_RejectsOversizedLengthPrefix(t *testing.T) { + t.Parallel() + + filter := NewFilter(nil, Allow) + relay := NewRelay(filter) + + vmApp, vmRelay := net.Pipe() + netRelay, _ := net.Pipe() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + errCh := make(chan error, 1) + go func() { + errCh <- relay.Run(ctx, vmRelay, netRelay) + }() + + // Write a 4-byte big-endian length prefix claiming a 2 MiB frame — + // well above maxFrameSize. Do not send any payload. + var lenBuf [4]byte + binary.BigEndian.PutUint32(lenBuf[:], 2*1024*1024) + _, err := vmApp.Write(lenBuf[:]) + require.NoError(t, err) + + // The relay must terminate with a protocol-violation error rather + // than attempt a multi-MiB allocation and hang on ReadFull. + select { + case err := <-errCh: + require.Error(t, err) + assert.Contains(t, err.Error(), "exceeds maximum") + case <-time.After(2 * time.Second): + t.Fatal("relay did not terminate on oversized length prefix") + } +} + func TestRelay_DroppedFrame(t *testing.T) { t.Parallel() @@ -183,6 +218,52 @@ func TestRelay_ARPPassthroughWithDenyAll(t *testing.T) { <-errCh } +func TestRelay_DropsNonIPv4UnderDenyDefault(t *testing.T) { + t.Parallel() + + // Deny-default with no DNS hook. IPv6 (and any other non-IPv4, non-ARP + // EtherType) would previously pass through as "hdr == nil" without + // being checked against the filter. Callers who set FirewallDefault + // Deny expect a closed egress; honor that for v6 frames. + filter := NewFilter(nil, Deny) + relay := NewRelay(filter) + + vmApp, vmRelay := net.Pipe() + netRelay, netApp := net.Pipe() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + errCh := make(chan error, 1) + go func() { + errCh <- relay.Run(ctx, vmRelay, netRelay) + }() + + // Build a minimal IPv6-tagged frame (EtherType 0x86DD). + v6Frame := make([]byte, 60) + binary.BigEndian.PutUint16(v6Frame[12:14], 0x86DD) + + // Send the v6 frame first; it must be dropped. + _, err := vmApp.Write(buildPrefixedFrame(v6Frame)) + require.NoError(t, err) + + // Follow with an ARP frame; it must still pass (existing guarantee). + arpFrame := make([]byte, 42) + binary.BigEndian.PutUint16(arpFrame[12:14], 0x0806) + _, err = vmApp.Write(buildPrefixedFrame(arpFrame)) + require.NoError(t, err) + + got := readPrefixedFrame(t, netApp) + assert.Equal(t, arpFrame, got, "ARP should still pass under deny-default") + + m := relay.Metrics() + assert.Equal(t, uint64(1), m.FramesForwarded.Load()) + assert.Equal(t, uint64(1), m.FramesDropped.Load(), "v6 frame must have been dropped") + + cancel() + <-errCh +} + func TestRelay_Metrics(t *testing.T) { t.Parallel() diff --git a/net/hosted/service.go b/net/hosted/service.go index 9127efd..f13b853 100644 --- a/net/hosted/service.go +++ b/net/hosted/service.go @@ -14,6 +14,19 @@ import ( "github.com/stacklok/go-microvm/net/topology" ) +// Default HTTP server timeouts for hosted services. These protect the +// host process from a misbehaving or hostile guest that opens +// connections but stalls the request — classic Slowloris / slow-body +// patterns — exhausting goroutines and file descriptors in the caller's +// process. Callers can override any of these per-Service if they ship a +// streaming handler that legitimately takes longer than the default. +const ( + defaultReadHeaderTimeout = 10 * time.Second + defaultReadTimeout = 30 * time.Second + defaultWriteTimeout = 30 * time.Second + defaultIdleTimeout = 60 * time.Second +) + // Service describes an HTTP service to expose inside the virtual network. // // Services always bind to the gateway IP ([topology.GatewayIP], 192.168.127.1) @@ -27,6 +40,43 @@ type Service struct { // Handler is the HTTP handler that serves requests. Handler http.Handler + + // ReadHeaderTimeout bounds the time the server will wait to finish + // reading request headers. Zero uses defaultReadHeaderTimeout. + ReadHeaderTimeout time.Duration + + // ReadTimeout bounds the total time reading a request including + // the body. Zero uses defaultReadTimeout. + ReadTimeout time.Duration + + // WriteTimeout bounds the total time writing the response. Zero + // uses defaultWriteTimeout. + WriteTimeout time.Duration + + // IdleTimeout bounds the time a keep-alive connection may remain + // idle between requests. Zero uses defaultIdleTimeout. + IdleTimeout time.Duration +} + +// timeoutOrDefault returns user if set, else the fallback default. +func timeoutOrDefault(user, fallback time.Duration) time.Duration { + if user > 0 { + return user + } + return fallback +} + +// newHTTPServer constructs an *http.Server for the given Service with +// Slowloris-bounding timeouts applied. Zero-valued timeout fields on +// svc fall back to defaults. +func newHTTPServer(svc Service) *http.Server { + return &http.Server{ + Handler: svc.Handler, + ReadHeaderTimeout: timeoutOrDefault(svc.ReadHeaderTimeout, defaultReadHeaderTimeout), + ReadTimeout: timeoutOrDefault(svc.ReadTimeout, defaultReadTimeout), + WriteTimeout: timeoutOrDefault(svc.WriteTimeout, defaultWriteTimeout), + IdleTimeout: timeoutOrDefault(svc.IdleTimeout, defaultIdleTimeout), + } } // runningService tracks a started service for graceful shutdown. @@ -63,9 +113,7 @@ func (p *Provider) startServices() error { return fmt.Errorf("listen on %s for service %d: %w", addr, i, err) } - srv := &http.Server{ - Handler: svc.Handler, - } + srv := newHTTPServer(svc) p.runningServices = append(p.runningServices, runningService{ server: srv, diff --git a/net/hosted/service_test.go b/net/hosted/service_test.go index bcbf1ce..3d62e97 100644 --- a/net/hosted/service_test.go +++ b/net/hosted/service_test.go @@ -7,6 +7,7 @@ import ( "context" "net/http" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -15,6 +16,39 @@ import ( propnet "github.com/stacklok/go-microvm/net" ) +func TestNewHTTPServer_AppliesDefaults(t *testing.T) { + t.Parallel() + + srv := newHTTPServer(Service{ + Port: 4483, + Handler: http.NotFoundHandler(), + }) + + assert.Equal(t, defaultReadHeaderTimeout, srv.ReadHeaderTimeout) + assert.Equal(t, defaultReadTimeout, srv.ReadTimeout) + assert.Equal(t, defaultWriteTimeout, srv.WriteTimeout) + assert.Equal(t, defaultIdleTimeout, srv.IdleTimeout) +} + +func TestNewHTTPServer_RespectsOverrides(t *testing.T) { + t.Parallel() + + override := 3 * time.Second + srv := newHTTPServer(Service{ + Port: 4483, + Handler: http.NotFoundHandler(), + ReadHeaderTimeout: override, + ReadTimeout: override, + WriteTimeout: override, + IdleTimeout: override, + }) + + assert.Equal(t, override, srv.ReadHeaderTimeout) + assert.Equal(t, override, srv.ReadTimeout) + assert.Equal(t, override, srv.WriteTimeout) + assert.Equal(t, override, srv.IdleTimeout) +} + func TestAddServiceBeforeStart(t *testing.T) { t.Parallel()