Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion packages/api/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ require (
github.com/gin-contrib/cors v1.7.6
github.com/gin-contrib/size v1.0.2
github.com/gin-gonic/gin v1.10.1
github.com/go-redis/redis_rate/v10 v10.0.1
github.com/gogo/status v1.1.1
github.com/golang-jwt/jwt/v5 v5.3.1
github.com/golang/protobuf v1.5.4
Expand All @@ -36,6 +37,7 @@ require (
github.com/hashicorp/nomad/api v0.0.0-20251216171439-1dee0671280e
github.com/jackc/pgx/v5 v5.7.5
github.com/launchdarkly/go-sdk-common/v3 v3.3.0
github.com/launchdarkly/go-server-sdk/v7 v7.13.0
github.com/oapi-codegen/gin-middleware v1.0.2
github.com/oapi-codegen/runtime v1.1.1
github.com/orcaman/concurrent-map/v2 v2.0.1
Expand Down Expand Up @@ -235,7 +237,6 @@ require (
github.com/launchdarkly/go-sdk-events/v3 v3.5.0 // indirect
github.com/launchdarkly/go-semver v1.0.3 // indirect
github.com/launchdarkly/go-server-sdk-evaluation/v3 v3.0.1 // indirect
github.com/launchdarkly/go-server-sdk/v7 v7.13.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/lib/pq v1.11.2 // indirect
github.com/lufia/plan9stats v0.0.0-20240909124753-873cd0166683 // indirect
Expand Down
2 changes: 2 additions & 0 deletions packages/api/go.sum

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 1 addition & 25 deletions packages/api/internal/handlers/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import (
authdb "github.com/e2b-dev/infra/packages/db/pkg/auth"
"github.com/e2b-dev/infra/packages/db/pkg/pool"
"github.com/e2b-dev/infra/packages/shared/pkg/apierrors"
"github.com/e2b-dev/infra/packages/shared/pkg/factories"
"github.com/e2b-dev/infra/packages/shared/pkg/featureflags"
"github.com/e2b-dev/infra/packages/shared/pkg/logger"
"github.com/e2b-dev/infra/packages/shared/pkg/logs/loki"
Expand Down Expand Up @@ -61,7 +60,7 @@ type APIStore struct {
clusters *clusters.Pool
}

func NewAPIStore(ctx context.Context, tel *telemetry.Client, config cfg.Config, serviceName string) *APIStore {
func NewAPIStore(ctx context.Context, tel *telemetry.Client, redisClient redis.UniversalClient, featureFlags *featureflags.Client, config cfg.Config) *APIStore {
logger.L().Info(ctx, "Initializing API store and services")

sqlcDB, err := sqlcdb.NewClient(ctx, config.PostgresConnectionString, pool.WithMaxConnections(config.DBMaxOpenConnections), pool.WithMinIdle(config.DBMinIdleConnections))
Expand Down Expand Up @@ -108,15 +107,6 @@ func NewAPIStore(ctx context.Context, tel *telemetry.Client, config cfg.Config,
if err != nil {
logger.L().Fatal(ctx, "Initializing Nomad client", zap.Error(err))
}
redisClient, err := factories.NewRedisClient(ctx, factories.RedisConfig{
RedisURL: config.RedisURL,
RedisClusterURL: config.RedisClusterURL,
RedisTLSCABase64: config.RedisTLSCABase64,
PoolSize: config.RedisPoolSize,
})
if err != nil {
logger.L().Fatal(ctx, "Initializing Redis client", zap.Error(err))
}

queryLogsProvider, err := loki.NewLokiQueryProvider(config.LokiURL, config.LokiUser, config.LokiPassword)
if err != nil {
Expand All @@ -128,14 +118,6 @@ func NewAPIStore(ctx context.Context, tel *telemetry.Client, config cfg.Config,
logger.L().Fatal(ctx, "initializing edge clusters pool failed", zap.Error(err))
}

featureFlags, err := featureflags.NewClient()
if err != nil {
logger.L().Fatal(ctx, "failed to create feature flags client", zap.Error(err))
}

featureFlags.SetServiceName(serviceName)
featureFlags.SetDeploymentName(config.DomainName)

accessTokenGenerator, err := sandbox.NewAccessTokenGenerator(config.SandboxAccessTokenHashSeed)
if err != nil {
logger.L().Fatal(ctx, "Initializing access token generator failed", zap.Error(err))
Expand Down Expand Up @@ -248,12 +230,6 @@ func (a *APIStore) Close(ctx context.Context) error {
errs = append(errs, fmt.Errorf("closing snapshot cache: %w", err))
}

if a.redisClient != nil {
if err := a.redisClient.Close(); err != nil {
errs = append(errs, fmt.Errorf("closing redis client: %w", err))
}
}

return errors.Join(errs...)
}

Expand Down
167 changes: 167 additions & 0 deletions packages/api/internal/middleware/ratelimit/ratelimit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
package ratelimit

import (
"context"
"math"
"net/http"
"strconv"
"time"

"github.com/gin-gonic/gin"
"github.com/go-redis/redis_rate/v10"
"github.com/redis/go-redis/v9"
"go.uber.org/zap"

"github.com/e2b-dev/infra/packages/auth/pkg/auth"
"github.com/e2b-dev/infra/packages/shared/pkg/featureflags"
"github.com/e2b-dev/infra/packages/shared/pkg/logger"
redis_utils "github.com/e2b-dev/infra/packages/shared/pkg/redis"
)

const rateLimitPefix = "ratelimit"

// Config defines the rate limit parameters.
type Config struct {
// Rate is the number of requests allowed per Period.
Rate int
// Burst is the maximum number of requests allowed in a single burst.
Burst int
// Period is the time window for the rate.
Period time.Duration
// FailOpen allows requests through when Redis is unavailable.
FailOpen bool
}

// DefaultConfig returns a sensible default: 50 req/s with burst of 100.
func DefaultConfig() Config {
return Config{
Rate: 50,
Burst: 100,
Period: time.Second,
FailOpen: true,
}
}

// NewLimiter creates a redis_rate.Limiter from a Redis client.
func NewLimiter(redisClient redis.UniversalClient) *redis_rate.Limiter {
return redis_rate.NewLimiter(redisClient)
}

// Middleware returns a Gin middleware that enforces per-team rate limits
// using the GCRA algorithm backed by Redis (go-redis/redis_rate).
//
// The middleware is gated by the RateLimitEnabledFlag feature flag for
// gradual rollout. Unauthenticated requests are passed through.
// resolveLimit returns the rate limit for the current request, checking the
// RateLimitConfigFlag for per-team overrides. The flag JSON format is:
//
// {
// "/sandboxes/": {"rate": 50, "burst": 100},
// "/sandboxes/:sandboxID/pause": {"rate": 10, "burst": 20}
// }
//
// The route is the Gin route pattern (c.FullPath()). If no override exists
// for the route (or the flag is null), code defaults are used.
func resolveLimit(ctx context.Context, ff *featureflags.Client, cfg Config, route string) redis_rate.Limit {
rate := cfg.Rate
burst := cfg.Burst

flagValue := ff.JSONFlag(ctx, featureflags.RateLimitConfigFlag)
if !flagValue.IsNull() {
override := flagValue.GetByKey(route)
if !override.IsNull() {
if v := override.GetByKey("rate"); v.IsInt() {
rate = v.IntValue()
}

if v := override.GetByKey("burst"); v.IsInt() {
burst = v.IntValue()
}
}
}

return redis_rate.Limit{
Rate: rate,
Burst: burst,
Period: cfg.Period,
}
}

func Middleware(limiter *redis_rate.Limiter, cfg Config, ff *featureflags.Client) gin.HandlerFunc {
return func(c *gin.Context) {
ctx := c.Request.Context()

// Check feature flag — skip if rate limiting is disabled.
if !ff.BoolFlag(ctx, featureflags.RateLimitEnabledFlag) {
c.Next()

return
}

// Skip unauthenticated requests (they'll be rejected by auth middleware).
team, ok := auth.GetTeamInfo(c)
if !ok {
c.Next()

return
}

teamID := team.ID.String()
route := c.FullPath()
key := redis_utils.CreateKey(rateLimitPefix, teamID, route)

// Resolve per-team limit overrides from feature flag.
limit := resolveLimit(ctx, ff, cfg, route)
Comment thread
dobrac marked this conversation as resolved.

// Build a logger with rate limit context for reuse.
l := logger.L().With(
logger.WithTeamID(teamID),
zap.String("route", route),
zap.Int("rate_limit_rate", limit.Rate),
zap.Int("rate_limit_burst", limit.Burst),
)

res, err := limiter.Allow(ctx, key, limit)
if err != nil {
l.Warn(ctx, "rate limiter Redis error", zap.Error(err))

if cfg.FailOpen {
c.Next()

return
}

c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{
"code": http.StatusInternalServerError,
"message": "Rate limiter unavailable",
})

return
}

// Set standard rate limit headers
c.Header("RateLimit-Limit", strconv.Itoa(limit.Burst))
c.Header("RateLimit-Remaining", strconv.Itoa(res.Remaining))
c.Header("RateLimit-Reset", strconv.FormatInt(int64(math.Ceil(res.ResetAfter.Seconds())), 10))

if res.Allowed > 0 {
c.Next()

return
}

// Denied — set Retry-After and return 429.
retryAfterSecs := max(int(res.RetryAfter.Seconds()), 1)
c.Header("Retry-After", strconv.Itoa(retryAfterSecs))

l.Warn(ctx, "rate limit exceeded",
zap.Int("remaining", res.Remaining),
zap.Int("retry_after_s", retryAfterSecs),

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets add also the rate limit fields, ideally set them on the logger early and reuse the instance

)

c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
"code": http.StatusTooManyRequests,
"message": "Rate limit exceeded",
})
}
}
Loading
Loading