diff --git a/client.go b/client.go index ed7fdd3..3d6babb 100644 --- a/client.go +++ b/client.go @@ -61,6 +61,9 @@ func NewClient(opts ...ClientOption) (*Client, error) { PermitWithoutStream: true, }), } + if len(o.UnaryInterceptors) > 0 { + dialOpts = append(dialOpts, grpc.WithChainUnaryInterceptor(o.UnaryInterceptors...)) + } conn, err := grpc.NewClient(o.Address, dialOpts...) if err != nil { diff --git a/client_test.go b/client_test.go index 34a2645..caa277d 100644 --- a/client_test.go +++ b/client_test.go @@ -5,6 +5,8 @@ import ( "errors" "testing" "time" + + "google.golang.org/grpc" ) func TestNewClient_DefaultsApplied(t *testing.T) { @@ -102,6 +104,27 @@ func TestClient_AuthCtx_OmitsHeaderWhenUnset(t *testing.T) { } } +func TestNewClient_UnaryInterceptorOption_AdditiveAndPlumbedToOpts(t *testing.T) { + interceptor := func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + return invoker(ctx, method, req, reply, cc, opts...) + } + + c, err := NewClient( + WithAddress("127.0.0.1:1"), + WithInsecure(), + WithUnaryInterceptor(interceptor), + WithUnaryInterceptor(interceptor), + ) + if err != nil { + t.Fatalf("NewClient: %v", err) + } + defer c.Close() + + if got, want := len(c.opts.UnaryInterceptors), 2; got != want { + t.Errorf("UnaryInterceptors len = %d, want %d (each WithUnaryInterceptor must append, not replace)", got, want) + } +} + func TestClient_CallsAfterCloseReturnErrClientClosed(t *testing.T) { c, err := NewClient( WithAddress("127.0.0.1:1"), diff --git a/clientoptions.go b/clientoptions.go index 0a86c91..7d8c2c9 100644 --- a/clientoptions.go +++ b/clientoptions.go @@ -1,6 +1,10 @@ package envector -import "time" +import ( + "time" + + "google.golang.org/grpc" +) const ( defaultDialTimeout = 3 * time.Second @@ -10,13 +14,14 @@ const ( ) type clientOptions struct { - Address string - AccessToken string - Insecure bool - DialTimeout time.Duration - KeepaliveTime time.Duration - KeepaliveTimeout time.Duration - MaxMsgSize int + Address string + AccessToken string + Insecure bool + DialTimeout time.Duration + KeepaliveTime time.Duration + KeepaliveTimeout time.Duration + MaxMsgSize int + UnaryInterceptors []grpc.UnaryClientInterceptor } func defaultClientOptions() clientOptions { @@ -58,3 +63,9 @@ func WithKeepaliveTimeout(d time.Duration) ClientOption { func WithMaxMsgSize(n int) ClientOption { return func(o *clientOptions) { o.MaxMsgSize = n } } + +func WithUnaryInterceptor(i grpc.UnaryClientInterceptor) ClientOption { + return func(o *clientOptions) { + o.UnaryInterceptors = append(o.UnaryInterceptors, i) // interceptors run in the order they were added + } +}