Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions pkg/arke/arke.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,11 @@ func (a Arke) listener() (net.Listener, error) {
}

tlsCfg, err := a.tlsConfig()
if tlsCfg != nil && err == nil {
if err != nil {
lis.Close()
return nil, err
}
if tlsCfg != nil {
lis = tls.NewListener(lis, tlsCfg)
}

Expand Down Expand Up @@ -288,9 +292,15 @@ func (a *Arke) Serve(ctx context.Context) error {

c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
defer signal.Stop(c)
go func() {
for range c {
a.server.Stop()
for {
select {
case <-c:
a.server.Stop()
case <-ctx.Done():
return
}
}
}()

Expand Down Expand Up @@ -338,7 +348,7 @@ func (a *Arke) Serve(ctx context.Context) error {
if a.ratelimiter != nil {
go a.ratelimiter.StartClientCull(ctx)
}
serveErrChan := make(chan error)
serveErrChan := make(chan error, 1)
go func(as *Arke) {
if err := as.mux.Serve(); err != nil {
switch err.(type) { //nolint:gocritic
Expand All @@ -347,6 +357,7 @@ func (a *Arke) Serve(ctx context.Context) error {
return
}
serveErrChan <- err
return
}
serveErrChan <- nil
}(a)
Expand Down
155 changes: 155 additions & 0 deletions pkg/arke/arke_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,16 @@ package arke

import (
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"log"
"math/big"
"os"
"path/filepath"
"testing"
"time"

Expand All @@ -17,6 +24,41 @@ import (
"google.golang.org/grpc/credentials/insecure"
)

// writeTestCert generates a self-signed cert/key pair and writes them to the
// provided directory. Returns the cert and key file paths.
func writeTestCert(t *testing.T, dir string) (certPath, keyPath string) {
t.Helper()
pub, priv, err := ed25519.GenerateKey(rand.Reader)
assert.Nil(t, err)

tmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "test"},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
}
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, pub, priv)
assert.Nil(t, err)

certPath = filepath.Join(dir, "cert.pem")
keyPath = filepath.Join(dir, "key.pem")

certOut, err := os.Create(certPath)
assert.Nil(t, err)
assert.Nil(t, pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: der}))
assert.Nil(t, certOut.Close())

keyDER, err := x509.MarshalPKCS8PrivateKey(priv)
assert.Nil(t, err)
keyOut, err := os.Create(keyPath)
assert.Nil(t, err)
assert.Nil(t, pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: keyDER}))
assert.Nil(t, keyOut.Close())

return certPath, keyPath
}

func Test_DefaultArkeServer(t *testing.T) {
a := DefaultArkeServer().WithCertFilePath("/cert").WithCertKeyPath("/key").WithTLSSkipVerify(true)

Expand Down Expand Up @@ -164,6 +206,119 @@ func Test_Serve_muxClose(t *testing.T) {
err := a.Serve(ctx)
assert.Nil(t, err)
}
func Test_WithPrometheus(t *testing.T) {
a := &Arke{}
assert.Empty(t, a.interceptors.chainUnary)
assert.Empty(t, a.interceptors.chainStream)
a = a.WithPrometheus()
assert.Len(t, a.interceptors.chainUnary, 1)
assert.Len(t, a.interceptors.chainStream, 1)
}

func Test_WithRateLimit_nil(t *testing.T) {
a := &Arke{}
a = a.WithRateLimit(nil)
assert.Nil(t, a.ratelimiter)
assert.Empty(t, a.interceptors.chainUnary)
assert.Empty(t, a.interceptors.chainStream)
}

func Test_WithRateLimit_invalidParams(t *testing.T) {
tests := []struct {
name string
rlp *RateLimitParameters
}{
{
name: "zero bucket size",
rlp: &RateLimitParameters{BucketSize: 0, RefillInterval: time.Second, MaxAgeStaleClient: time.Second},
},
{
name: "zero refill interval",
rlp: &RateLimitParameters{BucketSize: 1, RefillInterval: 0, MaxAgeStaleClient: time.Second},
},
{
name: "zero max age",
rlp: &RateLimitParameters{BucketSize: 1, RefillInterval: time.Second, MaxAgeStaleClient: 0},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := (&Arke{}).WithRateLimit(tt.rlp)
assert.Nil(t, a.ratelimiter)
assert.Empty(t, a.interceptors.chainUnary)
assert.Empty(t, a.interceptors.chainStream)
})
}
}

func Test_WithRateLimit_valid(t *testing.T) {
rlp := &RateLimitParameters{
BucketSize: 5,
RefillInterval: time.Second,
MaxAgeStaleClient: time.Minute,
Enforced: true,
}
a := (&Arke{}).WithRateLimit(rlp)
assert.NotNil(t, a.ratelimiter)
assert.Len(t, a.interceptors.chainUnary, 1)
assert.Len(t, a.interceptors.chainStream, 1)
}

func Test_tlsConfig_success(t *testing.T) {
certPath, keyPath := writeTestCert(t, t.TempDir())
a := Arke{certFile: certPath, certKey: keyPath}
cfg, err := a.tlsConfig()
assert.Nil(t, err)
assert.NotNil(t, cfg)
assert.Len(t, cfg.Certificates, 1)
assert.Contains(t, cfg.NextProtos, "h2")
assert.Contains(t, cfg.NextProtos, "http/1.1")
}

func Test_listener_TLS(t *testing.T) {
certPath, keyPath := writeTestCert(t, t.TempDir())
a := Arke{port: 0, certFile: certPath, certKey: keyPath}
lis, err := a.listener()
assert.Nil(t, err)
assert.NotNil(t, lis)
assert.Nil(t, lis.Close())
}

func Test_listener_tlsConfigError(t *testing.T) {
a := Arke{port: 0, certFile: "/nonexistent/cert", certKey: "/nonexistent/key"}
lis, err := a.listener()
assert.Nil(t, lis)
assert.NotNil(t, err)
}

func Test_Serve_withRateLimiter(t *testing.T) {
rlp := &RateLimitParameters{
BucketSize: 5,
RefillInterval: time.Second,
MaxAgeStaleClient: time.Minute,
}
a := DefaultArkeServer().WithPort(50062).WithRateLimit(rlp).Build()
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(500 * time.Millisecond)
err := testHealth(50062)
assert.Nil(t, err)
cancel()
}()
err := a.Serve(ctx)
assert.Nil(t, err)
}

func Test_Serve_listenerError(t *testing.T) {
a := DefaultArkeServer().
WithPort(50063).
WithCertFilePath("/nonexistent/cert").
WithCertKeyPath("/nonexistent/key").
Build()
err := a.Serve(context.Background())
assert.NotNil(t, err)
}

func Test_GetRateLimitParameters(t *testing.T) {
tests := []struct {
name string
Expand Down