Skip to content

Commit f731a12

Browse files
Add locks to channel maps
1 parent 6440450 commit f731a12

1 file changed

Lines changed: 14 additions & 4 deletions

File tree

internal/libvirt/libvirt.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,11 @@ type LibVirt struct {
4646
version string
4747

4848
// Event channels for domains by their libvirt event id.
49-
domEventChs map[libvirt.DomainEventID]<-chan any
49+
domEventChs map[libvirt.DomainEventID]<-chan any
50+
domEventChsLock sync.Mutex
5051
// Event listeners for domain events by their own identifier.
51-
domEventChangeHandlers map[libvirt.DomainEventID]map[string]func(context.Context, any)
52+
domEventChangeHandlers map[libvirt.DomainEventID]map[string]func(context.Context, any)
53+
domEventChangeHandlersLock sync.Mutex
5254

5355
// Client that connects to libvirt and fetches capabilities of the
5456
// hypervisor. The capabilities client abstracts the xml parsing away.
@@ -79,8 +81,8 @@ func NewLibVirt(k client.Client) *LibVirt {
7981
make(map[string]context.CancelFunc),
8082
sync.Mutex{},
8183
"N/A",
82-
make(map[libvirt.DomainEventID]<-chan any),
83-
make(map[libvirt.DomainEventID]map[string]func(context.Context, any)),
84+
make(map[libvirt.DomainEventID]<-chan any), sync.Mutex{},
85+
make(map[libvirt.DomainEventID]map[string]func(context.Context, any)), sync.Mutex{},
8486
capabilities.NewClient(),
8587
domcapabilities.NewClient(),
8688
dominfo.NewClient(),
@@ -149,13 +151,15 @@ func (l *LibVirt) runEventLoop(ctx context.Context) {
149151
// a dynamic set of channels.
150152
var cases []reflect.SelectCase
151153
var eventIds []libvirt.DomainEventID
154+
l.domEventChsLock.Lock()
152155
for eventId, ch := range l.domEventChs {
153156
cases = append(cases, reflect.SelectCase{
154157
Dir: reflect.SelectRecv,
155158
Chan: reflect.ValueOf(ch),
156159
})
157160
eventIds = append(eventIds, eventId)
158161
}
162+
l.domEventChsLock.Unlock()
159163

160164
cases = append(cases, reflect.SelectCase{
161165
Dir: reflect.SelectRecv,
@@ -181,7 +185,9 @@ func (l *LibVirt) runEventLoop(ctx context.Context) {
181185

182186
// Distribute the event to all registered handlers.
183187
eventId := eventIds[chosen] // safe as chosen < len(eventIds)
188+
l.domEventChangeHandlersLock.Lock()
184189
handlers, exists := l.domEventChangeHandlers[eventId]
190+
l.domEventChangeHandlersLock.Unlock()
185191
if !exists {
186192
continue
187193
}
@@ -206,13 +212,17 @@ func (l *LibVirt) WatchDomainChanges(
206212

207213
// Register the handler so that it is called when an event with the provided
208214
// eventId is received.
215+
l.domEventChangeHandlersLock.Lock()
216+
defer l.domEventChangeHandlersLock.Unlock()
209217
if _, exists := l.domEventChangeHandlers[eventId]; !exists {
210218
l.domEventChangeHandlers[eventId] = make(map[string]func(context.Context, any))
211219
}
212220
l.domEventChangeHandlers[eventId][handlerId] = handler
213221

214222
// If we are already subscribed to this eventId, nothing more to do.
215223
// Note: subscribing more than once will be blocked by the libvirt client.
224+
l.domEventChsLock.Lock()
225+
defer l.domEventChsLock.Unlock()
216226
if _, exists := l.domEventChs[eventId]; exists {
217227
return
218228
}

0 commit comments

Comments
 (0)