diff --git a/packages/orchestrator/internal/proxy/proxy.go b/packages/orchestrator/internal/proxy/proxy.go index c8d06d58f5..f621551d42 100644 --- a/packages/orchestrator/internal/proxy/proxy.go +++ b/packages/orchestrator/internal/proxy/proxy.go @@ -127,8 +127,8 @@ func (p *SandboxProxy) Close(ctx context.Context) error { return nil } -func (p *SandboxProxy) RemoveFromPool(connectionKey string) { - p.proxy.RemoveFromPool(connectionKey) +func (p *SandboxProxy) RemoveFromPool(connectionKey string) error { + return p.proxy.RemoveFromPool(connectionKey) } func (p *SandboxProxy) GetAddr() string { diff --git a/packages/orchestrator/internal/server/sandboxes.go b/packages/orchestrator/internal/server/sandboxes.go index bfec76bc1b..2a5f50570f 100644 --- a/packages/orchestrator/internal/server/sandboxes.go +++ b/packages/orchestrator/internal/server/sandboxes.go @@ -156,7 +156,11 @@ func (s *Server) Create(ctx context.Context, req *orchestrator.SandboxCreateRequ s.sandboxes.RemoveByExecutionID(req.GetSandbox().GetSandboxId(), sbx.Runtime.ExecutionID) // Remove the proxies assigned to the sandbox from the pool to prevent them from being reused. - s.proxy.RemoveFromPool(sbx.Runtime.ExecutionID) + closeErr := s.proxy.RemoveFromPool(sbx.Runtime.ExecutionID) + if closeErr != nil { + // Errors here will be from forcefully closing the connections, so we can ignore them—they will at worst timeout on their own. + sbxlogger.I(sbx).Warn("errors when manually closing connections to sandbox", zap.Error(closeErr)) + } sbxlogger.E(sbx).Info("Sandbox killed") }() diff --git a/packages/orchestrator/internal/template/build/layer/layer_executor.go b/packages/orchestrator/internal/template/build/layer/layer_executor.go index 13e57bc580..209177bbce 100644 --- a/packages/orchestrator/internal/template/build/layer/layer_executor.go +++ b/packages/orchestrator/internal/template/build/layer/layer_executor.go @@ -82,7 +82,12 @@ func (lb *LayerExecutor) BuildLayer( lb.sandboxes.Insert(sbx) defer func() { lb.sandboxes.Remove(sbx.Runtime.SandboxID) - lb.proxy.RemoveFromPool(sbx.Runtime.ExecutionID) + + closeErr := lb.proxy.RemoveFromPool(sbx.Runtime.ExecutionID) + if closeErr != nil { + // Errors here will be from forcefully closing the connections, so we can ignore them—they will at worst timeout on their own. + lb.logger.Warn("errors when manually closing connections to sandbox", zap.Error(closeErr)) + } }() // Update envd binary to the latest version diff --git a/packages/shared/pkg/proxy/pool/client.go b/packages/shared/pkg/proxy/pool/client.go index ba769c5a06..9df9769570 100644 --- a/packages/shared/pkg/proxy/pool/client.go +++ b/packages/shared/pkg/proxy/pool/client.go @@ -2,6 +2,7 @@ package pool import ( "context" + "errors" "log" "net" "net/http" @@ -13,12 +14,15 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/proxy/template" "github.com/e2b-dev/infra/packages/shared/pkg/proxy/tracking" + "github.com/e2b-dev/infra/packages/shared/pkg/smap" ) type ProxyClient struct { httputil.ReverseProxy transport *http.Transport + + activeConnections *smap.Map[*tracking.Connection] } func newProxyClient( @@ -30,6 +34,8 @@ func newProxyClient( currentConnsCounter *atomic.Int64, logger *log.Logger, ) *ProxyClient { + activeConnections := smap.New[*tracking.Connection]() + transport := &http.Transport{ Proxy: http.ProxyFromEnvironment, // Limit the max connection per host to avoid exhausting the number of available ports to one host. @@ -56,7 +62,7 @@ func newProxyClient( if err == nil { totalConnsCounter.Add(1) - return tracking.NewConnection(conn, currentConnsCounter), nil + return tracking.NewConnection(conn, currentConnsCounter, activeConnections), nil } if ctx.Err() != nil { @@ -82,7 +88,8 @@ func newProxyClient( } return &ProxyClient{ - transport: transport, + transport: transport, + activeConnections: activeConnections, ReverseProxy: httputil.ReverseProxy{ Transport: transport, Rewrite: func(r *httputil.ProxyRequest) { @@ -167,3 +174,16 @@ func newProxyClient( func (p *ProxyClient) closeIdleConnections() { p.transport.CloseIdleConnections() } + +func (p *ProxyClient) resetAllConnections() error { + var errs []error + + for _, conn := range p.activeConnections.Items() { + err := conn.Reset() + if err != nil { + errs = append(errs, err) + } + } + + return errors.Join(errs...) +} diff --git a/packages/shared/pkg/proxy/pool/pool.go b/packages/shared/pkg/proxy/pool/pool.go index cbb042daf3..56b15ef8b0 100644 --- a/packages/shared/pkg/proxy/pool/pool.go +++ b/packages/shared/pkg/proxy/pool/pool.go @@ -69,14 +69,17 @@ func (p *ProxyPool) Get(d *Destination) *ProxyClient { }) } -func (p *ProxyPool) Close(connectionKey string) { +func (p *ProxyPool) Close(connectionKey string) (err error) { p.pool.RemoveCb(connectionKey, func(_ string, proxy *ProxyClient, _ bool) bool { if proxy != nil { proxy.closeIdleConnections() + err = proxy.resetAllConnections() } return true }) + + return err } func (p *ProxyPool) TotalConnections() uint64 { diff --git a/packages/shared/pkg/proxy/proxy.go b/packages/shared/pkg/proxy/proxy.go index 9e585899a4..fbc122e85e 100644 --- a/packages/shared/pkg/proxy/proxy.go +++ b/packages/shared/pkg/proxy/proxy.go @@ -58,10 +58,12 @@ func New( } } +// TotalPoolConnections returns the total number of connections that have been established across whole pool. func (p *Proxy) TotalPoolConnections() uint64 { return p.pool.TotalConnections() } +// CurrentServerConnections returns the current number of connections that are alive across whole pool. func (p *Proxy) CurrentServerConnections() int64 { return p.currentServerConnsCounter.Load() } @@ -74,8 +76,8 @@ func (p *Proxy) CurrentPoolConnections() int64 { return p.pool.CurrentConnections() } -func (p *Proxy) RemoveFromPool(connectionKey string) { - p.pool.Close(connectionKey) +func (p *Proxy) RemoveFromPool(connectionKey string) error { + return p.pool.Close(connectionKey) } func (p *Proxy) ListenAndServe(ctx context.Context) error { diff --git a/packages/shared/pkg/proxy/proxy_test.go b/packages/shared/pkg/proxy/proxy_test.go index 6baa17d866..9da5a839c1 100644 --- a/packages/shared/pkg/proxy/proxy_test.go +++ b/packages/shared/pkg/proxy/proxy_test.go @@ -34,6 +34,8 @@ func (b *testBackend) RequestCount() uint64 { return b.requestCount.Load() } +const bodyWriteDelayHeader = "body-write-delay" + // newTestBackend creates a new test backend server func newTestBackend(listener net.Listener, id string) (*testBackend, error) { var requestCount atomic.Uint64 @@ -42,7 +44,7 @@ func newTestBackend(listener net.Listener, id string) (*testBackend, error) { backend := &testBackend{ server: &http.Server{ - Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { select { case <-ctx.Done(): w.WriteHeader(http.StatusBadGateway) @@ -54,6 +56,21 @@ func newTestBackend(listener net.Listener, id string) (*testBackend, error) { requestCount.Add(1) w.WriteHeader(http.StatusOK) + + // Flush the headers, so we can read the headers and body separately after .Do() returns. + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + + // Check for "body-write-delay" header (interpreted as seconds) + delayHeader := r.Header.Get(bodyWriteDelayHeader) + + if delayHeader != "" { + if n, err := time.ParseDuration(delayHeader); err == nil { + time.Sleep(n) + } + } + w.Write([]byte(id)) }), }, @@ -101,11 +118,22 @@ func assertBackendOutput(t *testing.T, backend *testBackend, resp *http.Response t.Helper() assert.Equal(t, resp.StatusCode, http.StatusOK, "status code should be 200") + body, err := io.ReadAll(resp.Body) require.NoError(t, err) + assert.Equal(t, string(body), backend.id, "backend id should be the same") } +func assertStreamError(t *testing.T, resp *http.Response) { + t.Helper() + + assert.Equal(t, resp.StatusCode, http.StatusOK, "status code should be 200") + + _, err := io.ReadAll(resp.Body) + assert.ErrorType(t, err, io.ErrUnexpectedEOF) +} + // newTestProxy creates a new proxy server for testing func newTestProxy(t *testing.T, getDestination func(r *http.Request) (*pool.Destination, error)) (*Proxy, uint, error) { t.Helper() @@ -175,11 +203,29 @@ func TestProxyRoutesToTargetServer(t *testing.T) { func httpGet(t *testing.T, proxyURL string) (*http.Response, error) { t.Helper() + return httpGetWithHeaders(t, proxyURL, nil) +} + +func httpGetWithBodyWriteDelay(t *testing.T, proxyURL string, bodyWriteDelay time.Duration) (*http.Response, error) { + t.Helper() + + return httpGetWithHeaders(t, proxyURL, http.Header{bodyWriteDelayHeader: {bodyWriteDelay.String()}}) +} + +func httpGetWithHeaders(t *testing.T, proxyURL string, headers http.Header) (*http.Response, error) { + t.Helper() + req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, proxyURL, nil) if err != nil { return nil, err } + for key, values := range headers { + for _, value := range values { + req.Header.Add(key, value) + } + } + rsp, err := (&http.Client{}).Do(req) if err != nil { return nil, err @@ -188,6 +234,51 @@ func httpGet(t *testing.T, proxyURL string) (*http.Response, error) { return rsp, nil } +type instrumentedConn struct { + net.Conn + + listener *instrumentedListener +} + +func (c *instrumentedConn) Read(b []byte) (int, error) { + n, err := c.Conn.Read(b) + if err != nil { + c.listener.AddReadError(err) + } + + return n, err +} + +func (l *instrumentedListener) AddReadError(err error) { + l.readErrsMutex.Lock() + defer l.readErrsMutex.Unlock() + + l.readErrs = append(l.readErrs, err) +} + +func (l *instrumentedListener) ReadErrors() []error { + l.readErrsMutex.Lock() + defer l.readErrsMutex.Unlock() + + return l.readErrs +} + +type instrumentedListener struct { + net.Listener + + readErrs []error + readErrsMutex sync.Mutex +} + +func (l *instrumentedListener) Accept() (net.Conn, error) { + conn, err := l.Listener.Accept() + if err != nil { + return nil, err + } + + return &instrumentedConn{Conn: conn, listener: l}, nil +} + func TestProxyReusesConnections(t *testing.T) { var lisCfg net.ListenConfig listener, err := lisCfg.Listen(t.Context(), "tcp", "127.0.0.1:0") @@ -233,6 +324,114 @@ func TestProxyReusesConnections(t *testing.T) { assert.Equal(t, proxy.TotalPoolConnections(), uint64(1), "proxy should have used one connection") } +func TestProxyCloseIdleConnectionsFromPool(t *testing.T) { + var lisCfg net.ListenConfig + listener, err := lisCfg.Listen(t.Context(), "tcp", "127.0.0.1:0") + require.NoError(t, err) + + backend, err := newTestBackend(listener, "backend-1") + require.NoError(t, err) + defer backend.Close() + + getDestination := func(*http.Request) (*pool.Destination, error) { + return &pool.Destination{ + Url: backend.url, + SandboxId: "test-sandbox", + RequestLogger: zap.NewNop(), + ConnectionKey: backend.id, + }, nil + } + + proxy, port, err := newTestProxy(t, getDestination) + require.NoError(t, err) + defer proxy.Close() + + // Make a request to the proxy + proxyURL := fmt.Sprintf("http://127.0.0.1:%d/hello", port) + resp, err := httpGet(t, proxyURL) + require.NoError(t, err) + defer resp.Body.Close() + + assertBackendOutput(t, backend, resp) + + assert.Equal(t, proxy.TotalPoolConnections(), uint64(1), "proxy should have established one connection") + assert.Equal(t, proxy.CurrentPoolConnections(), int64(1), "proxy should have established one connection that is still alive") + assert.Equal(t, backend.RequestCount(), uint64(1), "backend should have been called once") + + // Remove the connection from the pool + err = proxy.RemoveFromPool(backend.id) + require.NoError(t, err) + + assert.Equal(t, proxy.TotalPoolConnections(), uint64(1), "proxy should have still one connection in the pool") + assert.Equal(t, proxy.CurrentPoolConnections(), int64(0), "proxy should have removed the connection from the pool that is still alive") +} + +func TestProxyResetAliveConnectionsFromPool(t *testing.T) { + var lisCfg net.ListenConfig + + listener, err := lisCfg.Listen(t.Context(), "tcp", "127.0.0.1:0") + require.NoError(t, err) + + instrumentedListener := &instrumentedListener{Listener: listener} + + backend, err := newTestBackend(instrumentedListener, "backend-1") + require.NoError(t, err) + defer backend.Close() + + getDestination := func(*http.Request) (*pool.Destination, error) { + return &pool.Destination{ + Url: backend.url, + SandboxId: "test-sandbox", + RequestLogger: zap.NewNop(), + ConnectionKey: backend.id, + }, nil + } + + proxy, port, err := newTestProxy(t, getDestination) + require.NoError(t, err) + defer proxy.Close() + + requestEnded := make(chan struct{}, 1) + + go func() { + defer close(requestEnded) + + // Make a request to the proxy + proxyURL := fmt.Sprintf("http://127.0.0.1:%d/hello", port) + resp, err := httpGetWithBodyWriteDelay(t, proxyURL, 10*time.Second) + assert.NilError(t, err) + defer resp.Body.Close() + + assertStreamError(t, resp) + }() + + // Wait for the request to start being processed by the backend + time.Sleep(1 * time.Second) + + assert.Equal(t, proxy.TotalPoolConnections(), uint64(1), "proxy should have established one connection") + assert.Equal(t, proxy.CurrentPoolConnections(), int64(1), "proxy should have established one connection that is still alive") + assert.Equal(t, backend.RequestCount(), uint64(1), "backend should have been called once") + + // Remove the connection from the pool + err = proxy.RemoveFromPool(backend.id) + require.NoError(t, err) + + assert.Equal(t, proxy.TotalPoolConnections(), uint64(1), "proxy should have still one connection in the pool") + assert.Equal(t, proxy.CurrentPoolConnections(), int64(0), "proxy should have removed the connection from the pool that is still alive") + + select { + case <-requestEnded: + case <-t.Context().Done(): + t.Fatalf("request timed out: %v", t.Context().Err()) + } + + require.Len(t, instrumentedListener.ReadErrors(), 1, "server connection should have one read error") + // io.EOF is returned for the FIN packet. + require.NotErrorIs(t, instrumentedListener.ReadErrors()[0], io.EOF, "server connection should have read error other than EOF") + + require.ErrorContains(t, instrumentedListener.ReadErrors()[0], "connection reset by peer") +} + // This is a test that verify that the proxy reuse fails when the backend changes. func TestProxyReuseConnectionsWhenBackendChangesFails(t *testing.T) { // Create first backend diff --git a/packages/shared/pkg/proxy/tracking/connection.go b/packages/shared/pkg/proxy/tracking/connection.go index da88c453d1..7d7f9d1b5c 100644 --- a/packages/shared/pkg/proxy/tracking/connection.go +++ b/packages/shared/pkg/proxy/tracking/connection.go @@ -1,23 +1,57 @@ package tracking import ( + "errors" "net" "sync/atomic" + + "github.com/google/uuid" + + "github.com/e2b-dev/infra/packages/shared/pkg/smap" ) type Connection struct { net.Conn counter *atomic.Int64 + key string + + m *smap.Map[*Connection] } -func NewConnection(conn net.Conn, counter *atomic.Int64) *Connection { +func NewConnection(conn net.Conn, counter *atomic.Int64, m *smap.Map[*Connection]) *Connection { counter.Add(1) - return &Connection{ + c := &Connection{ Conn: conn, counter: counter, + m: m, + } + + if m != nil { + c.key = uuid.New().String() + + m.Insert(c.key, c) + } + + return c +} + +func (c *Connection) Reset() error { + var errs []error + + // This forces the connection to close with RST. + err := c.Conn.(*net.TCPConn).SetLinger(0) + if err != nil { + errs = append(errs, err) } + + err = c.Close() + if err != nil { + errs = append(errs, err) + } + + return errors.Join(errs...) } func (c *Connection) Close() error { @@ -28,5 +62,9 @@ func (c *Connection) Close() error { c.counter.Add(-1) + if c.m != nil { + c.m.Remove(c.key) + } + return nil } diff --git a/packages/shared/pkg/proxy/tracking/listener.go b/packages/shared/pkg/proxy/tracking/listener.go index e792732f78..985c920672 100644 --- a/packages/shared/pkg/proxy/tracking/listener.go +++ b/packages/shared/pkg/proxy/tracking/listener.go @@ -24,5 +24,5 @@ func (l *Listener) Accept() (net.Conn, error) { return nil, err } - return NewConnection(conn, l.counter), nil + return NewConnection(conn, l.counter, nil), nil }