Skip to content

Commit df2b44e

Browse files
committed
fix(sdk): race condition issues on stub resolver implementation
1 parent f6646b9 commit df2b44e

3 files changed

Lines changed: 166 additions & 2 deletions

File tree

plugin/resolver.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ import (
44
"context"
55
"errors"
66
"fmt"
7+
"time"
78

89
"connectrpc.com/connect"
10+
"github.com/sirupsen/logrus"
911
"google.golang.org/protobuf/proto"
1012

1113
resolverv1 "github.com/docker/secrets-engine/pkg/api/resolver/v1"
@@ -16,10 +18,20 @@ import (
1618
var _ resolverv1connect.ResolverServiceHandler = &resolverService{}
1719

1820
type resolverService struct {
19-
resolver secrets.Resolver
21+
resolver secrets.Resolver
22+
setupCompleted chan struct{}
23+
registrationTimeout time.Duration
2024
}
2125

2226
func (r *resolverService) GetSecret(ctx context.Context, c *connect.Request[resolverv1.GetSecretRequest]) (*connect.Response[resolverv1.GetSecretResponse], error) {
27+
logrus.Infof("GetSecret request (ID %q)", c.Msg.GetSecretId())
28+
select {
29+
case <-r.setupCompleted:
30+
case <-ctx.Done():
31+
return nil, connect.NewError(connect.CodeInternal, errors.New("context cancelled while waiting for registration"))
32+
case <-time.After(r.registrationTimeout):
33+
return nil, connect.NewError(connect.CodeDeadlineExceeded, fmt.Errorf("registration incomplete (timeout after %s)", r.registrationTimeout))
34+
}
2335
msgID := c.Msg.GetSecretId()
2436
id, err := secrets.ParseID(msgID)
2537
if err != nil {
@@ -33,6 +45,9 @@ func (r *resolverService) GetSecret(ctx context.Context, c *connect.Request[reso
3345
}
3446
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("failed to get secret %q: %w", msgID, err))
3547
}
48+
if envelope.ID != id {
49+
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("secret ID mismatch: expected %q, got %q", id, envelope.ID))
50+
}
3651
return connect.NewResponse(resolverv1.GetSecretResponse_builder{
3752
SecretId: proto.String(envelope.ID.String()),
3853
SecretValue: proto.String(string(envelope.Value)),

plugin/resolver_test.go

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
package plugin
2+
3+
import (
4+
"context"
5+
"errors"
6+
"testing"
7+
"time"
8+
9+
"connectrpc.com/connect"
10+
"github.com/stretchr/testify/assert"
11+
"google.golang.org/protobuf/proto"
12+
13+
resolverv1 "github.com/docker/secrets-engine/pkg/api/resolver/v1"
14+
"github.com/docker/secrets-engine/pkg/secrets"
15+
)
16+
17+
const (
18+
mockSecretValue = "mockSecretValue"
19+
mockSecretID = secrets.ID("mockSecretID")
20+
)
21+
22+
type mockResolver struct {
23+
t *testing.T
24+
id secrets.ID
25+
value string
26+
err error
27+
}
28+
29+
func newMockResolver(t *testing.T, options ...mockResolverOption) *mockResolver {
30+
resolver := &mockResolver{
31+
t: t,
32+
id: mockSecretID,
33+
value: mockSecretValue,
34+
}
35+
for _, opt := range options {
36+
resolver = opt(resolver)
37+
}
38+
return resolver
39+
}
40+
41+
type mockResolverOption func(*mockResolver) *mockResolver
42+
43+
func withMockResolverID(id secrets.ID) mockResolverOption {
44+
return func(m *mockResolver) *mockResolver {
45+
m.id = id
46+
return m
47+
}
48+
}
49+
50+
func withMockResolverError(err error) mockResolverOption {
51+
return func(m *mockResolver) *mockResolver {
52+
m.err = err
53+
return m
54+
}
55+
}
56+
57+
func (m mockResolver) GetSecret(_ context.Context, request secrets.Request) (secrets.Envelope, error) {
58+
assert.Equal(m.t, m.id, request.ID)
59+
if m.err != nil {
60+
return secrets.Envelope{}, m.err
61+
}
62+
return secrets.Envelope{
63+
ID: m.id,
64+
Value: []byte(m.value),
65+
}, nil
66+
}
67+
68+
func TestResolverService_GetSecret(t *testing.T) {
69+
t.Parallel()
70+
tests := []struct {
71+
name string
72+
test func(t *testing.T)
73+
}{
74+
{
75+
name: "does not resolve secrets before setup completed",
76+
test: func(t *testing.T) {
77+
s := &resolverService{resolver: newMockResolver(t)}
78+
_, err := s.GetSecret(t.Context(), newGetSecretRequest(mockSecretID))
79+
assert.ErrorContains(t, err, "registration incomplete (timeout ")
80+
},
81+
},
82+
{
83+
name: "returns an error if request secret ID is invalid",
84+
test: func(t *testing.T) {
85+
done := make(chan struct{})
86+
close(done)
87+
s := &resolverService{resolver: newMockResolver(t), setupCompleted: done, registrationTimeout: 10 * time.Second}
88+
_, err := s.GetSecret(t.Context(), newGetSecretRequest("/"))
89+
assert.ErrorContains(t, err, "invalid secret ID")
90+
},
91+
},
92+
{
93+
name: "secret not found",
94+
test: func(t *testing.T) {
95+
done := make(chan struct{})
96+
close(done)
97+
s := &resolverService{resolver: newMockResolver(t, withMockResolverError(secrets.ErrNotFound)), setupCompleted: done, registrationTimeout: 10 * time.Second}
98+
_, err := s.GetSecret(t.Context(), newGetSecretRequest(mockSecretID))
99+
assert.ErrorIs(t, err, secrets.ErrNotFound)
100+
},
101+
},
102+
{
103+
name: "error fetching secret",
104+
test: func(t *testing.T) {
105+
done := make(chan struct{})
106+
close(done)
107+
s := &resolverService{resolver: newMockResolver(t, withMockResolverError(errors.New("foo"))), setupCompleted: done, registrationTimeout: 10 * time.Second}
108+
_, err := s.GetSecret(t.Context(), newGetSecretRequest(mockSecretID))
109+
assert.ErrorContains(t, err, "foo")
110+
},
111+
},
112+
{
113+
name: "returns wrong ID",
114+
test: func(t *testing.T) {
115+
done := make(chan struct{})
116+
close(done)
117+
s := &resolverService{resolver: newMockResolver(t, withMockResolverID("wrongID")), setupCompleted: done, registrationTimeout: 10 * time.Second}
118+
_, err := s.GetSecret(t.Context(), newGetSecretRequest(mockSecretID))
119+
assert.ErrorContains(t, err, "secret ID mismatch")
120+
},
121+
},
122+
{
123+
name: "returns secret value",
124+
test: func(t *testing.T) {
125+
done := make(chan struct{})
126+
close(done)
127+
s := &resolverService{resolver: newMockResolver(t), setupCompleted: done, registrationTimeout: 10 * time.Second}
128+
resp, err := s.GetSecret(t.Context(), newGetSecretRequest(mockSecretID))
129+
assert.NoError(t, err)
130+
assert.Equal(t, mockSecretID.String(), resp.Msg.GetSecretId())
131+
assert.Equal(t, mockSecretValue, resp.Msg.GetSecretValue())
132+
},
133+
},
134+
}
135+
for _, tt := range tests {
136+
t.Run(tt.name, func(t *testing.T) {
137+
138+
tt.test(t)
139+
})
140+
}
141+
}
142+
143+
func newGetSecretRequest(secretID secrets.ID) *connect.Request[resolverv1.GetSecretRequest] {
144+
return connect.NewRequest(resolverv1.GetSecretRequest_builder{
145+
SecretId: proto.String(string(secretID)),
146+
}.Build())
147+
}

plugin/stub.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ func setup(ctx context.Context, conn net.Conn, name string, p Plugin, timeout ti
9393
_, _ = w.Write([]byte("ok"))
9494
})
9595
httpMux.Handle(resolverv1connect.NewPluginServiceHandler(&pluginService{p.Shutdown}))
96-
httpMux.Handle(resolverv1connect.NewResolverServiceHandler(&resolverService{p}))
96+
setupCompleted := make(chan struct{})
97+
httpMux.Handle(resolverv1connect.NewResolverServiceHandler(&resolverService{p, setupCompleted, timeout}))
9798
ipc, c, err := ipc.NewPluginIPC(conn, httpMux, func(err error) {
9899
if errors.Is(err, io.EOF) {
99100
logrus.Infof("Plugin runtime stopped, plugin %s is shutting down...", name)
@@ -114,6 +115,7 @@ func setup(ctx context.Context, conn net.Conn, name string, p Plugin, timeout ti
114115
return nil, fmt.Errorf("failed to configure plugin %q: %w", name, err)
115116
}
116117
logrus.Infof("Started plugin %s...", name)
118+
close(setupCompleted)
117119
return ipc, nil
118120
}
119121

0 commit comments

Comments
 (0)