diff --git a/go.work.sum b/go.work.sum index 59865b4..21cabc8 100644 --- a/go.work.sum +++ b/go.work.sum @@ -11,6 +11,7 @@ github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdko github.com/StackExchange/wmi v1.2.1/go.mod h1:rcmrprowKIVzvc+NUiLncP2uuArMWLCbu9SBzvHz7e8= github.com/VictoriaMetrics/fastcache v1.6.0/go.mod h1:0qHz5QP0GMX4pfmMA/zt5RgfNuXJrTP0zS7DqpHGGTw= github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b/go.mod h1:1KcenG0jGWcpt8ov532z81sp/kMMUG485J2InIOyADM= +github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/aws/aws-sdk-go-v2 v1.2.0/go.mod h1:zEQs02YRBw1DjK0PoJv3ygDYOFTre1ejlJWl8FwAuQo= github.com/aws/aws-sdk-go-v2/config v1.1.1/go.mod h1:0XsVy9lBI/BCXm+2Tuvt39YmdHwS5unDQmxZOYe8F5Y= github.com/aws/aws-sdk-go-v2/credentials v1.1.1/go.mod h1:mM2iIjwl7LULWtS6JCACyInboHirisUUdkBPoTHMOUo= @@ -39,6 +40,7 @@ github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:ma github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/crate-crypto/go-ipa v0.0.0-20220523130400-f11357ae11c7/go.mod h1:gFnFS95y8HstDP6P9pPwzrxOOC5TRDkwbM+ao15ChAI= github.com/deepmap/oapi-codegen v1.8.2/go.mod h1:YLgSKSDv/bZQB7N4ws6luhozi3cEdRktEqrX88CvjIw= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dlclark/regexp2 v1.7.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/docker/docker v1.6.2/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/dop251/goja v0.0.0-20230122112309-96b1610dd4f7/go.mod h1:yRkwfj0CBpOGre+TwBsqPV0IH0Pk73e4PXJOeNDboGs= @@ -72,6 +74,7 @@ github.com/jedisct1/go-minisign v0.0.0-20190909160543-45766022959e/go.mod h1:G1C github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/karalabe/usb v0.0.2/go.mod h1:Od972xHfMJowv7NGVDiWVxk2zxnWgjLlJzE+F4F7AGU= github.com/klauspost/compress v1.15.15/go.mod h1:ZcK2JAFqKOpnBlxcLsJzYfrS9X1akm9fHZNnD9+Vo/4= +github.com/klauspost/compress v1.16.3/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= @@ -88,6 +91,7 @@ github.com/prometheus/client_golang v1.14.0/go.mod h1:8vpkKitgIVNcqrRBWh1C4TIUQg github.com/prometheus/client_model v0.3.0/go.mod h1:LDGWKZIo7rky3hgvBe+caln+Dr3dPggB5dvjtD7w9+w= github.com/prometheus/common v0.39.0/go.mod h1:6XBZ7lYdLCbkAVhwRsWTZn+IN5AB9F/NXd5w0BbEX0Y= github.com/prometheus/procfs v0.9.0/go.mod h1:+pB4zwohETzFnmlpe6yd2lSc+0/46IYZRB/chUwxUZY= +github.com/redis/go-redis/v9 v9.0.4/go.mod h1:WqMKv5vnQbRuZstUwxQI195wHy+t4PuXDOjzMvcuQHk= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= @@ -99,6 +103,8 @@ github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7/go.mod h1:q4W45 github.com/tyler-smith/go-bip39 v1.1.0/go.mod h1:gUYDtqQw1JS3ZJ8UWVcGTGqqr6YIN3CWg+kkNaLt55U= github.com/urfave/cli/v2 v2.3.0/go.mod h1:LJmUH05zAU44vOAcrfzZQKsZbVcdbOG8rtL3/XcUArI= github.com/urfave/cli/v2 v2.17.2-0.20221006022127-8f469abc00aa/go.mod h1:1CNUng3PtjQMtRzJO4FMXBQvkGtuYRxxiR9xMa7jMwI= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.47.0/go.mod h1:k2zXd82h/7UZc3VOdJ2WaUqt1uZ/XpXAfE9i+HBC3lA= github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8= golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= diff --git a/node-registrar/cmds/main.go b/node-registrar/cmds/main.go index 9a5d0cb..59651cf 100644 --- a/node-registrar/cmds/main.go +++ b/node-registrar/cmds/main.go @@ -6,6 +6,7 @@ import ( "net" "os" "strings" + "time" "github.com/pkg/errors" "github.com/rs/zerolog" @@ -23,6 +24,11 @@ type flags struct { serverPort uint network string adminTwinID uint64 + + // Rate limiter configuration + rateLimitEnabled bool + rateLimitRequests uint64 + rateLimitPeriodSecs uint64 } // These variables are set during build time using ldflags @@ -62,6 +68,10 @@ func Run() error { flag.StringVar(&f.network, "network", "dev", "the registrar network") flag.Uint64Var(&f.adminTwinID, "admin-twin-id", 1, "admin twin ID") + flag.BoolVar(&f.rateLimitEnabled, "rate-limit-enabled", true, "enable rate limiting") + flag.Uint64Var(&f.rateLimitRequests, "rate-limit-requests", 100, "number of requests allowed per period") + flag.Uint64Var(&f.rateLimitPeriodSecs, "rate-limit-period", 1, "rate limit period in seconds") + flag.Parse() f.SqlLogLevel = logger.LogLevel(sqlLogLevel) @@ -92,7 +102,14 @@ func Run() error { } }() - s := server.NewServer(db, f.network, f.adminTwinID) + // Create rate limiter configuration from flags + rateLimiterConfig := server.RateLimiterConfig{ + Enabled: f.rateLimitEnabled, + Requests: f.rateLimitRequests, + Period: time.Duration(f.rateLimitPeriodSecs) * time.Second, + } + + s := server.NewServer(db, f.network, f.adminTwinID, rateLimiterConfig) log.Info().Msgf("server is running on port :%d", f.serverPort) @@ -120,6 +137,15 @@ func (f flags) validate() error { return errors.Errorf("invalid admin twin id %d, admin twin id should not be 0", f.adminTwinID) } + // Validate rate limiter configuration + if f.rateLimitRequests == 0 { + return errors.Errorf("invalid rate limit requests %d, rate limit requests should be greater than 0", f.rateLimitRequests) + } + + if f.rateLimitPeriodSecs == 0 { + return errors.Errorf("invalid rate limit period %d, rate limit period should be greater than 0", f.rateLimitPeriodSecs) + } + if _, err := net.LookupHost(f.domain); err != nil { return errors.Wrapf(err, "invalid domain %s", f.domain) } diff --git a/node-registrar/go.mod b/node-registrar/go.mod index 693faef..8b334f0 100644 --- a/node-registrar/go.mod +++ b/node-registrar/go.mod @@ -61,6 +61,7 @@ require ( github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect + github.com/ulule/limiter/v3 v3.11.2 // indirect golang.org/x/arch v0.12.0 // indirect golang.org/x/crypto v0.19.0 // indirect golang.org/x/net v0.19.0 // indirect diff --git a/node-registrar/go.sum b/node-registrar/go.sum index 9025429..c2c7e28 100644 --- a/node-registrar/go.sum +++ b/node-registrar/go.sum @@ -65,6 +65,7 @@ github.com/goccy/go-json v0.10.4/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PU github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/gtank/merlin v0.1.1-0.20191105220539-8318aed1a79f/go.mod h1:T86dnYJhcGOh5BjZFCJWTDeTK7XW8uE+E21Cy/bIQ+s= github.com/gtank/merlin v0.1.1 h1:eQ90iG7K9pOhtereWsmyRJ6RAwcP4tHTDBHXNg+u5is= @@ -156,6 +157,8 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +github.com/ulule/limiter/v3 v3.11.2 h1:P4yOrxoEMJbOTfRJR2OzjL90oflzYPPmWg+dvwN2tHA= +github.com/ulule/limiter/v3 v3.11.2/go.mod h1:QG5GnFOCV+k7lrL5Y8kgEeeflPH3+Cviqlqa8SVSQxI= github.com/vedhavyas/go-subkey/v2 v2.0.0 h1:LemDIsrVtRSOkp0FA8HxP6ynfKjeOj3BY2U9UNfeDMA= github.com/vedhavyas/go-subkey/v2 v2.0.0/go.mod h1:95aZ+XDCWAUUynjlmi7BtPExjXgXxByE0WfBwbmIRH4= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= @@ -210,6 +213,7 @@ golang.org/x/tools v0.16.1/go.mod h1:kYVVN6I1mBNoB1OX+noeBjbRk4IUEPa7JJ+TJMEooJ0 golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= google.golang.org/protobuf v1.36.1 h1:yBPeRvTftaleIgM3PZ/WBIZ7XM/eEYAaEyCwvyjq/gk= google.golang.org/protobuf v1.36.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/node-registrar/pkg/server/middlewares.go b/node-registrar/pkg/server/middlewares.go index 0c65ffc..4e9f558 100644 --- a/node-registrar/pkg/server/middlewares.go +++ b/node-registrar/pkg/server/middlewares.go @@ -13,6 +13,9 @@ import ( "github.com/gin-gonic/gin" "github.com/rs/zerolog/log" "github.com/threefoldtech/tfgrid4-sdk-go/node-registrar/pkg/db" + "github.com/ulule/limiter/v3" + ginlimiter "github.com/ulule/limiter/v3/drivers/middleware/gin" + "github.com/ulule/limiter/v3/drivers/store/memory" ) // twinKeyID is where the twin key is stored @@ -23,6 +26,31 @@ const ( ChallengeValidity = 1 * time.Minute ) +type RateLimiterConfig struct { + Enabled bool + Requests uint64 + Period time.Duration +} + +// RateLimitMiddleware creates a rate limiting middleware with configuration +func (s *Server) RateLimitMiddleware(config RateLimiterConfig) gin.HandlerFunc { + if !config.Enabled { + return gin.HandlerFunc(func(c *gin.Context) { + c.Next() + }) + } + + store := memory.NewStore() + + rate := limiter.Rate{ + Period: config.Period, + Limit: (int64)(config.Requests), + } + + instance := limiter.New(store, rate) + return ginlimiter.NewMiddleware(instance) +} + // AuthMiddleware is a middleware function that authenticates incoming requests based on the X-Auth header. // It verifies the challenge and signature provided in the header against the account's public key stored in the database. // If the authentication fails, it aborts the request with an appropriate error status and message. diff --git a/node-registrar/pkg/server/routes.go b/node-registrar/pkg/server/routes.go index 0e4b547..ab7f245 100644 --- a/node-registrar/pkg/server/routes.go +++ b/node-registrar/pkg/server/routes.go @@ -12,6 +12,7 @@ import ( ) func (s *Server) SetupRoutes() { + s.router.Use(cors.New(cors.Config{ AllowOrigins: []string{"*"}, AllowMethods: []string{"POST", "OPTIONS", "GET", "PUT", "DELETE"}, @@ -22,7 +23,7 @@ func (s *Server) SetupRoutes() { })) s.router.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler)) - + s.router.Use(s.RateLimitMiddleware(s.rateLimiterConfig)) s.registerRoutes(s.router.Group("/api/v1")) s.registerRoutes(s.router.Group("/v1")) } diff --git a/node-registrar/pkg/server/server.go b/node-registrar/pkg/server/server.go index c7a2ca0..fc7370b 100644 --- a/node-registrar/pkg/server/server.go +++ b/node-registrar/pkg/server/server.go @@ -16,17 +16,18 @@ import ( ) type Server struct { - router *gin.Engine - db db.Database - network string - adminTwinID uint64 + router *gin.Engine + db db.Database + network string + adminTwinID uint64 + rateLimiterConfig RateLimiterConfig } -func NewServer(db db.Database, network string, adminTwinID uint64) Server { +func NewServer(db db.Database, network string, adminTwinID uint64, rateLimiterConfig RateLimiterConfig) Server { router := gin.Default() router.RedirectTrailingSlash = true - server := Server{router, db, network, adminTwinID} + server := Server{router, db, network, adminTwinID, rateLimiterConfig} server.SetupRoutes() return server