Skip to content
6 changes: 4 additions & 2 deletions packages/api/internal/middleware/otel/metrics/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
semconv "go.opentelemetry.io/otel/semconv/v1.7.0"

sharedmiddleware "github.com/e2b-dev/infra/packages/shared/pkg/middleware"
)

const MetricPrefix = "metric."
Expand Down Expand Up @@ -76,9 +78,9 @@ func Middleware(meterProvider metric.MeterProvider, service string, options ...O
)

code := ginCtx.Writer.Status()
if errors.Is(ctx.Err(), context.Canceled) {
if errors.Is(sharedmiddleware.CancelCause(ginCtx), context.Canceled) {
// 499 is the nginx convention for "client closed request before server responded"
code = 499
code = sharedmiddleware.StatusClientClosedRequest
}

groupedCode := code / 100 * 100
Expand Down
8 changes: 6 additions & 2 deletions packages/api/internal/middleware/otel/tracing/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
oteltrace "go.opentelemetry.io/otel/trace"

"github.com/e2b-dev/infra/packages/shared/pkg/logger"
sharedmiddleware "github.com/e2b-dev/infra/packages/shared/pkg/middleware"
"github.com/e2b-dev/infra/packages/shared/pkg/telemetry"
)

Expand Down Expand Up @@ -112,8 +113,11 @@ func Middleware(tracerProvider oteltrace.TracerProvider, service string) gin.Han
c.Next()

status := c.Writer.Status()
if errors.Is(ctx.Err(), context.Canceled) {
status = 499
cause := sharedmiddleware.CancelCause(c)
if errors.Is(cause, sharedmiddleware.ErrRequestTimeout) {
span.SetAttributes(attribute.Bool("request.timeout", true))
} else if errors.Is(cause, context.Canceled) {
status = sharedmiddleware.StatusClientClosedRequest
span.SetAttributes(attribute.Bool("client.canceled", true))
}

Expand Down
41 changes: 0 additions & 41 deletions packages/api/internal/middleware/timeout.go

This file was deleted.

4 changes: 2 additions & 2 deletions packages/api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ const (
// Must be less than maxWriteTimeout so the context cancels before the
// server's write deadline kills the connection (WriteTimeout does NOT
// cancel r.Context(); see https://github.com/golang/go/issues/59602).
// requestTimeout = 60 * time.Second
requestTimeout = 70 * time.Second

// This timeout should be > 600 (GCP LB upstream idle timeout) to prevent race condition
// https://cloud.google.com/load-balancing/docs/https#timeouts_and_retries%23:~:text=The%20load%20balancer%27s%20backend%20keepalive,is%20greater%20than%20600%20seconds
Expand Down Expand Up @@ -133,7 +133,7 @@ func NewGinServer(ctx context.Context, config cfg.Config, tel *telemetry.Client,
},
}),
gin.Recovery(),
// customMiddleware.RequestTimeout(requestTimeout), //nolint:contextcheck // Gin middleware sets context via c.Request.WithContext
sharedmiddleware.RequestTimeout(requestTimeout), //nolint:contextcheck // Gin middleware sets context via c.Request.WithContext
Comment thread
cursor[bot] marked this conversation as resolved.
)

corsConfig := cors.DefaultConfig()
Expand Down
16 changes: 10 additions & 6 deletions packages/shared/pkg/middleware/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,24 @@ func LoggingMiddleware(logger logger.Logger, conf Config) gin.HandlerFunc {
end = end.UTC()
}

status := c.Writer.Status()
if errors.Is(ctx.Err(), context.Canceled) {
status = 499
}

fields := []zapcore.Field{
zap.Int("status", status),
zap.String("method", c.Request.Method),
zap.String("path", path),
zap.String("query", query),
zap.String("user-agent", c.Request.UserAgent()),
zap.Duration("latency", latency),
}

status := c.Writer.Status()
cause := CancelCause(c)
if errors.Is(cause, ErrRequestTimeout) {
fields = append(fields, zap.Bool("request_timeout", true))
} else if errors.Is(cause, context.Canceled) {
status = StatusClientClosedRequest // 499
fields = append(fields, zap.Bool("client_canceled", true))
}
Comment thread
cursor[bot] marked this conversation as resolved.

fields = append(fields, zap.Int("status", status))
if conf.TimeFormat != "" {
fields = append(fields, zap.String("time", end.Format(conf.TimeFormat)))
}
Expand Down
67 changes: 67 additions & 0 deletions packages/shared/pkg/middleware/timeout.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package middleware

import (
"context"
"errors"
"time"

"github.com/gin-gonic/gin"
)

// ErrRequestTimeout is the cancel cause set when the per-request timeout fires.
// Callers can distinguish this from a client disconnection by checking:
//
// errors.Is(CancelCause(c), middleware.ErrRequestTimeout)
var ErrRequestTimeout = errors.New("request timeout exceeded")

// cancelCauseKey is the gin context key where RequestTimeout snapshots the
// cancel cause before defer-cancel runs.
const cancelCauseKey = "middleware.cancelCause"

// CancelCause returns the cancel cause captured by the timeout middleware.
// It returns nil for normal (non-canceled/non-timed-out) requests.
func CancelCause(c *gin.Context) error {
if val, exists := c.Get(cancelCauseKey); exists {
if err, ok := val.(error); ok {
return err
}
}

return nil
}

// StatusClientClosedRequest is the de-facto status code (used by nginx) for a
// client that closed the connection before the server could send a response.
// It is not an official IANA code but is widely recognised in logs and metrics.
const StatusClientClosedRequest = 499

// RequestTimeout returns a Gin middleware that sets a context deadline on each
// request. This is needed because http.Server.WriteTimeout does NOT cancel
// r.Context() (see https://github.com/golang/go/issues/59602), so without an
// explicit deadline, blocking calls like pgxpool.Acquire will wait indefinitely
// when the connection pool is saturated.
//
// After the handler returns, the middleware checks context.Cause to distinguish
// two cancellation scenarios and sets an appropriate status if nothing was
// written yet:
// - server-side timeout (cause is ErrRequestTimeout) → 408 Request Timeout
// - client disconnect (ctx.Err() == context.Canceled) → 499 Client Closed Request
//
// Routes matching any of the excludedRoutes patterns are skipped (useful for
// health checks and long-polling endpoints).
func RequestTimeout(timeout time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
ctx, cancel := context.WithTimeoutCause(c.Request.Context(), timeout, ErrRequestTimeout)
defer cancel()

c.Request = c.Request.WithContext(ctx)
c.Next()
Comment thread
jakubno marked this conversation as resolved.
Comment thread
cursor[bot] marked this conversation as resolved.

// Snapshot the cause *before* defer-cancel fires so outer
// middlewares can distinguish timeout vs client-disconnect
// via CancelCause(c) without racing with the deferred cancel.
if err := context.Cause(ctx); err != nil {
c.Set(cancelCauseKey, err)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,27 +35,6 @@ func TestRequestTimeout_SetsDeadline(t *testing.T) {
require.Equal(t, http.StatusOK, w.Code)
}

func TestRequestTimeout_Returns408WhenHandlerDoesNotWrite(t *testing.T) {
t.Parallel()

r := gin.New()
r.Use(RequestTimeout(100 * time.Millisecond))
r.GET("/slow", func(c *gin.Context) {
<-c.Request.Context().Done()
})

w := httptest.NewRecorder()
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/slow", nil)

start := time.Now()
r.ServeHTTP(w, req)
elapsed := time.Since(start)

assert.Less(t, elapsed, 500*time.Millisecond, "should have timed out promptly")
require.Equal(t, http.StatusRequestTimeout, w.Code)
assert.Equal(t, "request timed out", w.Body.String())
}

func TestRequestTimeout_CancelsBlockingHandler(t *testing.T) {
t.Parallel()

Expand All @@ -80,38 +59,53 @@ func TestRequestTimeout_CancelsBlockingHandler(t *testing.T) {
assert.Less(t, elapsed, 500*time.Millisecond, "handler should have been unblocked by context timeout")
}

func TestRequestTimeout_ExcludedRouteHasNoDeadline(t *testing.T) {
func TestRequestTimeout_NormalRequestContextNotCanceled(t *testing.T) {
t.Parallel()

// Simulate an outer middleware that reads CancelCause after c.Next().
// The context itself will be canceled by defer cancel(), but CancelCause
// should return nil for normal (non-timed-out) requests.
var outerCause error
outerMiddleware := func(c *gin.Context) {
c.Next()
outerCause = CancelCause(c)
}

r := gin.New()
r.Use(RequestTimeout(500*time.Millisecond, "/health"))
r.GET("/health", func(c *gin.Context) {
_, ok := c.Request.Context().Deadline()
assert.False(t, ok, "excluded route should not have a deadline")
r.Use(outerMiddleware)
r.Use(RequestTimeout(5 * time.Second))
r.GET("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})

w := httptest.NewRecorder()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/health", nil)
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/test", nil)
r.ServeHTTP(w, req)

require.Equal(t, http.StatusOK, w.Code)
assert.NoError(t, outerCause, "CancelCause should return nil for normal requests")
}

func TestRequestTimeout_ExcludedRouteWithParam(t *testing.T) {
func TestRequestTimeout_TimeoutContextVisibleToOuterMiddleware(t *testing.T) {
t.Parallel()

var outerCause error
outerMiddleware := func(c *gin.Context) {
c.Next()
outerCause = CancelCause(c)
}

r := gin.New()
r.Use(RequestTimeout(500*time.Millisecond, "/templates/:templateID/builds/:buildID/logs"))
r.GET("/templates/:templateID/builds/:buildID/logs", func(c *gin.Context) {
_, ok := c.Request.Context().Deadline()
assert.False(t, ok, "excluded parameterized route should not have a deadline")
c.Status(http.StatusOK)
r.Use(outerMiddleware)
r.Use(RequestTimeout(50 * time.Millisecond))
r.GET("/slow", func(c *gin.Context) {
<-c.Request.Context().Done()
})

w := httptest.NewRecorder()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/templates/abc123/builds/build456/logs", nil)
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/slow", nil)
r.ServeHTTP(w, req)

require.Equal(t, http.StatusOK, w.Code)
require.ErrorIs(t, outerCause, ErrRequestTimeout,
"outer middleware should see ErrRequestTimeout as the cause when the timeout fires")
}
Loading