Skip to content

Commit 48c9bf5

Browse files
authored
pkg/utils: deprecate StartStopOnce in favor of StateMachine (#203)
1 parent 7dc07d3 commit 48c9bf5

4 files changed

Lines changed: 203 additions & 175 deletions

File tree

pkg/loop/internal/plugin_service.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ type grpcPlugin interface {
2828
// pluginService is a [types.Service] wrapper that maintains an internal [types.Service] created from a [grpcPlugin]
2929
// client instance by launching and re-launching as necessary.
3030
type PluginService[P grpcPlugin, S services.Service] struct {
31-
utils.StartStopOnce
31+
services.StateMachine
3232

3333
pluginName string
3434

pkg/services/state.go

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import (
1313
const defaultErrorBufferCap = 50
1414

1515
type errNotStarted struct {
16-
state State
16+
state state
1717
}
1818

1919
func (e *errNotStarted) Error() string {
@@ -25,7 +25,7 @@ var (
2525
ErrCannotStopUnstarted = errors.New("cannot stop unstarted service")
2626
)
2727

28-
// StateMachine contains a State integer
28+
// StateMachine contains a state integer
2929
type StateMachine struct {
3030
state atomic.Int32
3131
sync.RWMutex // lock is held during startup/shutdown, RLock is held while executing functions dependent on a particular state
@@ -38,12 +38,12 @@ type StateMachine struct {
3838
SvcErrBuffer ErrorBuffer
3939
}
4040

41-
// State holds the state for StateMachine
42-
type State int32
41+
// state holds the state for StateMachine
42+
type state int32
4343

4444
// nolint
4545
const (
46-
stateUnstarted State = iota
46+
stateUnstarted state = iota
4747
stateStarted
4848
stateStarting
4949
stateStartFailed
@@ -52,7 +52,7 @@ const (
5252
stateStopFailed
5353
)
5454

55-
func (s State) String() string {
55+
func (s state) String() string {
5656
switch s {
5757
case stateUnstarted:
5858
return "Unstarted"
@@ -80,7 +80,7 @@ func (once *StateMachine) StartOnce(name string, fn func() error) error {
8080
success := once.state.CompareAndSwap(int32(stateUnstarted), int32(stateStarting))
8181

8282
if !success {
83-
return pkgerrors.Errorf("%v has already been started once; state=%v", name, State(once.state.Load()))
83+
return pkgerrors.Errorf("%v has already been started once; state=%v", name, state(once.state.Load()))
8484
}
8585

8686
once.Lock()
@@ -116,14 +116,14 @@ func (once *StateMachine) StopOnce(name string, fn func() error) error {
116116
success := once.state.CompareAndSwap(int32(stateStarted), int32(stateStopping))
117117

118118
if !success {
119-
state := once.state.Load()
120-
switch state {
121-
case int32(stateStopped):
119+
s := once.loadState()
120+
switch s {
121+
case stateStopped:
122122
return pkgerrors.Wrapf(ErrAlreadyStopped, "%s has already been stopped", name)
123-
case int32(stateUnstarted):
123+
case stateUnstarted:
124124
return pkgerrors.Wrapf(ErrCannotStopUnstarted, "%s has not been started", name)
125125
default:
126-
return pkgerrors.Errorf("%v cannot be stopped from this state; state=%v", name, State(state))
126+
return pkgerrors.Errorf("%v cannot be stopped from this state; state=%v", name, s)
127127
}
128128
}
129129

@@ -145,19 +145,20 @@ func (once *StateMachine) StopOnce(name string, fn func() error) error {
145145
}
146146

147147
// State retrieves the current state
148-
func (once *StateMachine) State() State {
149-
state := once.state.Load()
150-
return State(state)
148+
func (once *StateMachine) State() string {
149+
return once.loadState().String()
150+
}
151+
152+
func (once *StateMachine) loadState() state {
153+
return state(once.state.Load())
151154
}
152155

153156
// IfStarted runs the func and returns true only if started, otherwise returns false
154157
func (once *StateMachine) IfStarted(f func()) (ok bool) {
155158
once.RLock()
156159
defer once.RUnlock()
157160

158-
state := once.state.Load()
159-
160-
if State(state) == stateStarted {
161+
if once.loadState() == stateStarted {
161162
f()
162163
return true
163164
}
@@ -169,9 +170,7 @@ func (once *StateMachine) IfNotStopped(f func()) (ok bool) {
169170
once.RLock()
170171
defer once.RUnlock()
171172

172-
state := once.state.Load()
173-
174-
if State(state) == stateStopped {
173+
if once.loadState() == stateStopped {
175174
return false
176175
}
177176
f()
@@ -180,7 +179,7 @@ func (once *StateMachine) IfNotStopped(f func()) (ok bool) {
180179

181180
// Ready returns ErrNotStarted if the state is not started.
182181
func (once *StateMachine) Ready() error {
183-
state := once.State()
182+
state := once.loadState()
184183
if state == stateStarted {
185184
return nil
186185
}
@@ -190,7 +189,7 @@ func (once *StateMachine) Ready() error {
190189
// Healthy returns ErrNotStarted if the state is not started.
191190
// Override this per-service with more specific implementations.
192191
func (once *StateMachine) Healthy() error {
193-
state := once.State()
192+
state := once.loadState()
194193
if state == stateStarted {
195194
return once.SvcErrBuffer.Flush()
196195
}

pkg/services/state_test.go

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
package services
2+
3+
import (
4+
"errors"
5+
"sync/atomic"
6+
"testing"
7+
"time"
8+
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
)
12+
13+
func TestStateMachine_StartOnce_StopOnce(t *testing.T) {
14+
t.Parallel()
15+
16+
var sm StateMachine
17+
18+
ch := make(chan int, 3)
19+
20+
ready := make(chan bool)
21+
22+
go func() {
23+
assert.NoError(t, sm.StartOnce("slow service", func() (err error) {
24+
ch <- 1
25+
ready <- true
26+
<-time.After(time.Millisecond * 500) // wait for StopOnce to happen
27+
ch <- 2
28+
29+
return nil
30+
}))
31+
}()
32+
33+
go func() {
34+
<-ready // try stopping halfway through startup
35+
assert.NoError(t, sm.StopOnce("slow service", func() (err error) {
36+
ch <- 3
37+
38+
return nil
39+
}))
40+
}()
41+
42+
require.Equal(t, 1, <-ch)
43+
require.Equal(t, 2, <-ch)
44+
require.Equal(t, 3, <-ch)
45+
}
46+
47+
func TestStateMachine_MultipleStartNoBlock(t *testing.T) {
48+
t.Parallel()
49+
50+
var sm StateMachine
51+
52+
ch := make(chan int, 3)
53+
54+
ready := make(chan bool)
55+
next := make(chan bool)
56+
57+
go func() {
58+
ch <- 1
59+
assert.NoError(t, sm.StartOnce("slow service", func() (err error) {
60+
ready <- true
61+
<-next // continue after the other StartOnce call fails
62+
63+
return nil
64+
}))
65+
<-next
66+
ch <- 2
67+
68+
}()
69+
70+
go func() {
71+
<-ready // try starting halfway through startup
72+
assert.Error(t, sm.StartOnce("slow service", func() (err error) {
73+
return nil
74+
}))
75+
next <- true
76+
ch <- 3
77+
next <- true
78+
79+
}()
80+
81+
require.Equal(t, 1, <-ch)
82+
require.Equal(t, 3, <-ch) // 3 arrives before 2 because it returns immediately
83+
require.Equal(t, 2, <-ch)
84+
}
85+
86+
func TestStateMachine(t *testing.T) {
87+
t.Parallel()
88+
89+
var callsCount atomic.Int32
90+
incCount := func() {
91+
callsCount.Add(1)
92+
}
93+
94+
var s StateMachine
95+
ok := s.IfStarted(incCount)
96+
assert.False(t, ok)
97+
ok = s.IfNotStopped(incCount)
98+
assert.True(t, ok)
99+
assert.Equal(t, int32(1), callsCount.Load())
100+
101+
err := s.StartOnce("foo", func() error { return nil })
102+
assert.NoError(t, err)
103+
104+
assert.True(t, s.IfStarted(incCount))
105+
assert.Equal(t, int32(2), callsCount.Load())
106+
107+
err = s.StopOnce("foo", func() error { return nil })
108+
assert.NoError(t, err)
109+
ok = s.IfNotStopped(incCount)
110+
assert.False(t, ok)
111+
assert.Equal(t, int32(2), callsCount.Load())
112+
}
113+
114+
func TestStateMachine_StartErrors(t *testing.T) {
115+
var s StateMachine
116+
117+
err := s.StartOnce("foo", func() error { return errors.New("foo") })
118+
assert.Error(t, err)
119+
120+
var callsCount atomic.Int32
121+
incCount := func() {
122+
callsCount.Add(1)
123+
}
124+
125+
assert.False(t, s.IfStarted(incCount))
126+
assert.Equal(t, int32(0), callsCount.Load())
127+
128+
err = s.StartOnce("foo", func() error { return nil })
129+
require.Error(t, err)
130+
assert.Contains(t, err.Error(), "foo has already been started once")
131+
err = s.StopOnce("foo", func() error { return nil })
132+
require.Error(t, err)
133+
assert.Contains(t, err.Error(), "foo cannot be stopped from this state; state=StartFailed")
134+
135+
assert.Equal(t, stateStartFailed, s.loadState())
136+
}
137+
138+
func TestStateMachine_StopErrors(t *testing.T) {
139+
var s StateMachine
140+
141+
err := s.StartOnce("foo", func() error { return nil })
142+
require.NoError(t, err)
143+
144+
var callsCount atomic.Int32
145+
incCount := func() {
146+
callsCount.Add(1)
147+
}
148+
149+
err = s.StopOnce("foo", func() error { return errors.New("explodey mcsplode") })
150+
assert.Error(t, err)
151+
152+
assert.False(t, s.IfStarted(incCount))
153+
assert.Equal(t, int32(0), callsCount.Load())
154+
assert.True(t, s.IfNotStopped(incCount))
155+
assert.Equal(t, int32(1), callsCount.Load())
156+
157+
err = s.StartOnce("foo", func() error { return nil })
158+
require.Error(t, err)
159+
assert.Contains(t, err.Error(), "foo has already been started once")
160+
err = s.StopOnce("foo", func() error { return nil })
161+
require.Error(t, err)
162+
assert.Contains(t, err.Error(), "foo cannot be stopped from this state; state=StopFailed")
163+
164+
assert.Equal(t, stateStopFailed, s.loadState())
165+
}
166+
167+
func TestStateMachine_Ready_Healthy(t *testing.T) {
168+
t.Parallel()
169+
170+
var s StateMachine
171+
assert.Error(t, s.Ready())
172+
assert.Error(t, s.Healthy())
173+
174+
err := s.StartOnce("foo", func() error { return nil })
175+
assert.NoError(t, err)
176+
assert.NoError(t, s.Ready())
177+
assert.NoError(t, s.Healthy())
178+
}

0 commit comments

Comments
 (0)