diff --git a/packages/api/go.mod b/packages/api/go.mod index aab03dd260..77878e362a 100644 --- a/packages/api/go.mod +++ b/packages/api/go.mod @@ -28,6 +28,7 @@ require ( github.com/gin-contrib/cors v1.7.6 github.com/gin-contrib/size v1.0.2 github.com/gin-gonic/gin v1.10.1 + github.com/go-redis/redis_rate/v10 v10.0.1 github.com/gogo/status v1.1.1 github.com/golang-jwt/jwt/v5 v5.3.1 github.com/golang/protobuf v1.5.4 @@ -36,6 +37,7 @@ require ( github.com/hashicorp/nomad/api v0.0.0-20251216171439-1dee0671280e github.com/jackc/pgx/v5 v5.7.5 github.com/launchdarkly/go-sdk-common/v3 v3.3.0 + github.com/launchdarkly/go-server-sdk/v7 v7.13.0 github.com/oapi-codegen/gin-middleware v1.0.2 github.com/oapi-codegen/runtime v1.1.1 github.com/orcaman/concurrent-map/v2 v2.0.1 @@ -235,7 +237,6 @@ require ( github.com/launchdarkly/go-sdk-events/v3 v3.5.0 // indirect github.com/launchdarkly/go-semver v1.0.3 // indirect github.com/launchdarkly/go-server-sdk-evaluation/v3 v3.0.1 // indirect - github.com/launchdarkly/go-server-sdk/v7 v7.13.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/lib/pq v1.11.2 // indirect github.com/lufia/plan9stats v0.0.0-20240909124753-873cd0166683 // indirect diff --git a/packages/api/go.sum b/packages/api/go.sum index bc90007e0c..7654c3d910 100644 --- a/packages/api/go.sum +++ b/packages/api/go.sum @@ -369,6 +369,8 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.26.0 h1:SP05Nqhjcvz81uJaRfEV0YBSSSGMc/iMaVtFbr3Sw2k= github.com/go-playground/validator/v10 v10.26.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo= +github.com/go-redis/redis_rate/v10 v10.0.1 h1:calPxi7tVlxojKunJwQ72kwfozdy25RjA0bCj1h0MUo= +github.com/go-redis/redis_rate/v10 v10.0.1/go.mod h1:EMiuO9+cjRkR7UvdvwMO7vbgqJkltQHtwbdIQvaBKIU= github.com/go-resty/resty/v2 v2.17.1 h1:x3aMpHK1YM9e4va/TMDRlusDDoZiQ+ViDu/WpA6xTM4= github.com/go-resty/resty/v2 v2.17.1/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= diff --git a/packages/api/internal/handlers/store.go b/packages/api/internal/handlers/store.go index 36e7bd79fa..cfe7ce13a9 100644 --- a/packages/api/internal/handlers/store.go +++ b/packages/api/internal/handlers/store.go @@ -31,7 +31,6 @@ import ( authdb "github.com/e2b-dev/infra/packages/db/pkg/auth" "github.com/e2b-dev/infra/packages/db/pkg/pool" "github.com/e2b-dev/infra/packages/shared/pkg/apierrors" - "github.com/e2b-dev/infra/packages/shared/pkg/factories" "github.com/e2b-dev/infra/packages/shared/pkg/featureflags" "github.com/e2b-dev/infra/packages/shared/pkg/logger" "github.com/e2b-dev/infra/packages/shared/pkg/logs/loki" @@ -61,7 +60,7 @@ type APIStore struct { clusters *clusters.Pool } -func NewAPIStore(ctx context.Context, tel *telemetry.Client, config cfg.Config, serviceName string) *APIStore { +func NewAPIStore(ctx context.Context, tel *telemetry.Client, redisClient redis.UniversalClient, featureFlags *featureflags.Client, config cfg.Config) *APIStore { logger.L().Info(ctx, "Initializing API store and services") sqlcDB, err := sqlcdb.NewClient(ctx, config.PostgresConnectionString, pool.WithMaxConnections(config.DBMaxOpenConnections), pool.WithMinIdle(config.DBMinIdleConnections)) @@ -108,15 +107,6 @@ func NewAPIStore(ctx context.Context, tel *telemetry.Client, config cfg.Config, if err != nil { logger.L().Fatal(ctx, "Initializing Nomad client", zap.Error(err)) } - redisClient, err := factories.NewRedisClient(ctx, factories.RedisConfig{ - RedisURL: config.RedisURL, - RedisClusterURL: config.RedisClusterURL, - RedisTLSCABase64: config.RedisTLSCABase64, - PoolSize: config.RedisPoolSize, - }) - if err != nil { - logger.L().Fatal(ctx, "Initializing Redis client", zap.Error(err)) - } queryLogsProvider, err := loki.NewLokiQueryProvider(config.LokiURL, config.LokiUser, config.LokiPassword) if err != nil { @@ -128,14 +118,6 @@ func NewAPIStore(ctx context.Context, tel *telemetry.Client, config cfg.Config, logger.L().Fatal(ctx, "initializing edge clusters pool failed", zap.Error(err)) } - featureFlags, err := featureflags.NewClient() - if err != nil { - logger.L().Fatal(ctx, "failed to create feature flags client", zap.Error(err)) - } - - featureFlags.SetServiceName(serviceName) - featureFlags.SetDeploymentName(config.DomainName) - accessTokenGenerator, err := sandbox.NewAccessTokenGenerator(config.SandboxAccessTokenHashSeed) if err != nil { logger.L().Fatal(ctx, "Initializing access token generator failed", zap.Error(err)) @@ -248,12 +230,6 @@ func (a *APIStore) Close(ctx context.Context) error { errs = append(errs, fmt.Errorf("closing snapshot cache: %w", err)) } - if a.redisClient != nil { - if err := a.redisClient.Close(); err != nil { - errs = append(errs, fmt.Errorf("closing redis client: %w", err)) - } - } - return errors.Join(errs...) } diff --git a/packages/api/internal/middleware/ratelimit/ratelimit.go b/packages/api/internal/middleware/ratelimit/ratelimit.go new file mode 100644 index 0000000000..705dc43bf3 --- /dev/null +++ b/packages/api/internal/middleware/ratelimit/ratelimit.go @@ -0,0 +1,167 @@ +package ratelimit + +import ( + "context" + "math" + "net/http" + "strconv" + "time" + + "github.com/gin-gonic/gin" + "github.com/go-redis/redis_rate/v10" + "github.com/redis/go-redis/v9" + "go.uber.org/zap" + + "github.com/e2b-dev/infra/packages/auth/pkg/auth" + "github.com/e2b-dev/infra/packages/shared/pkg/featureflags" + "github.com/e2b-dev/infra/packages/shared/pkg/logger" + redis_utils "github.com/e2b-dev/infra/packages/shared/pkg/redis" +) + +const rateLimitPefix = "ratelimit" + +// Config defines the rate limit parameters. +type Config struct { + // Rate is the number of requests allowed per Period. + Rate int + // Burst is the maximum number of requests allowed in a single burst. + Burst int + // Period is the time window for the rate. + Period time.Duration + // FailOpen allows requests through when Redis is unavailable. + FailOpen bool +} + +// DefaultConfig returns a sensible default: 50 req/s with burst of 100. +func DefaultConfig() Config { + return Config{ + Rate: 50, + Burst: 100, + Period: time.Second, + FailOpen: true, + } +} + +// NewLimiter creates a redis_rate.Limiter from a Redis client. +func NewLimiter(redisClient redis.UniversalClient) *redis_rate.Limiter { + return redis_rate.NewLimiter(redisClient) +} + +// Middleware returns a Gin middleware that enforces per-team rate limits +// using the GCRA algorithm backed by Redis (go-redis/redis_rate). +// +// The middleware is gated by the RateLimitEnabledFlag feature flag for +// gradual rollout. Unauthenticated requests are passed through. +// resolveLimit returns the rate limit for the current request, checking the +// RateLimitConfigFlag for per-team overrides. The flag JSON format is: +// +// { +// "/sandboxes/": {"rate": 50, "burst": 100}, +// "/sandboxes/:sandboxID/pause": {"rate": 10, "burst": 20} +// } +// +// The route is the Gin route pattern (c.FullPath()). If no override exists +// for the route (or the flag is null), code defaults are used. +func resolveLimit(ctx context.Context, ff *featureflags.Client, cfg Config, route string) redis_rate.Limit { + rate := cfg.Rate + burst := cfg.Burst + + flagValue := ff.JSONFlag(ctx, featureflags.RateLimitConfigFlag) + if !flagValue.IsNull() { + override := flagValue.GetByKey(route) + if !override.IsNull() { + if v := override.GetByKey("rate"); v.IsInt() { + rate = v.IntValue() + } + + if v := override.GetByKey("burst"); v.IsInt() { + burst = v.IntValue() + } + } + } + + return redis_rate.Limit{ + Rate: rate, + Burst: burst, + Period: cfg.Period, + } +} + +func Middleware(limiter *redis_rate.Limiter, cfg Config, ff *featureflags.Client) gin.HandlerFunc { + return func(c *gin.Context) { + ctx := c.Request.Context() + + // Check feature flag — skip if rate limiting is disabled. + if !ff.BoolFlag(ctx, featureflags.RateLimitEnabledFlag) { + c.Next() + + return + } + + // Skip unauthenticated requests (they'll be rejected by auth middleware). + team, ok := auth.GetTeamInfo(c) + if !ok { + c.Next() + + return + } + + teamID := team.ID.String() + route := c.FullPath() + key := redis_utils.CreateKey(rateLimitPefix, teamID, route) + + // Resolve per-team limit overrides from feature flag. + limit := resolveLimit(ctx, ff, cfg, route) + + // Build a logger with rate limit context for reuse. + l := logger.L().With( + logger.WithTeamID(teamID), + zap.String("route", route), + zap.Int("rate_limit_rate", limit.Rate), + zap.Int("rate_limit_burst", limit.Burst), + ) + + res, err := limiter.Allow(ctx, key, limit) + if err != nil { + l.Warn(ctx, "rate limiter Redis error", zap.Error(err)) + + if cfg.FailOpen { + c.Next() + + return + } + + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{ + "code": http.StatusInternalServerError, + "message": "Rate limiter unavailable", + }) + + return + } + + // Set standard rate limit headers + c.Header("RateLimit-Limit", strconv.Itoa(limit.Burst)) + c.Header("RateLimit-Remaining", strconv.Itoa(res.Remaining)) + c.Header("RateLimit-Reset", strconv.FormatInt(int64(math.Ceil(res.ResetAfter.Seconds())), 10)) + + if res.Allowed > 0 { + c.Next() + + return + } + + // Denied — set Retry-After and return 429. + retryAfterSecs := max(int(res.RetryAfter.Seconds()), 1) + c.Header("Retry-After", strconv.Itoa(retryAfterSecs)) + + l.Warn(ctx, "rate limit exceeded", + zap.Int("remaining", res.Remaining), + zap.Int("retry_after_s", retryAfterSecs), + ) + + c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{ + "code": http.StatusTooManyRequests, + "message": "Rate limit exceeded", + }) + } +} diff --git a/packages/api/internal/middleware/ratelimit/ratelimit_test.go b/packages/api/internal/middleware/ratelimit/ratelimit_test.go new file mode 100644 index 0000000000..43ee8f1f45 --- /dev/null +++ b/packages/api/internal/middleware/ratelimit/ratelimit_test.go @@ -0,0 +1,346 @@ +package ratelimit + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/go-redis/redis_rate/v10" + "github.com/google/uuid" + "github.com/launchdarkly/go-server-sdk/v7/testhelpers/ldtestdata" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/e2b-dev/infra/packages/auth/pkg/auth" + "github.com/e2b-dev/infra/packages/auth/pkg/types" + authqueries "github.com/e2b-dev/infra/packages/db/pkg/auth/queries" + "github.com/e2b-dev/infra/packages/shared/pkg/featureflags" + redis_utils "github.com/e2b-dev/infra/packages/shared/pkg/redis" +) + +func TestMain(m *testing.M) { + gin.SetMode(gin.TestMode) + m.Run() +} + +// newTestFF creates a feature flags client with rate limiting enabled or disabled. +func newTestFF(t *testing.T, enabled bool) *featureflags.Client { + t.Helper() + + td := ldtestdata.DataSource() + td.Update(td.Flag(featureflags.RateLimitEnabledFlag.Key()).VariationForAll(enabled)) + + ff, err := featureflags.NewClientWithDatasource(td) + require.NoError(t, err) + + t.Cleanup(func() { + _ = ff.Close(t.Context()) + }) + + return ff +} + +// doRequest performs a POST /sandboxes/test-sbx/connect. +func doRequest(r *gin.Engine) *httptest.ResponseRecorder { + w := httptest.NewRecorder() + req, _ := http.NewRequestWithContext(context.Background(), http.MethodPost, "/sandboxes/test-sbx/connect", nil) + r.ServeHTTP(w, req) + + return w +} + +// newRouterWithTeam creates a Gin engine that injects a team then applies rate limiting. +func newRouterWithTeam(limiter *redis_rate.Limiter, cfg Config, ff *featureflags.Client, teamID uuid.UUID) *gin.Engine { + r := gin.New() + r.Use(func(c *gin.Context) { + auth.SetTeamInfo(c, &types.Team{ + Team: &authqueries.Team{ID: teamID}, + }) + c.Next() + }) + r.Use(Middleware(limiter, cfg, ff)) + r.POST("/sandboxes/:sandboxID/connect", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"status": "ok"}) + }) + + return r +} + +// --- Unit tests --- + +func TestMiddleware_SkipsWhenFlagDisabled(t *testing.T) { + t.Parallel() + + ff := newTestFF(t, false) + // Unreachable Redis — shouldn't matter since flag is off. + badClient := redis.NewClient(&redis.Options{Addr: "localhost:1"}) + defer badClient.Close() + + limiter := redis_rate.NewLimiter(badClient) + r := newRouterWithTeam(limiter, DefaultConfig(), ff, uuid.New()) + + w := doRequest(r) + assert.Equal(t, http.StatusOK, w.Code) + // No rate limit headers should be set when flag is off. + assert.Empty(t, w.Header().Get("X-RateLimit-Limit")) +} + +func TestMiddleware_SkipsUnauthenticated(t *testing.T) { + t.Parallel() + + ff := newTestFF(t, true) + // Unreachable Redis — shouldn't matter since no team is set. + badClient := redis.NewClient(&redis.Options{Addr: "localhost:1"}) + defer badClient.Close() + + limiter := redis_rate.NewLimiter(badClient) + + r := gin.New() + // No team set — unauthenticated. + r.Use(Middleware(limiter, DefaultConfig(), ff)) + r.POST("/sandboxes/:sandboxID/connect", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"status": "ok"}) + }) + + w := doRequest(r) + assert.Equal(t, http.StatusOK, w.Code) +} + +func TestMiddleware_FailOpen(t *testing.T) { + t.Parallel() + + ff := newTestFF(t, true) + // Unreachable Redis. + badClient := redis.NewClient(&redis.Options{ + Addr: "localhost:1", + DialTimeout: 10 * time.Millisecond, + }) + defer badClient.Close() + + limiter := redis_rate.NewLimiter(badClient) + r := newRouterWithTeam(limiter, Config{ + Rate: 10, + Burst: 10, + Period: time.Second, + FailOpen: true, + }, ff, uuid.New()) + + w := doRequest(r) + assert.Equal(t, http.StatusOK, w.Code) +} + +func TestMiddleware_FailClosed(t *testing.T) { + t.Parallel() + + ff := newTestFF(t, true) + badClient := redis.NewClient(&redis.Options{ + Addr: "localhost:1", + DialTimeout: 10 * time.Millisecond, + }) + defer badClient.Close() + + limiter := redis_rate.NewLimiter(badClient) + r := newRouterWithTeam(limiter, Config{ + Rate: 10, + Burst: 10, + Period: time.Second, + FailOpen: false, + }, ff, uuid.New()) + + w := doRequest(r) + assert.Equal(t, http.StatusInternalServerError, w.Code) +} + +// --- Integration tests (real Redis) --- + +func TestIntegration_AllowedRequestSetsHeaders(t *testing.T) { + t.Parallel() + + if testing.Short() { + t.Skip("skipping integration test") + } + + redisClient := redis_utils.SetupInstance(t) + limiter := redis_rate.NewLimiter(redisClient) + ff := newTestFF(t, true) + + r := newRouterWithTeam(limiter, Config{ + Rate: 10, + Burst: 20, + Period: time.Second, + FailOpen: true, + }, ff, uuid.New()) + + w := doRequest(r) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "20", w.Header().Get("RateLimit-Limit")) + assert.NotEmpty(t, w.Header().Get("RateLimit-Remaining")) + assert.NotEmpty(t, w.Header().Get("RateLimit-Reset")) +} + +func TestIntegration_BurstThenDeny(t *testing.T) { + t.Parallel() + + if testing.Short() { + t.Skip("skipping integration test") + } + + redisClient := redis_utils.SetupInstance(t) + limiter := redis_rate.NewLimiter(redisClient) + ff := newTestFF(t, true) + + r := newRouterWithTeam(limiter, Config{ + Rate: 1, + Burst: 3, + Period: time.Second, + FailOpen: true, + }, ff, uuid.New()) + + // First 3 requests should succeed (burst). + for i := range 3 { + w := doRequest(r) + assert.Equal(t, http.StatusOK, w.Code, "request %d should be allowed", i+1) + } + + // 4th should be denied. + w := doRequest(r) + assert.Equal(t, http.StatusTooManyRequests, w.Code) + assert.NotEmpty(t, w.Header().Get("Retry-After")) + + var body struct { + Code int `json:"code"` + Message string `json:"message"` + } + err := json.NewDecoder(w.Body).Decode(&body) + require.NoError(t, err) + assert.Equal(t, http.StatusTooManyRequests, body.Code) + assert.Equal(t, "Rate limit exceeded", body.Message) +} + +func TestIntegration_Refill(t *testing.T) { + t.Parallel() + + if testing.Short() { + t.Skip("skipping integration test") + } + + redisClient := redis_utils.SetupInstance(t) + limiter := redis_rate.NewLimiter(redisClient) + ff := newTestFF(t, true) + + r := newRouterWithTeam(limiter, Config{ + Rate: 10, + Burst: 2, + Period: time.Second, + FailOpen: true, + }, ff, uuid.New()) + + // Exhaust burst. + for range 2 { + w := doRequest(r) + assert.Equal(t, http.StatusOK, w.Code) + } + w := doRequest(r) + assert.Equal(t, http.StatusTooManyRequests, w.Code) + + // Wait for refill (rate=10/s → one token every 100ms). + time.Sleep(200 * time.Millisecond) + + w = doRequest(r) + assert.Equal(t, http.StatusOK, w.Code) +} + +func TestIntegration_IndependentTeams(t *testing.T) { + t.Parallel() + + if testing.Short() { + t.Skip("skipping integration test") + } + + redisClient := redis_utils.SetupInstance(t) + limiter := redis_rate.NewLimiter(redisClient) + ff := newTestFF(t, true) + + cfg := Config{ + Rate: 1, + Burst: 1, + Period: time.Second, + FailOpen: true, + } + + teamA := uuid.New() + teamB := uuid.New() + + rA := newRouterWithTeam(limiter, cfg, ff, teamA) + rB := newRouterWithTeam(limiter, cfg, ff, teamB) + + // Team A uses its quota. + w := doRequest(rA) + assert.Equal(t, http.StatusOK, w.Code) + w = doRequest(rA) + assert.Equal(t, http.StatusTooManyRequests, w.Code) + + // Team B should still have quota. + w = doRequest(rB) + assert.Equal(t, http.StatusOK, w.Code) +} + +func TestIntegration_ConcurrentAccess(t *testing.T) { + t.Parallel() + + if testing.Short() { + t.Skip("skipping integration test") + } + + redisClient := redis_utils.SetupInstance(t) + limiter := redis_rate.NewLimiter(redisClient) + ff := newTestFF(t, true) + + burst := 10 + cfg := Config{ + Rate: 1, + Burst: burst, + Period: time.Minute, // slow refill so burst is the effective limit + FailOpen: true, + } + + r := newRouterWithTeam(limiter, cfg, ff, uuid.New()) + + // Fire 20 concurrent requests; only `burst` should be allowed. + total := 20 + results := make([]int, total) + + var wg sync.WaitGroup + for i := range total { + wg.Add(1) + go func(idx int) { + defer wg.Done() + w := doRequest(r) + results[idx] = w.Code + }(i) + } + wg.Wait() + + allowed := 0 + denied := 0 + for _, code := range results { + switch code { + case http.StatusOK: + allowed++ + case http.StatusTooManyRequests: + denied++ + default: + t.Errorf("unexpected status code: %d", code) + } + } + + assert.Equal(t, burst, allowed, "exactly burst requests should be allowed") + assert.Equal(t, total-burst, denied, "remaining requests should be denied") +} diff --git a/packages/api/main.go b/packages/api/main.go index 9b41c80e45..5db9055a91 100644 --- a/packages/api/main.go +++ b/packages/api/main.go @@ -23,6 +23,7 @@ import ( "github.com/gin-gonic/gin" "github.com/google/uuid" middleware "github.com/oapi-codegen/gin-middleware" + "github.com/redis/go-redis/v9" "go.uber.org/zap" "go.uber.org/zap/zapcore" @@ -32,10 +33,13 @@ import ( customMiddleware "github.com/e2b-dev/infra/packages/api/internal/middleware" metricsMiddleware "github.com/e2b-dev/infra/packages/api/internal/middleware/otel/metrics" tracingMiddleware "github.com/e2b-dev/infra/packages/api/internal/middleware/otel/tracing" + "github.com/e2b-dev/infra/packages/api/internal/middleware/ratelimit" "github.com/e2b-dev/infra/packages/api/internal/utils" "github.com/e2b-dev/infra/packages/auth/pkg/auth" sqlcdb "github.com/e2b-dev/infra/packages/db/client" "github.com/e2b-dev/infra/packages/shared/pkg/env" + "github.com/e2b-dev/infra/packages/shared/pkg/factories" + "github.com/e2b-dev/infra/packages/shared/pkg/featureflags" e2bgrpc "github.com/e2b-dev/infra/packages/shared/pkg/grpc" proxygrpc "github.com/e2b-dev/infra/packages/shared/pkg/grpc/proxy" "github.com/e2b-dev/infra/packages/shared/pkg/logger" @@ -73,7 +77,7 @@ var ( expectedMigrationTimestamp string ) -func NewGinServer(ctx context.Context, config cfg.Config, tel *telemetry.Client, l logger.Logger, apiStore *handlers.APIStore, swagger *openapi3.T, port int) *http.Server { +func NewGinServer(ctx context.Context, config cfg.Config, tel *telemetry.Client, l logger.Logger, apiStore *handlers.APIStore, redisClient redis.UniversalClient, ff *featureflags.Client, swagger *openapi3.T, port int) *http.Server { // Clear out the servers array in the swagger spec, that skips validating // that server names match. We don't know how this thing will be run. swagger.Servers = nil @@ -196,6 +200,11 @@ func NewGinServer(ctx context.Context, config cfg.Config, tel *telemetry.Client, r.Use(customMiddleware.InitLaunchDarklyContext) + // Per-team rate limiting (after auth + LD context, before handlers). + // Only applied to connect and resume endpoints. Gated by feature flag. + limiter := ratelimit.NewLimiter(redisClient) + r.Use(ratelimit.Middleware(limiter, ratelimit.DefaultConfig(), ff)) //nolint:contextcheck // Gin middleware sets context via c.Request.WithContext + // We now register our store above as the handler for the interface api.RegisterHandlersWithOptions(r, apiStore, api.GinServerOptions{ ErrorHandler: func(c *gin.Context, err error, statusCode int) { @@ -373,10 +382,32 @@ func run() int { cleanup := func() { cleanupOnce.Do(cleanupOp) } defer cleanup() + redisClient, err := factories.NewRedisClient(ctx, factories.RedisConfig{ + RedisURL: config.RedisURL, + RedisClusterURL: config.RedisClusterURL, + RedisTLSCABase64: config.RedisTLSCABase64, + PoolSize: config.RedisPoolSize, + }) + if err != nil { + logger.L().Fatal(ctx, "Initializing Redis client", zap.Error(err)) + } + cleanupFns = append(cleanupFns, func(_ context.Context) error { + return redisClient.Close() + }) + + featureFlags, err := featureflags.NewClient() + if err != nil { + logger.L().Fatal(ctx, "failed to create feature flags client", zap.Error(err)) + } + cleanupFns = append(cleanupFns, featureFlags.Close) + + featureFlags.SetServiceName(serviceName) + featureFlags.SetDeploymentName(config.DomainName) + // Create an instance of our handler which satisfies the generated interface // (use the outer context rather than the signal handling // context so it doesn't exit first.) - apiStore := handlers.NewAPIStore(ctx, tel, config, serviceName) + apiStore := handlers.NewAPIStore(ctx, tel, redisClient, featureFlags, config) cleanupFns = append(cleanupFns, apiStore.Close) grpcAddr := fmt.Sprintf("0.0.0.0:%d", config.APIGrpcPort) @@ -389,7 +420,7 @@ func run() int { proxygrpc.RegisterSandboxServiceServer(grpcServer, handlers.NewSandboxService(apiStore)) // pass the signal context so that handlers know when shutdown is happening. - s := NewGinServer(ctx, config, tel, l, apiStore, swagger, port) + s := NewGinServer(ctx, config, tel, l, apiStore, redisClient, featureFlags, swagger, port) // //////////////////////// // diff --git a/packages/shared/pkg/featureflags/flags.go b/packages/shared/pkg/featureflags/flags.go index 6f01025401..fe498080f0 100644 --- a/packages/shared/pkg/featureflags/flags.go +++ b/packages/shared/pkg/featureflags/flags.go @@ -60,6 +60,17 @@ func newJSONFlag(name string, fallback ldvalue.Value) JSONFlag { var CleanNFSCache = newJSONFlag("clean-nfs-cache", ldvalue.Null()) +// RateLimitConfigFlag provides per-team rate limit overrides. +// JSON format: +// +// { +// "/sandboxes/": {"rate": 50, "burst": 100}, +// "/sandboxes/:sandboxID/pause": {"rate": 10, "burst": 20} +// } +// +// When non-null, values override the code defaults. Target specific teams in LaunchDarkly. +var RateLimitConfigFlag = newJSONFlag("rate-limit-config", ldvalue.Null()) + type BoolFlag struct { name string fallback bool @@ -109,6 +120,10 @@ var ( PersistentVolumesFlag = newBoolFlag("can-use-persistent-volumes", env.IsDevelopment()) ExecutionMetricsOnWebhooksFlag = newBoolFlag("execution-metrics-on-webhooks", false) // TODO: Remove NLT 20250315 SandboxLabelBasedSchedulingFlag = newBoolFlag("sandbox-label-based-scheduling", false) + + // RateLimitEnabledFlag gates the per-team rate limiting middleware. + // Evaluated per-request with the team's LD context. Roll out by targeting tiers/teams. + RateLimitEnabledFlag = newBoolFlag("rate-limit-enabled", false) ) type IntFlag struct {