diff --git a/packages/api/internal/middleware/otel/metrics/middleware.go b/packages/api/internal/middleware/otel/metrics/middleware.go index 0ee91b000c..8f2f27ee9d 100644 --- a/packages/api/internal/middleware/otel/metrics/middleware.go +++ b/packages/api/internal/middleware/otel/metrics/middleware.go @@ -11,6 +11,8 @@ import ( "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" semconv "go.opentelemetry.io/otel/semconv/v1.7.0" + + sharedmiddleware "github.com/e2b-dev/infra/packages/shared/pkg/middleware" ) const MetricPrefix = "metric." @@ -76,9 +78,9 @@ func Middleware(meterProvider metric.MeterProvider, service string, options ...O ) code := ginCtx.Writer.Status() - if errors.Is(ctx.Err(), context.Canceled) { + if errors.Is(sharedmiddleware.CancelCause(ginCtx), context.Canceled) { // 499 is the nginx convention for "client closed request before server responded" - code = 499 + code = sharedmiddleware.StatusClientClosedRequest } groupedCode := code / 100 * 100 diff --git a/packages/api/internal/middleware/otel/tracing/middleware.go b/packages/api/internal/middleware/otel/tracing/middleware.go index daf210f879..3bb2ce0c1b 100644 --- a/packages/api/internal/middleware/otel/tracing/middleware.go +++ b/packages/api/internal/middleware/otel/tracing/middleware.go @@ -30,6 +30,7 @@ import ( oteltrace "go.opentelemetry.io/otel/trace" "github.com/e2b-dev/infra/packages/shared/pkg/logger" + sharedmiddleware "github.com/e2b-dev/infra/packages/shared/pkg/middleware" "github.com/e2b-dev/infra/packages/shared/pkg/telemetry" ) @@ -112,8 +113,11 @@ func Middleware(tracerProvider oteltrace.TracerProvider, service string) gin.Han c.Next() status := c.Writer.Status() - if errors.Is(ctx.Err(), context.Canceled) { - status = 499 + cause := sharedmiddleware.CancelCause(c) + if errors.Is(cause, sharedmiddleware.ErrRequestTimeout) { + span.SetAttributes(attribute.Bool("request.timeout", true)) + } else if errors.Is(cause, context.Canceled) { + status = sharedmiddleware.StatusClientClosedRequest span.SetAttributes(attribute.Bool("client.canceled", true)) } diff --git a/packages/api/internal/middleware/timeout.go b/packages/api/internal/middleware/timeout.go deleted file mode 100644 index 9963d93cf7..0000000000 --- a/packages/api/internal/middleware/timeout.go +++ /dev/null @@ -1,41 +0,0 @@ -package middleware - -import ( - "context" - "net/http" - "time" - - "github.com/gin-gonic/gin" -) - -// RequestTimeout returns a Gin middleware that sets a context deadline on each -// request. This is needed because http.Server.WriteTimeout does NOT cancel -// r.Context() (see https://github.com/golang/go/issues/59602), so without an -// explicit deadline, blocking calls like pgxpool.Acquire will wait indefinitely -// when the connection pool is saturated. -// -// If the deadline is exceeded and the handler has not yet written a response, -// the middleware responds with 408 Request Timeout. -// -// Routes matching any of the excludedRoutes patterns are skipped (useful for -// health checks and long-polling endpoints). -func RequestTimeout(timeout time.Duration, excludedRoutes ...string) gin.HandlerFunc { - return func(c *gin.Context) { - if shouldSkip(c.Request.URL.Path, excludedRoutes) { - c.Next() - - return - } - - ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) - defer cancel() - - c.Request = c.Request.WithContext(ctx) - c.Next() - - if ctx.Err() == context.DeadlineExceeded && !c.Writer.Written() { - c.String(http.StatusRequestTimeout, "request timed out") - c.Abort() - } - } -} diff --git a/packages/api/main.go b/packages/api/main.go index 7efdd5e19b..2c05bc0632 100644 --- a/packages/api/main.go +++ b/packages/api/main.go @@ -63,7 +63,7 @@ const ( // Must be less than maxWriteTimeout so the context cancels before the // server's write deadline kills the connection (WriteTimeout does NOT // cancel r.Context(); see https://github.com/golang/go/issues/59602). - // requestTimeout = 60 * time.Second + requestTimeout = 70 * time.Second // This timeout should be > 600 (GCP LB upstream idle timeout) to prevent race condition // https://cloud.google.com/load-balancing/docs/https#timeouts_and_retries%23:~:text=The%20load%20balancer%27s%20backend%20keepalive,is%20greater%20than%20600%20seconds @@ -133,7 +133,7 @@ func NewGinServer(ctx context.Context, config cfg.Config, tel *telemetry.Client, }, }), gin.Recovery(), - // customMiddleware.RequestTimeout(requestTimeout), //nolint:contextcheck // Gin middleware sets context via c.Request.WithContext + sharedmiddleware.RequestTimeout(requestTimeout), //nolint:contextcheck // Gin middleware sets context via c.Request.WithContext ) corsConfig := cors.DefaultConfig() diff --git a/packages/shared/pkg/middleware/logging.go b/packages/shared/pkg/middleware/logging.go index 3539d22893..5412e5e2cc 100644 --- a/packages/shared/pkg/middleware/logging.go +++ b/packages/shared/pkg/middleware/logging.go @@ -72,13 +72,7 @@ func LoggingMiddleware(logger logger.Logger, conf Config) gin.HandlerFunc { end = end.UTC() } - status := c.Writer.Status() - if errors.Is(ctx.Err(), context.Canceled) { - status = 499 - } - fields := []zapcore.Field{ - zap.Int("status", status), zap.String("method", c.Request.Method), zap.String("path", path), zap.String("query", query), @@ -86,6 +80,16 @@ func LoggingMiddleware(logger logger.Logger, conf Config) gin.HandlerFunc { zap.Duration("latency", latency), } + status := c.Writer.Status() + cause := CancelCause(c) + if errors.Is(cause, ErrRequestTimeout) { + fields = append(fields, zap.Bool("request_timeout", true)) + } else if errors.Is(cause, context.Canceled) { + status = StatusClientClosedRequest // 499 + fields = append(fields, zap.Bool("client_canceled", true)) + } + + fields = append(fields, zap.Int("status", status)) if conf.TimeFormat != "" { fields = append(fields, zap.String("time", end.Format(conf.TimeFormat))) } diff --git a/packages/shared/pkg/middleware/timeout.go b/packages/shared/pkg/middleware/timeout.go new file mode 100644 index 0000000000..7c8546567c --- /dev/null +++ b/packages/shared/pkg/middleware/timeout.go @@ -0,0 +1,67 @@ +package middleware + +import ( + "context" + "errors" + "time" + + "github.com/gin-gonic/gin" +) + +// ErrRequestTimeout is the cancel cause set when the per-request timeout fires. +// Callers can distinguish this from a client disconnection by checking: +// +// errors.Is(CancelCause(c), middleware.ErrRequestTimeout) +var ErrRequestTimeout = errors.New("request timeout exceeded") + +// cancelCauseKey is the gin context key where RequestTimeout snapshots the +// cancel cause before defer-cancel runs. +const cancelCauseKey = "middleware.cancelCause" + +// CancelCause returns the cancel cause captured by the timeout middleware. +// It returns nil for normal (non-canceled/non-timed-out) requests. +func CancelCause(c *gin.Context) error { + if val, exists := c.Get(cancelCauseKey); exists { + if err, ok := val.(error); ok { + return err + } + } + + return nil +} + +// StatusClientClosedRequest is the de-facto status code (used by nginx) for a +// client that closed the connection before the server could send a response. +// It is not an official IANA code but is widely recognised in logs and metrics. +const StatusClientClosedRequest = 499 + +// RequestTimeout returns a Gin middleware that sets a context deadline on each +// request. This is needed because http.Server.WriteTimeout does NOT cancel +// r.Context() (see https://github.com/golang/go/issues/59602), so without an +// explicit deadline, blocking calls like pgxpool.Acquire will wait indefinitely +// when the connection pool is saturated. +// +// After the handler returns, the middleware checks context.Cause to distinguish +// two cancellation scenarios and sets an appropriate status if nothing was +// written yet: +// - server-side timeout (cause is ErrRequestTimeout) → 408 Request Timeout +// - client disconnect (ctx.Err() == context.Canceled) → 499 Client Closed Request +// +// Routes matching any of the excludedRoutes patterns are skipped (useful for +// health checks and long-polling endpoints). +func RequestTimeout(timeout time.Duration) gin.HandlerFunc { + return func(c *gin.Context) { + ctx, cancel := context.WithTimeoutCause(c.Request.Context(), timeout, ErrRequestTimeout) + defer cancel() + + c.Request = c.Request.WithContext(ctx) + c.Next() + + // Snapshot the cause *before* defer-cancel fires so outer + // middlewares can distinguish timeout vs client-disconnect + // via CancelCause(c) without racing with the deferred cancel. + if err := context.Cause(ctx); err != nil { + c.Set(cancelCauseKey, err) + } + } +} diff --git a/packages/api/internal/middleware/timeout_test.go b/packages/shared/pkg/middleware/timeout_test.go similarity index 58% rename from packages/api/internal/middleware/timeout_test.go rename to packages/shared/pkg/middleware/timeout_test.go index 4e4b68adc7..09025d7d49 100644 --- a/packages/api/internal/middleware/timeout_test.go +++ b/packages/shared/pkg/middleware/timeout_test.go @@ -35,27 +35,6 @@ func TestRequestTimeout_SetsDeadline(t *testing.T) { require.Equal(t, http.StatusOK, w.Code) } -func TestRequestTimeout_Returns408WhenHandlerDoesNotWrite(t *testing.T) { - t.Parallel() - - r := gin.New() - r.Use(RequestTimeout(100 * time.Millisecond)) - r.GET("/slow", func(c *gin.Context) { - <-c.Request.Context().Done() - }) - - w := httptest.NewRecorder() - req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/slow", nil) - - start := time.Now() - r.ServeHTTP(w, req) - elapsed := time.Since(start) - - assert.Less(t, elapsed, 500*time.Millisecond, "should have timed out promptly") - require.Equal(t, http.StatusRequestTimeout, w.Code) - assert.Equal(t, "request timed out", w.Body.String()) -} - func TestRequestTimeout_CancelsBlockingHandler(t *testing.T) { t.Parallel() @@ -80,38 +59,53 @@ func TestRequestTimeout_CancelsBlockingHandler(t *testing.T) { assert.Less(t, elapsed, 500*time.Millisecond, "handler should have been unblocked by context timeout") } -func TestRequestTimeout_ExcludedRouteHasNoDeadline(t *testing.T) { +func TestRequestTimeout_NormalRequestContextNotCanceled(t *testing.T) { t.Parallel() + // Simulate an outer middleware that reads CancelCause after c.Next(). + // The context itself will be canceled by defer cancel(), but CancelCause + // should return nil for normal (non-timed-out) requests. + var outerCause error + outerMiddleware := func(c *gin.Context) { + c.Next() + outerCause = CancelCause(c) + } + r := gin.New() - r.Use(RequestTimeout(500*time.Millisecond, "/health")) - r.GET("/health", func(c *gin.Context) { - _, ok := c.Request.Context().Deadline() - assert.False(t, ok, "excluded route should not have a deadline") + r.Use(outerMiddleware) + r.Use(RequestTimeout(5 * time.Second)) + r.GET("/test", func(c *gin.Context) { c.Status(http.StatusOK) }) w := httptest.NewRecorder() - req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/health", nil) + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/test", nil) r.ServeHTTP(w, req) require.Equal(t, http.StatusOK, w.Code) + assert.NoError(t, outerCause, "CancelCause should return nil for normal requests") } -func TestRequestTimeout_ExcludedRouteWithParam(t *testing.T) { +func TestRequestTimeout_TimeoutContextVisibleToOuterMiddleware(t *testing.T) { t.Parallel() + var outerCause error + outerMiddleware := func(c *gin.Context) { + c.Next() + outerCause = CancelCause(c) + } + r := gin.New() - r.Use(RequestTimeout(500*time.Millisecond, "/templates/:templateID/builds/:buildID/logs")) - r.GET("/templates/:templateID/builds/:buildID/logs", func(c *gin.Context) { - _, ok := c.Request.Context().Deadline() - assert.False(t, ok, "excluded parameterized route should not have a deadline") - c.Status(http.StatusOK) + r.Use(outerMiddleware) + r.Use(RequestTimeout(50 * time.Millisecond)) + r.GET("/slow", func(c *gin.Context) { + <-c.Request.Context().Done() }) w := httptest.NewRecorder() - req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/templates/abc123/builds/build456/logs", nil) + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/slow", nil) r.ServeHTTP(w, req) - require.Equal(t, http.StatusOK, w.Code) + require.ErrorIs(t, outerCause, ErrRequestTimeout, + "outer middleware should see ErrRequestTimeout as the cause when the timeout fires") }