Skip to content

Commit 64d588a

Browse files
committed
fix(sdk): wait for streams to close before shutting down server
1 parent b1e15ec commit 64d588a

5 files changed

Lines changed: 144 additions & 36 deletions

File tree

internal/ipc/ipc.go

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"time"
1212

1313
"github.com/hashicorp/yamux"
14+
"github.com/sirupsen/logrus"
1415
)
1516

1617
const (
@@ -56,7 +57,7 @@ type ipcServer struct {
5657
err error
5758
}
5859

59-
func newIpcServer(l net.Listener, handler http.Handler, onClose func(error) error) *ipcServer {
60+
func newIpcServer(l net.Listener, handler http.Handler, afterClose func(error) error) *ipcServer {
6061
result := &ipcServer{
6162
done: make(chan struct{}),
6263
server: &http.Server{
@@ -68,7 +69,7 @@ func newIpcServer(l net.Listener, handler http.Handler, onClose func(error) erro
6869
if errors.Is(err, http.ErrServerClosed) { // not an error, client closed the connection
6970
err = nil
7071
}
71-
result.err = errors.Join(filterEOF(err), onClose(err)) // EOF: only forward to the onClose handler, but filter out internal forwarding
72+
result.err = errors.Join(filterEOF(err), afterClose(err)) // EOF: only forward to the afterClose handler, but filter out internal forwarding
7273
close(result.done)
7374
}()
7475
return result
@@ -99,16 +100,36 @@ func newMuxedIPC(session *yamux.Session, handler http.Handler, onClose func(erro
99100
}
100101
return session.Close()
101102
})
103+
c := createYamuxedClient(session)
102104
return &ipcImpl{
103105
server: server,
104106
teardown: sync.OnceValue(func() error {
105-
ctx, cancel := context.WithTimeout(context.Background(), cfg.shutdownTimeout)
106-
defer cancel()
107-
err := server.server.Shutdown(ctx)
107+
_ = session.GoAway()
108+
c.CloseIdleConnections()
109+
waitForClientToDisconnect(session, cfg.shutdownTimeout)
110+
err := server.server.Close()
108111
<-server.done
109-
return errors.Join(err, session.Close(), server.err)
112+
return errors.Join(err, server.err)
110113
}),
111-
}, createYamuxedClient(session)
114+
}, c
115+
}
116+
117+
func waitForClientToDisconnect(s *yamux.Session, t time.Duration) {
118+
timeout := time.After(t)
119+
for {
120+
select {
121+
case <-time.After(50 * time.Millisecond):
122+
case <-timeout:
123+
logrus.Debugf("Timeout expired but %d streams still open, shutting down server...", s.NumStreams())
124+
return
125+
}
126+
streams := s.NumStreams()
127+
// 1 stream is the control stream (todo: verify)
128+
// TODO: https://github.com/docker/secrets-engine/issues/71
129+
if streams <= 1 {
130+
return
131+
}
132+
}
112133
}
113134

114135
func (i *ipcImpl) Close() error {

pkg/adaptation/plugin_test.go

Lines changed: 108 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,37 @@ const (
2121
mockSecretID = secrets.ID("mockSecretID")
2222
)
2323

24-
var (
25-
mockPlugin = &mockedPlugin{
26-
pattern: "*",
27-
id: mockSecretID,
28-
}
29-
)
30-
3124
type mockedPlugin struct {
3225
pattern string
3326
id secrets.ID
3427
configureErr error
3528
}
3629

30+
type MockedPluginOption func(*mockedPlugin)
31+
32+
func newMockedPlugin(options ...MockedPluginOption) *mockedPlugin {
33+
m := &mockedPlugin{
34+
pattern: "*",
35+
id: mockSecretID,
36+
}
37+
for _, opt := range options {
38+
opt(m)
39+
}
40+
return m
41+
}
42+
43+
func WithPattern(pattern string) MockedPluginOption {
44+
return func(mp *mockedPlugin) {
45+
mp.pattern = pattern
46+
}
47+
}
48+
49+
func WithID(id secrets.ID) MockedPluginOption {
50+
return func(mp *mockedPlugin) {
51+
mp.id = id
52+
}
53+
}
54+
3755
func (m mockedPlugin) GetSecret(context.Context, secrets.Request) (secrets.Envelope, error) {
3856
return secrets.Envelope{ID: m.id, Value: []byte(mockSecretValue)}, nil
3957
}
@@ -59,27 +77,90 @@ func Test_newExternalPlugin(t *testing.T) {
5977
}{
6078
{
6179
name: "create external plugin",
80+
test: func(t *testing.T, l net.Listener, conn net.Conn) {
81+
doneRuntime := make(chan struct{})
82+
go func() {
83+
p := mockExternalPluginRuntime(t, l)
84+
e, err := p.GetSecret(t.Context(), secrets.Request{ID: mockSecretID})
85+
assert.NoError(t, err)
86+
assert.Equal(t, mockSecretValue, string(e.Value))
87+
assert.NoError(t, p.close())
88+
close(doneRuntime)
89+
}()
90+
91+
s, err := p.New(newMockedPlugin(), p.WithPluginName("my-plugin"), p.WithConnection(conn))
92+
require.NoError(t, err)
93+
assert.NoError(t, s.Run(context.Background()))
94+
<-doneRuntime
95+
},
96+
},
97+
{
98+
name: "plugin returns error on GetSecret",
99+
test: func(t *testing.T, l net.Listener, conn net.Conn) {
100+
doneRuntime := make(chan struct{})
101+
go func() {
102+
p := mockExternalPluginRuntime(t, l)
103+
_, err := p.GetSecret(t.Context(), secrets.Request{ID: mockSecretID})
104+
assert.ErrorContains(t, err, "id mismatch")
105+
assert.NoError(t, p.close())
106+
close(doneRuntime)
107+
}()
108+
109+
s, err := p.New(newMockedPlugin(WithID("rewrite-id")), p.WithPluginName("my-plugin"), p.WithConnection(conn))
110+
require.NoError(t, err)
111+
assert.NoError(t, s.Run(context.Background()))
112+
<-doneRuntime
113+
},
114+
},
115+
{
116+
name: "cancelling plugin.run() shuts down the runtime",
117+
test: func(t *testing.T, l net.Listener, conn net.Conn) {
118+
doneRuntime := make(chan struct{})
119+
donePlugin := make(chan struct{})
120+
downRuntime := make(chan struct{})
121+
go func() {
122+
p := mockExternalPluginRuntime(t, l)
123+
e, err := p.GetSecret(t.Context(), secrets.Request{ID: mockSecretID})
124+
assert.NoError(t, err)
125+
assert.Equal(t, mockSecretValue, string(e.Value))
126+
close(doneRuntime)
127+
<-donePlugin
128+
assert.NoError(t, p.close())
129+
close(downRuntime)
130+
}()
131+
132+
s, err := p.New(newMockedPlugin(), p.WithPluginName("my-plugin"), p.WithConnection(conn))
133+
require.NoError(t, err)
134+
ctx, cancel := context.WithCancel(t.Context())
135+
136+
go func() {
137+
assert.NoError(t, s.Run(ctx))
138+
close(donePlugin)
139+
}()
140+
<-doneRuntime
141+
cancel()
142+
<-downRuntime
143+
},
144+
},
145+
{
146+
name: "plugins with invalid patterns are rejected",
62147
test: func(t *testing.T, l net.Listener, conn net.Conn) {
63148
doneRuntime := make(chan struct{})
64149
go func() {
65150
conn, err := l.Accept()
66151
require.NoError(t, err)
67152

68-
p, err := newExternalPlugin(conn, setupValidator{
153+
_, err = newExternalPlugin(conn, setupValidator{
69154
out: pluginCfgOut{engineName: "test-engine", engineVersion: "1.0.0", requestTimeout: 30 * time.Second},
70155
acceptPattern: func(secrets.Pattern) error { return nil },
71156
})
72-
require.NoError(t, err)
73-
e, err := p.GetSecret(t.Context(), secrets.Request{ID: mockSecretID})
74-
assert.NoError(t, err)
75-
assert.Equal(t, mockSecretValue, string(e.Value))
76-
assert.NoError(t, p.close())
157+
assert.ErrorContains(t, err, "invalid pattern")
77158
close(doneRuntime)
78159
}()
79160

80-
s, err := p.New(mockPlugin, p.WithPluginName("my-plugin"), p.WithConnection(conn))
161+
s, err := p.New(newMockedPlugin(WithPattern("a*a")), p.WithPluginName("my-plugin"), p.WithConnection(conn))
81162
require.NoError(t, err)
82-
assert.NoError(t, s.Run(context.Background()))
163+
assert.ErrorContains(t, s.Run(t.Context()), "invalid pattern")
83164
<-doneRuntime
84165
},
85166
},
@@ -101,3 +182,15 @@ func Test_newExternalPlugin(t *testing.T) {
101182
})
102183
}
103184
}
185+
186+
func mockExternalPluginRuntime(t *testing.T, l net.Listener) *plugin {
187+
conn, err := l.Accept()
188+
require.NoError(t, err)
189+
190+
p, err := newExternalPlugin(conn, setupValidator{
191+
out: pluginCfgOut{engineName: "test-engine", engineVersion: "1.0.0", requestTimeout: 30 * time.Second},
192+
acceptPattern: func(secrets.Pattern) error { return nil },
193+
})
194+
require.NoError(t, err)
195+
return p
196+
}

pkg/adaptation/registration.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,15 @@ type RegisterService struct {
4242

4343
func (r *RegisterService) RegisterPlugin(ctx context.Context, c *connect.Request[resolverv1.RegisterPluginRequest]) (*connect.Response[resolverv1.RegisterPluginResponse], error) {
4444
logrus.Infof("Reveived plugin registration request: %s@%s (pattern: %v)", c.Msg.GetName(), c.Msg.GetVersion(), c.Msg.GetPattern())
45-
pattern, err := secrets.ParsePattern(c.Msg.GetPattern())
46-
if err != nil {
47-
return nil, connect.NewError(connect.CodeInvalidArgument, err)
48-
}
4945
in := pluginCfgIn{
5046
name: c.Msg.GetName(),
5147
version: c.Msg.GetVersion(),
52-
pattern: pattern,
48+
pattern: secrets.Pattern(c.Msg.GetPattern()),
5349
}
5450
out, err := r.r.register(ctx, in)
51+
if errors.Is(err, secrets.ErrInvalidPattern) {
52+
return nil, connect.NewError(connect.CodeInvalidArgument, err)
53+
}
5554
if err != nil {
5655
return nil, connect.NewError(connect.CodeInternal, err)
5756
}

pkg/adaptation/registration_test.go

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,6 @@ func Test_RegisterPlugin(t *testing.T) {
148148
in pluginCfgIn
149149
test func(t *testing.T, resp *connect.Response[resolverv1.RegisterPluginResponse], err error)
150150
}{
151-
{
152-
name: "invalid pattern",
153-
r: mockPluginRegistratorOK(t),
154-
in: pluginCfgIn{pattern: "*a*"},
155-
test: func(t *testing.T, _ *connect.Response[resolverv1.RegisterPluginResponse], err error) {
156-
assert.ErrorContains(t, err, "invalid pattern")
157-
},
158-
},
159151
{
160152
name: "registration fails",
161153
r: mockPluginRegistratorErr(t),

pkg/adaptation/setup.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ func setup(conn net.Conn, v setupValidator) (*setupResult, error) {
5353
case r := <-chRegistrationResult:
5454
if r.err != nil {
5555
i.Close()
56-
return nil, fmt.Errorf("failed to register plugin: %w", err)
56+
return nil, fmt.Errorf("failed to register plugin: %w", r.err)
5757
}
5858
out = r.cfg
5959
case err := <-chIpcErr:
@@ -72,6 +72,9 @@ func setup(conn net.Conn, v setupValidator) (*setupResult, error) {
7272
}
7373

7474
func (p setupValidator) Validate(in pluginCfgIn) (*pluginCfgOut, error) {
75+
if err := in.pattern.Valid(); err != nil {
76+
return nil, err
77+
}
7578
if p.name != "" && in.name != p.name {
7679
return nil, errors.New("plugin name cannot be changed when launched by engine")
7780
}

0 commit comments

Comments
 (0)