Skip to content

Commit 9f62b5d

Browse files
committed
Ensure thread safety
1 parent ab1dcbd commit 9f62b5d

1 file changed

Lines changed: 12 additions & 6 deletions

File tree

pkg/beholder/auth.go

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"crypto/rand"
88
"encoding/binary"
99
"fmt"
10+
"maps"
1011
"sync"
1112
"sync/atomic"
1213
"time"
@@ -96,8 +97,10 @@ func NewRotatingAuth(csaPubKey ed25519.PublicKey, signer Signer, ttl time.Durati
9697

9798
func (r *rotatingAuth) Headers(ctx context.Context) (map[string]string, error) {
9899

100+
// Return a copy of the headers to avoid concurrent read/write to the map by callers
101+
returnHeader := make(map[string]string)
99102
lastUpdated := time.Unix(0, r.lastUpdatedNanos.Load())
100-
// Check if we need to get the lock
103+
101104
if time.Since(lastUpdated) > r.ttl {
102105

103106
r.mu.Lock()
@@ -108,7 +111,8 @@ func (r *rotatingAuth) Headers(ctx context.Context) (map[string]string, error) {
108111
// updated the headers and lastUpdated while waiting for the lock.
109112
lastUpdated = time.Unix(0, r.lastUpdatedNanos.Load())
110113
if time.Since(lastUpdated) < r.ttl {
111-
return r.headers.Load().(map[string]string), nil
114+
maps.Copy(returnHeader, r.headers.Load().(map[string]string))
115+
return returnHeader, nil
112116
}
113117

114118
// Append the bytes of the public key with bytes of the timestamp to create the message to sign
@@ -126,14 +130,16 @@ func (r *rotatingAuth) Headers(ctx context.Context) (map[string]string, error) {
126130
return nil, fmt.Errorf("beholder: failed to sign auth header: %w", err)
127131
}
128132

129-
headers := r.headers.Load().(map[string]string)
130-
headers[authHeaderKey] = fmt.Sprintf("%s:%x:%d:%x", authHeaderV2, r.csaPubKey, ts.UnixNano(), signature)
133+
newHeaders := make(map[string]string)
134+
newHeaders[authHeaderKey] = fmt.Sprintf("%s:%x:%d:%x", authHeaderV2, r.csaPubKey, ts.UnixNano(), signature)
131135

132-
r.headers.Store(headers)
136+
r.headers.Store(newHeaders)
133137
r.lastUpdatedNanos.Store(ts.UnixNano())
134138
}
135139

136-
return r.headers.Load().(map[string]string), nil
140+
maps.Copy(returnHeader, r.headers.Load().(map[string]string))
141+
142+
return returnHeader, nil
137143
}
138144

139145
func (a *rotatingAuth) Credentials() credentials.PerRPCCredentials {

0 commit comments

Comments
 (0)