Skip to content

Commit 97bbfaf

Browse files
committed
Fix callback race window after Close
1 parent b50a026 commit 97bbfaf

1 file changed

Lines changed: 68 additions & 16 deletions

File tree

watcher.go

Lines changed: 68 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@ type Watcher struct {
2525
waitTimeout time.Duration
2626
logger logger.Logger
2727
watcher *fsnotify.Watcher
28+
29+
stateAccess sync.Mutex
30+
timerMap map[string]pendingCallback
31+
closed bool
32+
}
33+
34+
type pendingCallback struct {
35+
generation uint64
36+
timer *time.Timer
2837
}
2938

3039
type Options struct {
@@ -74,6 +83,7 @@ func NewWatcher(options Options) (*Watcher, error) {
7483
callback: options.Callback,
7584
waitTimeout: waitTimeout,
7685
logger: options.Logger,
86+
timerMap: make(map[string]pendingCallback),
7787
}, nil
7888
}
7989

@@ -94,12 +104,28 @@ func (w *Watcher) Start() error {
94104
}
95105

96106
func (w *Watcher) Close() error {
97-
return common.Close(common.PtrOrNil(w.watcher))
107+
w.stateAccess.Lock()
108+
if w.closed {
109+
watcher := w.watcher
110+
w.stateAccess.Unlock()
111+
return common.Close(common.PtrOrNil(watcher))
112+
}
113+
w.closed = true
114+
timers := make([]*time.Timer, 0, len(w.timerMap))
115+
for path, pending := range w.timerMap {
116+
timers = append(timers, pending.timer)
117+
delete(w.timerMap, path)
118+
}
119+
watcher := w.watcher
120+
w.stateAccess.Unlock()
121+
122+
for _, timer := range timers {
123+
timer.Stop()
124+
}
125+
return common.Close(common.PtrOrNil(watcher))
98126
}
99127

100128
func (w *Watcher) loopUpdate() {
101-
var timerAccess sync.Mutex
102-
timerMap := make(map[string]*time.Timer)
103129
for {
104130
select {
105131
case event, loaded := <-w.watcher.Events:
@@ -111,19 +137,7 @@ func (w *Watcher) loopUpdate() {
111137
w.logger.Error("fswatch: watcher removed: ", event.Name)
112138
}
113139
} else if common.Contains(w.watchPath, event.Name) && (event.Has(fsnotify.Create) || event.Has(fsnotify.Write)) {
114-
timerAccess.Lock()
115-
timer := timerMap[event.Name]
116-
if timer != nil {
117-
timer.Reset(w.waitTimeout)
118-
} else {
119-
timerMap[event.Name] = time.AfterFunc(w.waitTimeout, func() {
120-
w.callback(event.Name)
121-
timerAccess.Lock()
122-
delete(timerMap, event.Name)
123-
timerAccess.Unlock()
124-
})
125-
}
126-
timerAccess.Unlock()
140+
w.scheduleCallback(event.Name)
127141
}
128142
case err, loaded := <-w.watcher.Errors:
129143
if !loaded {
@@ -135,3 +149,41 @@ func (w *Watcher) loopUpdate() {
135149
}
136150
}
137151
}
152+
153+
func (w *Watcher) scheduleCallback(path string) {
154+
w.stateAccess.Lock()
155+
if w.closed {
156+
w.stateAccess.Unlock()
157+
return
158+
}
159+
160+
pending := w.timerMap[path]
161+
generation := pending.generation + 1
162+
if pending.timer != nil {
163+
pending.timer.Stop()
164+
}
165+
w.timerMap[path] = pendingCallback{
166+
generation: generation,
167+
timer: time.AfterFunc(w.waitTimeout, func() { w.fireCallback(path, generation) }),
168+
}
169+
w.stateAccess.Unlock()
170+
}
171+
172+
func (w *Watcher) fireCallback(path string, generation uint64) {
173+
w.stateAccess.Lock()
174+
if w.closed {
175+
w.stateAccess.Unlock()
176+
return
177+
}
178+
179+
pending, loaded := w.timerMap[path]
180+
if !loaded || pending.generation != generation {
181+
w.stateAccess.Unlock()
182+
return
183+
}
184+
delete(w.timerMap, path)
185+
callback := w.callback
186+
w.stateAccess.Unlock()
187+
188+
callback(path)
189+
}

0 commit comments

Comments
 (0)