88 "os"
99 "path/filepath"
1010 "strings"
11+ "sync"
1112 "time"
1213
1314 "github.com/dchest/safefile"
@@ -29,6 +30,7 @@ const (
2930)
3031
3132type Source struct {
33+ sync.RWMutex
3234 name string
3335 urls []* url.URL
3436 format SourceFormat
@@ -40,8 +42,20 @@ type Source struct {
4042 prefix string
4143}
4244
43- // timeNow() is replaced by tests to provide a static value
44- var timeNow = time .Now
45+ // timeNow is a function variable that provides the current time
46+ // It's replaced by tests to provide a static value
47+ // Access to this variable is synchronized to prevent race conditions
48+ var (
49+ timeNowMutex sync.RWMutex
50+ timeNow = time .Now
51+ )
52+
53+ // getCurrentTime safely gets the current time using the timeNow function
54+ func getCurrentTime () time.Time {
55+ timeNowMutex .RLock ()
56+ defer timeNowMutex .RUnlock ()
57+ return timeNow ()
58+ }
4559
4660func (source * Source ) checkSignature (bin , sig []byte ) error {
4761 signature , err := minisign .DecodeSignature (string (sig ))
@@ -51,7 +65,8 @@ func (source *Source) checkSignature(bin, sig []byte) error {
5165 return err
5266}
5367
54- func (source * Source ) fetchFromCache (now time.Time ) (time.Duration , error ) {
68+ func (source * Source ) fetchFromCache () (time.Duration , error ) {
69+ now := getCurrentTime ()
5570 var err error
5671 var bin , sig []byte
5772 if bin , err = os .ReadFile (source .cacheFile ); err != nil {
@@ -63,7 +78,11 @@ func (source *Source) fetchFromCache(now time.Time) (time.Duration, error) {
6378 if err = source .checkSignature (bin , sig ); err != nil {
6479 return 0 , err
6580 }
81+
82+ source .Lock ()
6683 source .bin = bin
84+ source .Unlock ()
85+
6786 var fi os.FileInfo
6887 if fi , err = os .Stat (source .cacheFile ); err != nil {
6988 return 0 , err
@@ -101,14 +120,19 @@ func writeSource(f string, bin, sig []byte) error {
101120 return fSig .Commit ()
102121}
103122
104- func (source * Source ) updateCache (bin , sig []byte , now time.Time ) {
123+ func (source * Source ) updateCache (bin , sig []byte ) {
124+ now := getCurrentTime ()
105125 file := source .cacheFile
106126 absPath := file
107127 if resolved , err := filepath .Abs (file ); err != nil {
108128 absPath = resolved
109129 }
110130
111- if ! bytes .Equal (source .bin , bin ) {
131+ source .Lock ()
132+ needsWrite := ! bytes .Equal (source .bin , bin )
133+ source .Unlock ()
134+
135+ if needsWrite {
112136 if err := writeSource (file , bin , sig ); err != nil {
113137 dlog .Warnf ("Couldn't write cache file [%s]: %s" , absPath , err ) // an error writing to the cache isn't fatal
114138 }
@@ -117,7 +141,9 @@ func (source *Source) updateCache(bin, sig []byte, now time.Time) {
117141 dlog .Warnf ("Couldn't update cache file [%s]: %s" , absPath , err )
118142 }
119143
144+ source .Lock ()
120145 source .bin = bin
146+ source .Unlock ()
121147}
122148
123149func (source * Source ) parseURLs (urls []string ) {
@@ -135,10 +161,11 @@ func fetchFromURL(xTransport *XTransport, u *url.URL) ([]byte, error) {
135161 return bin , err
136162}
137163
138- func (source * Source ) fetchWithCache (xTransport * XTransport , now time.Time ) (time.Duration , error ) {
164+ func (source * Source ) fetchWithCache (xTransport * XTransport ) (time.Duration , error ) {
165+ now := getCurrentTime ()
139166 var err error
140167 var ttl time.Duration
141- if ttl , err = source .fetchFromCache (now ); err != nil {
168+ if ttl , err = source .fetchFromCache (); err != nil {
142169 if len (source .urls ) == 0 {
143170 dlog .Errorf ("Source [%s] cache file [%s] not present and no valid URL" , source .name , source .cacheFile )
144171 return 0 , err
@@ -179,7 +206,7 @@ func (source *Source) fetchWithCache(xTransport *XTransport, now time.Time) (tim
179206 if err != nil {
180207 return 0 , err
181208 }
182- source .updateCache (bin , sig , now )
209+ source .updateCache (bin , sig )
183210 ttl = source .prefetchDelay
184211 source .refresh = now .Add (ttl )
185212 return ttl , nil
@@ -218,7 +245,7 @@ func NewSource(
218245 return source , err
219246 }
220247 source .parseURLs (urls )
221- _ , err := source .fetchWithCache (xTransport , timeNow () )
248+ _ , err := source .fetchWithCache (xTransport )
222249 if err == nil {
223250 dlog .Noticef ("Source [%s] loaded" , name )
224251 }
@@ -227,14 +254,14 @@ func NewSource(
227254
228255// PrefetchSources downloads latest versions of given sources, ensuring they have a valid signature before caching
229256func PrefetchSources (xTransport * XTransport , sources []* Source ) time.Duration {
230- now := timeNow ()
257+ now := getCurrentTime ()
231258 interval := MinimumPrefetchInterval
232259 for _ , source := range sources {
233260 if source .refresh .IsZero () || source .refresh .After (now ) {
234261 continue
235262 }
236263 dlog .Debugf ("Prefetching [%s]" , source .name )
237- if delay , err := source .fetchWithCache (xTransport , now ); err != nil {
264+ if delay , err := source .fetchWithCache (xTransport ); err != nil {
238265 dlog .Infof ("Prefetching [%s] failed: %v, will retry in %v" , source .name , err , interval )
239266 } else {
240267 dlog .Debugf ("Prefetching [%s] succeeded, next update in %v min" , source .name , delay )
@@ -262,7 +289,13 @@ func (source *Source) parseV2() ([]RegisteredServer, error) {
262289 stampErrs = append (stampErrs , stampErr )
263290 dlog .Warn (stampErr )
264291 }
265- in := string (source .bin )
292+
293+ source .RLock ()
294+ binCopy := make ([]byte , len (source .bin ))
295+ copy (binCopy , source .bin )
296+ source .RUnlock ()
297+
298+ in := string (binCopy )
266299 parts := strings .Split (in , "## " )
267300 if len (parts ) < 2 {
268301 return registeredServers , fmt .Errorf ("Invalid format for source at [%v]" , source .urls )
0 commit comments