Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmd/arkd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ func startAction(_ *cli.Context) error {
HeartbeatInterval: cfg.HeartbeatInterval,
EnablePprof: cfg.EnablePprof,
MaxConcurrentStreams: cfg.MaxConcurrentStreams,
StreamConnPoolSize: cfg.StreamConnPoolSize,
}

svc, err := grpcservice.NewService(Version, svcConfig, cfg)
Expand Down
9 changes: 9 additions & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ type Config struct {
// SENSITIVE: must never be logged.
IndexerSigningKey string
MaxConcurrentStreams uint32
StreamConnPoolSize uint32

fee ports.FeeManager
repo ports.RepoManager
Expand Down Expand Up @@ -252,6 +253,7 @@ var (
// IndexerSigningKey is a hex-encoded private key. SENSITIVE: never log this value.
IndexerSigningKey = "INDEXER_SIGNING_PRIVKEY" // #nosec G101
MaxConcurrentStreams = "MAX_CONCURRENT_STREAMS"
StreamConnPoolSize = "STREAM_CONN_POOL_SIZE"

defaultDatadir = arklib.AppDataDir("arkd", false)
defaultSessionDuration = 30
Expand Down Expand Up @@ -292,6 +294,8 @@ var (
defaultIndexerExposure = "public"
defaultIndexerAuthTokenExpiry = 300 // 5 minutes in seconds
defaultMaxConcurrentStreams = uint32(1000)
defaultStreamConnPoolSize = uint32(4)
maxStreamConnPoolSize = uint32(64)
defaultMaxOpReturnOuts = uint32(3)
)

Expand Down Expand Up @@ -342,6 +346,7 @@ func LoadConfig() (*Config, error) {
viper.SetDefault(IndexerExposure, defaultIndexerExposure)
viper.SetDefault(IndexerAuthTokenExpiry, defaultIndexerAuthTokenExpiry)
viper.SetDefault(MaxConcurrentStreams, defaultMaxConcurrentStreams)
viper.SetDefault(StreamConnPoolSize, defaultStreamConnPoolSize)
viper.SetDefault(MaxOpReturnOutputs, defaultMaxOpReturnOuts)

if err := initDatadir(); err != nil {
Expand Down Expand Up @@ -458,6 +463,10 @@ func LoadConfig() (*Config, error) {
IndexerAuthTokenExpiry: viper.GetInt64(IndexerAuthTokenExpiry),
IndexerSigningKey: viper.GetString(IndexerSigningKey),
MaxConcurrentStreams: viper.GetUint32(MaxConcurrentStreams),
// Default to 1 or maxStreamConnPoolSize if out of bounds
StreamConnPoolSize: min(
maxStreamConnPoolSize, max(1, viper.GetUint32(StreamConnPoolSize)),
),
// Default to 1 if set to 0
MaxOpReturnOutputs: max(1, viper.GetUint32(MaxOpReturnOutputs)),
}, nil
Expand Down
1 change: 1 addition & 0 deletions internal/interface/grpc/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type Config struct {
HeartbeatInterval int64
EnablePprof bool
MaxConcurrentStreams uint32
StreamConnPoolSize uint32
}

func (c Config) Validate() error {
Expand Down
73 changes: 53 additions & 20 deletions internal/interface/grpc/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ type service struct {
otelShutdown func(context.Context) error
pyroscopeShutdown func() error
unaryConn *grpc.ClientConn
streamConn *grpc.ClientConn
streamConns []*grpc.ClientConn
adminConn *grpc.ClientConn
}

Expand Down Expand Up @@ -210,8 +210,8 @@ func (s *service) stop() {
log.Warn("failed to close unary transport connection")
}
}
if s.streamConn != nil {
if err := s.streamConn.Close(); err != nil {
for _, sc := range s.streamConns {
if err := sc.Close(); err != nil {
log.Warn("failed to close stream transport connection")
}
}
Expand Down Expand Up @@ -375,18 +375,41 @@ func (s *service) newServer(tlsConfig *tls.Config, withPprof bool) error {
if err != nil {
return err
}
streamConn, err := grpc.NewClient(
s.config.gatewayAddress(), gatewayOpts,
)
if err != nil {
if err := unaryConn.Close(); err != nil {
log.Warn("failed to close unary transport connection")
s.unaryConn = unaryConn

// Connection pool for streaming RPCs. Each connection has its own
// HTTP/2 MAX_CONCURRENT_STREAMS budget, so a pool of N multiplies
// the gateway's concurrent stream capacity by N.
poolSize := s.config.StreamConnPoolSize
if poolSize == 0 {
poolSize = 1
}
streamConns := make([]*grpc.ClientConn, 0, poolSize)
for i := uint32(0); i < poolSize; i++ {
sc, err := grpc.NewClient(
s.config.gatewayAddress(), gatewayOpts,
)
if err != nil {
if closeErr := unaryConn.Close(); closeErr != nil {
log.Warn("failed to close unary transport connection")
}
for _, prev := range streamConns {
if closeErr := prev.Close(); closeErr != nil {
log.Warn("failed to close stream transport connection")
}
}
return err
}
return err
streamConns = append(streamConns, sc)
}
s.unaryConn = unaryConn
s.streamConn = streamConn
conn := &splitConn{unary: unaryConn, stream: streamConn}
s.streamConns = streamConns
log.Infof("stream connection pool size: %d", poolSize)

streamPool := make([]grpc.ClientConnInterface, len(streamConns))
for i, sc := range streamConns {
streamPool[i] = sc
}
conn := &splitConn{unary: unaryConn, streamPool: streamPool}

customMatcher := func(key string) (string, bool) {
switch key {
Expand Down Expand Up @@ -461,10 +484,12 @@ func (s *service) newServer(tlsConfig *tls.Config, withPprof bool) error {
log.Warn("failed to close unary transport connection")
}
s.unaryConn = nil
if closeErr := s.streamConn.Close(); closeErr != nil {
log.Warn("failed to close stream transport connection")
for _, sc := range s.streamConns {
if closeErr := sc.Close(); closeErr != nil {
log.Warn("failed to close stream transport connection")
}
}
s.streamConn = nil
s.streamConns = nil
return err
}
s.adminConn = adminConn
Expand Down Expand Up @@ -647,10 +672,13 @@ func isHttpRequest(req *http.Request) bool {
strings.Contains(req.Header.Get("Content-Type"), "application/json")
}

// splitConn routes unary and streaming RPCs to separate grpc.ClientConn
// splitConn routes unary RPCs to a dedicated connection and round-robins
// streaming RPCs across a pool of connections. Each connection carries an
// independent HTTP/2 MAX_CONCURRENT_STREAMS budget.
type splitConn struct {
unary grpc.ClientConnInterface
stream grpc.ClientConnInterface
unary grpc.ClientConnInterface
streamPool []grpc.ClientConnInterface
streamIndex atomic.Uint64
}

func (c *splitConn) Invoke(
Expand All @@ -659,8 +687,13 @@ func (c *splitConn) Invoke(
return c.unary.Invoke(ctx, method, args, reply, opts...)
}

// Called by the meshapi gateway to create new streams. Uses a simple round-robin
// selection strategy to atomically increment the counter and wrap around
// the pool size to select the next connection.
func (c *splitConn) NewStream(
ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption,
) (grpc.ClientStream, error) {
return c.stream.NewStream(ctx, desc, method, opts...)
idx := c.streamIndex.Add(1) - 1
conn := c.streamPool[idx%uint64(len(c.streamPool))]
return conn.NewStream(ctx, desc, method, opts...)
Comment thread
altafan marked this conversation as resolved.
}
131 changes: 131 additions & 0 deletions internal/interface/grpc/service_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package grpcservice

import (
"context"
"sync"
"sync/atomic"
"testing"

"github.com/stretchr/testify/require"
"google.golang.org/grpc"
)

// Mock connection object implementing grpc.ClientConnInterface
// to count unary vs stream call counts.
type mockConn struct {
invokeCalls atomic.Int64
streamCalls atomic.Int64
}

func (m *mockConn) Invoke(
_ context.Context, _ string, _, _ any, _ ...grpc.CallOption,
) error {
m.invokeCalls.Add(1)
return nil
}

func (m *mockConn) NewStream(
_ context.Context, _ *grpc.StreamDesc, _ string, _ ...grpc.CallOption,
) (grpc.ClientStream, error) {
m.streamCalls.Add(1)
return nil, nil
}

func TestSplitConnInvoke(t *testing.T) {
t.Run("routes unary only", func(t *testing.T) {
unary := &mockConn{}
streams := []grpc.ClientConnInterface{&mockConn{}, &mockConn{}}
sc := &splitConn{unary: unary, streamPool: streams}

for i := 0; i < 10; i++ {
require.NoError(t, sc.Invoke(context.Background(), "/test", nil, nil))
}

require.Equal(t, int64(10), unary.invokeCalls.Load())
for i, s := range streams {
mock := s.(*mockConn)
require.Zero(t, mock.invokeCalls.Load(), "stream pool[%d] received invoke calls", i)
require.Zero(t, mock.streamCalls.Load(), "stream pool[%d] received stream calls", i)
}
})
}

func TestSplitConnNewStream(t *testing.T) {
t.Run("round robins across pool", func(t *testing.T) {
unary := &mockConn{}
poolSize := 4
streams := make([]grpc.ClientConnInterface, poolSize)
for i := range streams {
streams[i] = &mockConn{}
}
sc := &splitConn{unary: unary, streamPool: streams}

totalCalls := 100
for i := 0; i < totalCalls; i++ {
_, err := sc.NewStream(context.Background(), nil, "/test")
require.NoError(t, err)
}

expectedPerConn := int64(totalCalls / poolSize)
for i, s := range streams {
require.Equal(t, expectedPerConn, s.(*mockConn).streamCalls.Load(),
"stream pool[%d] call count", i)
}
require.Zero(t, unary.streamCalls.Load(), "unary conn received stream calls")
})

t.Run("pool size one", func(t *testing.T) {
unary := &mockConn{}
single := &mockConn{}
sc := &splitConn{unary: unary, streamPool: []grpc.ClientConnInterface{single}}

for i := 0; i < 50; i++ {
_, err := sc.NewStream(context.Background(), nil, "/test")
require.NoError(t, err)
}

require.Equal(t, int64(50), single.streamCalls.Load())
})

t.Run("concurrent creation safe and evenly distributed", func(t *testing.T) {
unary := &mockConn{}
poolSize := 4
streams := make([]grpc.ClientConnInterface, poolSize)
for i := range streams {
streams[i] = &mockConn{}
}
sc := &splitConn{unary: unary, streamPool: streams}

goroutines := 100
callsPerGoroutine := 100
totalCalls := goroutines * callsPerGoroutine

var wg sync.WaitGroup
wg.Add(goroutines)
for g := 0; g < goroutines; g++ {
go func() {
defer wg.Done()
for i := 0; i < callsPerGoroutine; i++ {
_, err := sc.NewStream(context.Background(), nil, "/test")
require.NoError(t, err)
}
}()
}
wg.Wait()

var totalObserved int64
for _, s := range streams {
totalObserved += s.(*mockConn).streamCalls.Load()
}
require.Equal(t, int64(totalCalls), totalObserved)

// Verify roughly even distribution (allow 20% deviation).
expected := int64(totalCalls / poolSize)
tolerance := expected / 5
for i, s := range streams {
got := s.(*mockConn).streamCalls.Load()
require.InDelta(t, expected, got, float64(tolerance),
"stream pool[%d] distribution", i)
}
})
}
Loading