Skip to content

Commit 9b0adb2

Browse files
committed
Fix race conditions
1 parent 4b78281 commit 9b0adb2

4 files changed

Lines changed: 90 additions & 26 deletions

File tree

dnscrypt-proxy/plugin_block_name.go

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"io"
77
"net"
88
"strings"
9+
"sync"
910
"time"
1011

1112
"github.com/jedisct1/dlog"
@@ -21,7 +22,11 @@ type BlockedNames struct {
2122

2223
const aliasesLimit = 8
2324

24-
var blockedNames *BlockedNames
25+
var (
26+
// protects access to the blockedNames global variable
27+
blockedNamesLock sync.RWMutex
28+
blockedNames *BlockedNames
29+
)
2530

2631
func (blockedNames *BlockedNames) check(pluginsState *PluginsState, qName string, aliasFor *string) (bool, error) {
2732
reject, reason, xweeklyRanges := blockedNames.patternMatcher.Eval(qName)
@@ -127,12 +132,14 @@ func (plugin *PluginBlockName) Init(proxy *Proxy) error {
127132
continue
128133
}
129134
}
130-
blockedNames = &xBlockedNames
131-
if len(proxy.blockNameLogFile) == 0 {
132-
return nil
135+
if len(proxy.blockNameLogFile) > 0 {
136+
xBlockedNames.logger = Logger(proxy.logMaxSize, proxy.logMaxAge, proxy.logMaxBackups, proxy.blockNameLogFile)
137+
xBlockedNames.format = proxy.blockNameFormat
133138
}
134-
blockedNames.logger = Logger(proxy.logMaxSize, proxy.logMaxAge, proxy.logMaxBackups, proxy.blockNameLogFile)
135-
blockedNames.format = proxy.blockNameFormat
139+
140+
blockedNamesLock.Lock()
141+
blockedNames = &xBlockedNames
142+
blockedNamesLock.Unlock()
136143

137144
return nil
138145
}
@@ -146,10 +153,19 @@ func (plugin *PluginBlockName) Reload() error {
146153
}
147154

148155
func (plugin *PluginBlockName) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
149-
if blockedNames == nil || pluginsState.sessionData["whitelisted"] != nil {
156+
if pluginsState.sessionData["whitelisted"] != nil {
150157
return nil
151158
}
152-
_, err := blockedNames.check(pluginsState, pluginsState.qName, nil)
159+
160+
blockedNamesLock.RLock()
161+
localBlockedNames := blockedNames
162+
blockedNamesLock.RUnlock()
163+
164+
if localBlockedNames == nil {
165+
return nil
166+
}
167+
168+
_, err := localBlockedNames.check(pluginsState, pluginsState.qName, nil)
153169
return err
154170
}
155171

@@ -178,9 +194,18 @@ func (plugin *PluginBlockNameResponse) Reload() error {
178194
}
179195

180196
func (plugin *PluginBlockNameResponse) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
181-
if blockedNames == nil || pluginsState.sessionData["whitelisted"] != nil {
197+
if pluginsState.sessionData["whitelisted"] != nil {
182198
return nil
183199
}
200+
201+
blockedNamesLock.RLock()
202+
localBlockedNames := blockedNames
203+
blockedNamesLock.RUnlock()
204+
205+
if localBlockedNames == nil {
206+
return nil
207+
}
208+
184209
aliasFor := pluginsState.qName
185210
aliasesLeft := aliasesLimit
186211
answers := msg.Answer
@@ -203,7 +228,7 @@ func (plugin *PluginBlockNameResponse) Eval(pluginsState *PluginsState, msg *dns
203228
if err != nil {
204229
return err
205230
}
206-
if blocked, err := blockedNames.check(pluginsState, target, &aliasFor); blocked || err != nil {
231+
if blocked, err := localBlockedNames.check(pluginsState, target, &aliasFor); blocked || err != nil {
207232
return err
208233
}
209234
aliasesLeft--

dnscrypt-proxy/plugin_cache.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package main
33
import (
44
"crypto/sha512"
55
"encoding/binary"
6+
"fmt"
67
"sync"
78
"time"
89

@@ -149,11 +150,10 @@ func (plugin *PluginCacheResponse) Eval(pluginsState *PluginsState, msg *dns.Msg
149150
}
150151
cachedResponses.Lock()
151152
if cachedResponses.cache == nil {
152-
var err error
153153
cachedResponses.cache = sieve.New[[32]byte, CachedResponse](pluginsState.cacheSize)
154154
if cachedResponses.cache == nil {
155155
cachedResponses.Unlock()
156-
return err
156+
return fmt.Errorf("failed to initialize the cache")
157157
}
158158
}
159159
cachedResponses.cache.Add(cacheKey, cachedResponse)

dnscrypt-proxy/proxy.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -634,10 +634,16 @@ func (proxy *Proxy) clientsCountInc() bool {
634634

635635
func (proxy *Proxy) clientsCountDec() {
636636
for {
637-
if count := atomic.LoadUint32(&proxy.clientsCount); count == 0 ||
638-
atomic.CompareAndSwapUint32(&proxy.clientsCount, count, count-1) {
637+
count := atomic.LoadUint32(&proxy.clientsCount)
638+
if count == 0 {
639+
// Already at zero, nothing to do
640+
break
641+
}
642+
if atomic.CompareAndSwapUint32(&proxy.clientsCount, count, count-1) {
643+
dlog.Debugf("clients count: %d", count-1)
639644
break
640645
}
646+
// CAS failed, retry with updated count
641647
}
642648
}
643649

dnscrypt-proxy/sources.go

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"os"
99
"path/filepath"
1010
"strings"
11+
"sync"
1112
"time"
1213

1314
"github.com/dchest/safefile"
@@ -29,6 +30,7 @@ const (
2930
)
3031

3132
type Source struct {
33+
sync.RWMutex
3234
name string
3335
urls []*url.URL
3436
format SourceFormat
@@ -40,8 +42,20 @@ type Source struct {
4042
prefix string
4143
}
4244

43-
// timeNow() is replaced by tests to provide a static value
44-
var timeNow = time.Now
45+
// timeNow is a function variable that provides the current time
46+
// It's replaced by tests to provide a static value
47+
// Access to this variable is synchronized to prevent race conditions
48+
var (
49+
timeNowMutex sync.RWMutex
50+
timeNow = time.Now
51+
)
52+
53+
// getCurrentTime safely gets the current time using the timeNow function
54+
func getCurrentTime() time.Time {
55+
timeNowMutex.RLock()
56+
defer timeNowMutex.RUnlock()
57+
return timeNow()
58+
}
4559

4660
func (source *Source) checkSignature(bin, sig []byte) error {
4761
signature, err := minisign.DecodeSignature(string(sig))
@@ -51,7 +65,8 @@ func (source *Source) checkSignature(bin, sig []byte) error {
5165
return err
5266
}
5367

54-
func (source *Source) fetchFromCache(now time.Time) (time.Duration, error) {
68+
func (source *Source) fetchFromCache() (time.Duration, error) {
69+
now := getCurrentTime()
5570
var err error
5671
var bin, sig []byte
5772
if bin, err = os.ReadFile(source.cacheFile); err != nil {
@@ -63,7 +78,11 @@ func (source *Source) fetchFromCache(now time.Time) (time.Duration, error) {
6378
if err = source.checkSignature(bin, sig); err != nil {
6479
return 0, err
6580
}
81+
82+
source.Lock()
6683
source.bin = bin
84+
source.Unlock()
85+
6786
var fi os.FileInfo
6887
if fi, err = os.Stat(source.cacheFile); err != nil {
6988
return 0, err
@@ -101,14 +120,19 @@ func writeSource(f string, bin, sig []byte) error {
101120
return fSig.Commit()
102121
}
103122

104-
func (source *Source) updateCache(bin, sig []byte, now time.Time) {
123+
func (source *Source) updateCache(bin, sig []byte) {
124+
now := getCurrentTime()
105125
file := source.cacheFile
106126
absPath := file
107127
if resolved, err := filepath.Abs(file); err != nil {
108128
absPath = resolved
109129
}
110130

111-
if !bytes.Equal(source.bin, bin) {
131+
source.Lock()
132+
needsWrite := !bytes.Equal(source.bin, bin)
133+
source.Unlock()
134+
135+
if needsWrite {
112136
if err := writeSource(file, bin, sig); err != nil {
113137
dlog.Warnf("Couldn't write cache file [%s]: %s", absPath, err) // an error writing to the cache isn't fatal
114138
}
@@ -117,7 +141,9 @@ func (source *Source) updateCache(bin, sig []byte, now time.Time) {
117141
dlog.Warnf("Couldn't update cache file [%s]: %s", absPath, err)
118142
}
119143

144+
source.Lock()
120145
source.bin = bin
146+
source.Unlock()
121147
}
122148

123149
func (source *Source) parseURLs(urls []string) {
@@ -135,10 +161,11 @@ func fetchFromURL(xTransport *XTransport, u *url.URL) ([]byte, error) {
135161
return bin, err
136162
}
137163

138-
func (source *Source) fetchWithCache(xTransport *XTransport, now time.Time) (time.Duration, error) {
164+
func (source *Source) fetchWithCache(xTransport *XTransport) (time.Duration, error) {
165+
now := getCurrentTime()
139166
var err error
140167
var ttl time.Duration
141-
if ttl, err = source.fetchFromCache(now); err != nil {
168+
if ttl, err = source.fetchFromCache(); err != nil {
142169
if len(source.urls) == 0 {
143170
dlog.Errorf("Source [%s] cache file [%s] not present and no valid URL", source.name, source.cacheFile)
144171
return 0, err
@@ -179,7 +206,7 @@ func (source *Source) fetchWithCache(xTransport *XTransport, now time.Time) (tim
179206
if err != nil {
180207
return 0, err
181208
}
182-
source.updateCache(bin, sig, now)
209+
source.updateCache(bin, sig)
183210
ttl = source.prefetchDelay
184211
source.refresh = now.Add(ttl)
185212
return ttl, nil
@@ -218,7 +245,7 @@ func NewSource(
218245
return source, err
219246
}
220247
source.parseURLs(urls)
221-
_, err := source.fetchWithCache(xTransport, timeNow())
248+
_, err := source.fetchWithCache(xTransport)
222249
if err == nil {
223250
dlog.Noticef("Source [%s] loaded", name)
224251
}
@@ -227,14 +254,14 @@ func NewSource(
227254

228255
// PrefetchSources downloads latest versions of given sources, ensuring they have a valid signature before caching
229256
func PrefetchSources(xTransport *XTransport, sources []*Source) time.Duration {
230-
now := timeNow()
257+
now := getCurrentTime()
231258
interval := MinimumPrefetchInterval
232259
for _, source := range sources {
233260
if source.refresh.IsZero() || source.refresh.After(now) {
234261
continue
235262
}
236263
dlog.Debugf("Prefetching [%s]", source.name)
237-
if delay, err := source.fetchWithCache(xTransport, now); err != nil {
264+
if delay, err := source.fetchWithCache(xTransport); err != nil {
238265
dlog.Infof("Prefetching [%s] failed: %v, will retry in %v", source.name, err, interval)
239266
} else {
240267
dlog.Debugf("Prefetching [%s] succeeded, next update in %v min", source.name, delay)
@@ -262,7 +289,13 @@ func (source *Source) parseV2() ([]RegisteredServer, error) {
262289
stampErrs = append(stampErrs, stampErr)
263290
dlog.Warn(stampErr)
264291
}
265-
in := string(source.bin)
292+
293+
source.RLock()
294+
binCopy := make([]byte, len(source.bin))
295+
copy(binCopy, source.bin)
296+
source.RUnlock()
297+
298+
in := string(binCopy)
266299
parts := strings.Split(in, "## ")
267300
if len(parts) < 2 {
268301
return registeredServers, fmt.Errorf("Invalid format for source at [%v]", source.urls)

0 commit comments

Comments
 (0)