diff --git a/pkg/arke/arke.go b/pkg/arke/arke.go index 2a88c4e..9716ff1 100644 --- a/pkg/arke/arke.go +++ b/pkg/arke/arke.go @@ -251,7 +251,11 @@ func (a Arke) listener() (net.Listener, error) { } tlsCfg, err := a.tlsConfig() - if tlsCfg != nil && err == nil { + if err != nil { + lis.Close() + return nil, err + } + if tlsCfg != nil { lis = tls.NewListener(lis, tlsCfg) } @@ -288,9 +292,15 @@ func (a *Arke) Serve(ctx context.Context) error { c := make(chan os.Signal, 1) signal.Notify(c, os.Interrupt) + defer signal.Stop(c) go func() { - for range c { - a.server.Stop() + for { + select { + case <-c: + a.server.Stop() + case <-ctx.Done(): + return + } } }() @@ -338,7 +348,7 @@ func (a *Arke) Serve(ctx context.Context) error { if a.ratelimiter != nil { go a.ratelimiter.StartClientCull(ctx) } - serveErrChan := make(chan error) + serveErrChan := make(chan error, 1) go func(as *Arke) { if err := as.mux.Serve(); err != nil { switch err.(type) { //nolint:gocritic @@ -347,6 +357,7 @@ func (a *Arke) Serve(ctx context.Context) error { return } serveErrChan <- err + return } serveErrChan <- nil }(a) diff --git a/pkg/arke/arke_test.go b/pkg/arke/arke_test.go index c3d2d67..cec5a69 100644 --- a/pkg/arke/arke_test.go +++ b/pkg/arke/arke_test.go @@ -5,9 +5,16 @@ package arke import ( "context" + "crypto/ed25519" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" "fmt" "log" + "math/big" "os" + "path/filepath" "testing" "time" @@ -17,6 +24,41 @@ import ( "google.golang.org/grpc/credentials/insecure" ) +// writeTestCert generates a self-signed cert/key pair and writes them to the +// provided directory. Returns the cert and key file paths. +func writeTestCert(t *testing.T, dir string) (certPath, keyPath string) { + t.Helper() + pub, priv, err := ed25519.GenerateKey(rand.Reader) + assert.Nil(t, err) + + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "test"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + } + der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, pub, priv) + assert.Nil(t, err) + + certPath = filepath.Join(dir, "cert.pem") + keyPath = filepath.Join(dir, "key.pem") + + certOut, err := os.Create(certPath) + assert.Nil(t, err) + assert.Nil(t, pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: der})) + assert.Nil(t, certOut.Close()) + + keyDER, err := x509.MarshalPKCS8PrivateKey(priv) + assert.Nil(t, err) + keyOut, err := os.Create(keyPath) + assert.Nil(t, err) + assert.Nil(t, pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: keyDER})) + assert.Nil(t, keyOut.Close()) + + return certPath, keyPath +} + func Test_DefaultArkeServer(t *testing.T) { a := DefaultArkeServer().WithCertFilePath("/cert").WithCertKeyPath("/key").WithTLSSkipVerify(true) @@ -164,6 +206,119 @@ func Test_Serve_muxClose(t *testing.T) { err := a.Serve(ctx) assert.Nil(t, err) } +func Test_WithPrometheus(t *testing.T) { + a := &Arke{} + assert.Empty(t, a.interceptors.chainUnary) + assert.Empty(t, a.interceptors.chainStream) + a = a.WithPrometheus() + assert.Len(t, a.interceptors.chainUnary, 1) + assert.Len(t, a.interceptors.chainStream, 1) +} + +func Test_WithRateLimit_nil(t *testing.T) { + a := &Arke{} + a = a.WithRateLimit(nil) + assert.Nil(t, a.ratelimiter) + assert.Empty(t, a.interceptors.chainUnary) + assert.Empty(t, a.interceptors.chainStream) +} + +func Test_WithRateLimit_invalidParams(t *testing.T) { + tests := []struct { + name string + rlp *RateLimitParameters + }{ + { + name: "zero bucket size", + rlp: &RateLimitParameters{BucketSize: 0, RefillInterval: time.Second, MaxAgeStaleClient: time.Second}, + }, + { + name: "zero refill interval", + rlp: &RateLimitParameters{BucketSize: 1, RefillInterval: 0, MaxAgeStaleClient: time.Second}, + }, + { + name: "zero max age", + rlp: &RateLimitParameters{BucketSize: 1, RefillInterval: time.Second, MaxAgeStaleClient: 0}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := (&Arke{}).WithRateLimit(tt.rlp) + assert.Nil(t, a.ratelimiter) + assert.Empty(t, a.interceptors.chainUnary) + assert.Empty(t, a.interceptors.chainStream) + }) + } +} + +func Test_WithRateLimit_valid(t *testing.T) { + rlp := &RateLimitParameters{ + BucketSize: 5, + RefillInterval: time.Second, + MaxAgeStaleClient: time.Minute, + Enforced: true, + } + a := (&Arke{}).WithRateLimit(rlp) + assert.NotNil(t, a.ratelimiter) + assert.Len(t, a.interceptors.chainUnary, 1) + assert.Len(t, a.interceptors.chainStream, 1) +} + +func Test_tlsConfig_success(t *testing.T) { + certPath, keyPath := writeTestCert(t, t.TempDir()) + a := Arke{certFile: certPath, certKey: keyPath} + cfg, err := a.tlsConfig() + assert.Nil(t, err) + assert.NotNil(t, cfg) + assert.Len(t, cfg.Certificates, 1) + assert.Contains(t, cfg.NextProtos, "h2") + assert.Contains(t, cfg.NextProtos, "http/1.1") +} + +func Test_listener_TLS(t *testing.T) { + certPath, keyPath := writeTestCert(t, t.TempDir()) + a := Arke{port: 0, certFile: certPath, certKey: keyPath} + lis, err := a.listener() + assert.Nil(t, err) + assert.NotNil(t, lis) + assert.Nil(t, lis.Close()) +} + +func Test_listener_tlsConfigError(t *testing.T) { + a := Arke{port: 0, certFile: "/nonexistent/cert", certKey: "/nonexistent/key"} + lis, err := a.listener() + assert.Nil(t, lis) + assert.NotNil(t, err) +} + +func Test_Serve_withRateLimiter(t *testing.T) { + rlp := &RateLimitParameters{ + BucketSize: 5, + RefillInterval: time.Second, + MaxAgeStaleClient: time.Minute, + } + a := DefaultArkeServer().WithPort(50062).WithRateLimit(rlp).Build() + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(500 * time.Millisecond) + err := testHealth(50062) + assert.Nil(t, err) + cancel() + }() + err := a.Serve(ctx) + assert.Nil(t, err) +} + +func Test_Serve_listenerError(t *testing.T) { + a := DefaultArkeServer(). + WithPort(50063). + WithCertFilePath("/nonexistent/cert"). + WithCertKeyPath("/nonexistent/key"). + Build() + err := a.Serve(context.Background()) + assert.NotNil(t, err) +} + func Test_GetRateLimitParameters(t *testing.T) { tests := []struct { name string