Skip to content

Commit a80f073

Browse files
committed
fix(sdk): race condition issues on stub resolver implementation
1 parent f6646b9 commit a80f073

3 files changed

Lines changed: 168 additions & 2 deletions

File tree

plugin/resolver.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ import (
44
"context"
55
"errors"
66
"fmt"
7+
"time"
78

89
"connectrpc.com/connect"
10+
"github.com/sirupsen/logrus"
911
"google.golang.org/protobuf/proto"
1012

1113
resolverv1 "github.com/docker/secrets-engine/pkg/api/resolver/v1"
@@ -16,10 +18,20 @@ import (
1618
var _ resolverv1connect.ResolverServiceHandler = &resolverService{}
1719

1820
type resolverService struct {
19-
resolver secrets.Resolver
21+
resolver secrets.Resolver
22+
setupCompleted chan struct{}
23+
registrationTimeout time.Duration
2024
}
2125

2226
func (r *resolverService) GetSecret(ctx context.Context, c *connect.Request[resolverv1.GetSecretRequest]) (*connect.Response[resolverv1.GetSecretResponse], error) {
27+
logrus.Debugf("GetSecret request (ID %q)", c.Msg.GetSecretId())
28+
select {
29+
case <-r.setupCompleted:
30+
case <-ctx.Done():
31+
return nil, connect.NewError(connect.CodeInternal, errors.New("context cancelled while waiting for registration"))
32+
case <-time.After(r.registrationTimeout):
33+
return nil, connect.NewError(connect.CodeDeadlineExceeded, fmt.Errorf("registration incomplete (timeout after %s)", r.registrationTimeout))
34+
}
2335
msgID := c.Msg.GetSecretId()
2436
id, err := secrets.ParseID(msgID)
2537
if err != nil {
@@ -33,6 +45,9 @@ func (r *resolverService) GetSecret(ctx context.Context, c *connect.Request[reso
3345
}
3446
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("failed to get secret %q: %w", msgID, err))
3547
}
48+
if envelope.ID != id {
49+
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("secret ID mismatch: expected %q, got %q", id, envelope.ID))
50+
}
3651
return connect.NewResponse(resolverv1.GetSecretResponse_builder{
3752
SecretId: proto.String(envelope.ID.String()),
3853
SecretValue: proto.String(string(envelope.Value)),

plugin/resolver_test.go

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
package plugin
2+
3+
import (
4+
"context"
5+
"errors"
6+
"testing"
7+
"time"
8+
9+
"connectrpc.com/connect"
10+
"github.com/stretchr/testify/assert"
11+
"google.golang.org/protobuf/proto"
12+
13+
resolverv1 "github.com/docker/secrets-engine/pkg/api/resolver/v1"
14+
"github.com/docker/secrets-engine/pkg/secrets"
15+
)
16+
17+
const (
18+
mockSecretValue = "mockSecretValue"
19+
mockSecretID = secrets.ID("mockSecretID")
20+
)
21+
22+
type mockResolver struct {
23+
t *testing.T
24+
id secrets.ID
25+
value string
26+
err error
27+
}
28+
29+
func newMockResolver(t *testing.T, options ...mockResolverOption) *mockResolver {
30+
resolver := &mockResolver{
31+
t: t,
32+
id: mockSecretID,
33+
value: mockSecretValue,
34+
}
35+
for _, opt := range options {
36+
resolver = opt(resolver)
37+
}
38+
return resolver
39+
}
40+
41+
type mockResolverOption func(*mockResolver) *mockResolver
42+
43+
func withMockResolverID(id secrets.ID) mockResolverOption {
44+
return func(m *mockResolver) *mockResolver {
45+
m.id = id
46+
return m
47+
}
48+
}
49+
50+
func withMockResolverError(err error) mockResolverOption {
51+
return func(m *mockResolver) *mockResolver {
52+
m.err = err
53+
return m
54+
}
55+
}
56+
57+
func (m mockResolver) GetSecret(_ context.Context, request secrets.Request) (secrets.Envelope, error) {
58+
if request.ID != m.id {
59+
return secrets.Envelope{}, errors.New("unexpected secret ID")
60+
}
61+
if m.err != nil {
62+
return secrets.Envelope{}, m.err
63+
}
64+
return secrets.Envelope{
65+
ID: m.id,
66+
Value: []byte(m.value),
67+
}, nil
68+
}
69+
70+
func TestResolverService_GetSecret(t *testing.T) {
71+
t.Parallel()
72+
tests := []struct {
73+
name string
74+
test func(t *testing.T)
75+
}{
76+
{
77+
name: "does not resolve secrets before setup completed",
78+
test: func(t *testing.T) {
79+
s := &resolverService{resolver: newMockResolver(t)}
80+
_, err := s.GetSecret(t.Context(), newGetSecretRequest(mockSecretID))
81+
assert.ErrorContains(t, err, "registration incomplete (timeout ")
82+
},
83+
},
84+
{
85+
name: "returns an error if request secret ID is invalid",
86+
test: func(t *testing.T) {
87+
done := make(chan struct{})
88+
close(done)
89+
s := &resolverService{resolver: newMockResolver(t), setupCompleted: done, registrationTimeout: 10 * time.Second}
90+
_, err := s.GetSecret(t.Context(), newGetSecretRequest("/"))
91+
assert.ErrorContains(t, err, "invalid secret ID")
92+
},
93+
},
94+
{
95+
name: "secret not found",
96+
test: func(t *testing.T) {
97+
done := make(chan struct{})
98+
close(done)
99+
s := &resolverService{resolver: newMockResolver(t, withMockResolverError(secrets.ErrNotFound)), setupCompleted: done, registrationTimeout: 10 * time.Second}
100+
_, err := s.GetSecret(t.Context(), newGetSecretRequest(mockSecretID))
101+
assert.ErrorIs(t, err, secrets.ErrNotFound)
102+
},
103+
},
104+
{
105+
name: "error fetching secret",
106+
test: func(t *testing.T) {
107+
done := make(chan struct{})
108+
close(done)
109+
s := &resolverService{resolver: newMockResolver(t, withMockResolverError(errors.New("foo"))), setupCompleted: done, registrationTimeout: 10 * time.Second}
110+
_, err := s.GetSecret(t.Context(), newGetSecretRequest(mockSecretID))
111+
assert.ErrorContains(t, err, "foo")
112+
},
113+
},
114+
{
115+
name: "returns wrong ID",
116+
test: func(t *testing.T) {
117+
done := make(chan struct{})
118+
close(done)
119+
s := &resolverService{resolver: newMockResolver(t, withMockResolverID("wrongID")), setupCompleted: done, registrationTimeout: 10 * time.Second}
120+
_, err := s.GetSecret(t.Context(), newGetSecretRequest(mockSecretID))
121+
assert.ErrorContains(t, err, "secret ID mismatch")
122+
},
123+
},
124+
{
125+
name: "returns secret value",
126+
test: func(t *testing.T) {
127+
done := make(chan struct{})
128+
close(done)
129+
s := &resolverService{resolver: newMockResolver(t), setupCompleted: done, registrationTimeout: 10 * time.Second}
130+
resp, err := s.GetSecret(t.Context(), newGetSecretRequest(mockSecretID))
131+
assert.NoError(t, err)
132+
assert.Equal(t, mockSecretID.String(), resp.Msg.GetSecretId())
133+
assert.Equal(t, mockSecretValue, resp.Msg.GetSecretValue())
134+
},
135+
},
136+
}
137+
for _, tt := range tests {
138+
t.Run(tt.name, func(t *testing.T) {
139+
140+
tt.test(t)
141+
})
142+
}
143+
}
144+
145+
func newGetSecretRequest(secretID secrets.ID) *connect.Request[resolverv1.GetSecretRequest] {
146+
return connect.NewRequest(resolverv1.GetSecretRequest_builder{
147+
SecretId: proto.String(string(secretID)),
148+
}.Build())
149+
}

plugin/stub.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ func setup(ctx context.Context, conn net.Conn, name string, p Plugin, timeout ti
9393
_, _ = w.Write([]byte("ok"))
9494
})
9595
httpMux.Handle(resolverv1connect.NewPluginServiceHandler(&pluginService{p.Shutdown}))
96-
httpMux.Handle(resolverv1connect.NewResolverServiceHandler(&resolverService{p}))
96+
setupCompleted := make(chan struct{})
97+
httpMux.Handle(resolverv1connect.NewResolverServiceHandler(&resolverService{p, setupCompleted, timeout}))
9798
ipc, c, err := ipc.NewPluginIPC(conn, httpMux, func(err error) {
9899
if errors.Is(err, io.EOF) {
99100
logrus.Infof("Plugin runtime stopped, plugin %s is shutting down...", name)
@@ -114,6 +115,7 @@ func setup(ctx context.Context, conn net.Conn, name string, p Plugin, timeout ti
114115
return nil, fmt.Errorf("failed to configure plugin %q: %w", name, err)
115116
}
116117
logrus.Infof("Started plugin %s...", name)
118+
close(setupCompleted)
117119
return ipc, nil
118120
}
119121

0 commit comments

Comments
 (0)