|
6 | 6 | "context" |
7 | 7 | "errors" |
8 | 8 | "fmt" |
| 9 | + "sync" |
| 10 | + "sync/atomic" |
9 | 11 |
|
10 | 12 | "github.com/RoaringBitmap/roaring/v2" |
11 | 13 | "golang.org/x/sys/unix" |
@@ -73,26 +75,219 @@ func NewCacheFromMemfd( |
73 | 75 | if err != nil { |
74 | 76 | return nil, errors.Join(err, memfd.Close()) |
75 | 77 | } |
| 78 | + if err := copyFromMemfd(ctx, cache, memfd, dirty, blockSize); err != nil { |
| 79 | + return nil, errors.Join(err, memfd.Close(), cache.Close()) |
| 80 | + } |
| 81 | + if err := memfd.Close(); err != nil { |
| 82 | + return nil, errors.Join(fmt.Errorf("close memfd: %w", err), cache.Close()) |
| 83 | + } |
76 | 84 |
|
| 85 | + return cache, nil |
| 86 | +} |
| 87 | + |
| 88 | +func copyFromMemfd(ctx context.Context, cache *Cache, memfd *Memfd, dirty *roaring.Bitmap, blockSize int64) error { |
77 | 89 | var cacheOff int64 |
78 | 90 | for r := range BitsetRanges(dirty, blockSize) { |
79 | 91 | if err := ctx.Err(); err != nil { |
80 | | - return nil, errors.Join(err, memfd.Close(), cache.Close()) |
| 92 | + return err |
81 | 93 | } |
82 | 94 |
|
83 | 95 | src, err := memfd.Slice(r.Start, r.Size) |
84 | 96 | if err != nil { |
85 | | - return nil, errors.Join(fmt.Errorf("memfd slice [%d,%d): %w", r.Start, r.Start+r.Size, err), memfd.Close(), cache.Close()) |
| 97 | + return fmt.Errorf("memfd slice [%d,%d): %w", r.Start, r.Start+r.Size, err) |
86 | 98 | } |
87 | 99 |
|
88 | 100 | copy((*cache.mmap)[cacheOff:cacheOff+r.Size], src) |
89 | 101 | cache.setIsCached(cacheOff, r.Size) |
90 | 102 | cacheOff += r.Size |
91 | 103 | } |
92 | 104 |
|
93 | | - if err := memfd.Close(); err != nil { |
94 | | - return nil, errors.Join(fmt.Errorf("close memfd: %w", err), cache.Close()) |
| 105 | + return nil |
| 106 | +} |
| 107 | + |
| 108 | +// MemfdCache wraps a Cache that is being populated from a memfd on a |
| 109 | +// background goroutine. While the copy is in flight, reads are served |
| 110 | +// directly from the memfd via memfdSource; afterwards they delegate to the |
| 111 | +// embedded Cache and the memfd is closed. |
| 112 | +type MemfdCache struct { |
| 113 | + *Cache |
| 114 | + |
| 115 | + mu sync.RWMutex // guards src |
| 116 | + src *memfdSource // non-nil while the background copy is in flight |
| 117 | + cancel context.CancelFunc |
| 118 | + done chan struct{} |
| 119 | + err atomic.Pointer[error] |
| 120 | +} |
| 121 | + |
| 122 | +// NewCacheFromMemfdAsync starts the memfd→cache copy on a goroutine so gRPC |
| 123 | +// Pause can return as soon as the snapshot file and diff metadata are |
| 124 | +// written; the FC stop and the memfd copy then run in parallel after the |
| 125 | +// response. The returned wrapper takes ownership of memfd; Close cancels |
| 126 | +// and joins the copy goroutine. |
| 127 | +func NewCacheFromMemfdAsync( |
| 128 | + ctx context.Context, |
| 129 | + blockSize int64, |
| 130 | + filePath string, |
| 131 | + memfd *Memfd, |
| 132 | + dirty *roaring.Bitmap, |
| 133 | +) (*MemfdCache, error) { |
| 134 | + cache, err := NewCache(int64(dirty.GetCardinality())*blockSize, blockSize, filePath, false) |
| 135 | + if err != nil { |
| 136 | + return nil, errors.Join(err, memfd.Close()) |
95 | 137 | } |
| 138 | + if dirty.IsEmpty() { |
| 139 | + if closeErr := memfd.Close(); closeErr != nil { |
| 140 | + return nil, errors.Join(fmt.Errorf("close memfd: %w", closeErr), cache.Close()) |
| 141 | + } |
96 | 142 |
|
97 | | - return cache, nil |
| 143 | + return &MemfdCache{Cache: cache}, nil |
| 144 | + } |
| 145 | + |
| 146 | + // Detach from the request context so the copy can outlive Pause; Close |
| 147 | + // drives cancellation. |
| 148 | + copyCtx, cancel := context.WithCancel(context.WithoutCancel(ctx)) |
| 149 | + m := &MemfdCache{ |
| 150 | + Cache: cache, |
| 151 | + src: newMemfdSource(memfd, dirty, blockSize), |
| 152 | + cancel: cancel, |
| 153 | + done: make(chan struct{}), |
| 154 | + } |
| 155 | + |
| 156 | + go m.runCopy(copyCtx, dirty, blockSize) |
| 157 | + |
| 158 | + return m, nil |
| 159 | +} |
| 160 | + |
| 161 | +func (m *MemfdCache) runCopy(ctx context.Context, dirty *roaring.Bitmap, blockSize int64) { |
| 162 | + defer close(m.done) |
| 163 | + |
| 164 | + err := copyFromMemfd(ctx, m.Cache, m.src.memfd, dirty, blockSize) |
| 165 | + if err != nil { |
| 166 | + m.err.Store(&err) |
| 167 | + } |
| 168 | + |
| 169 | + m.mu.Lock() |
| 170 | + src := m.src |
| 171 | + m.src = nil |
| 172 | + m.mu.Unlock() |
| 173 | + |
| 174 | + if closeErr := src.memfd.Close(); closeErr != nil { |
| 175 | + joined := errors.Join(err, fmt.Errorf("close memfd: %w", closeErr)) |
| 176 | + m.err.Store(&joined) |
| 177 | + } |
| 178 | +} |
| 179 | + |
| 180 | +// Wait blocks until the background copy completes (or ctx is cancelled). |
| 181 | +func (m *MemfdCache) Wait(ctx context.Context) error { |
| 182 | + if m.done == nil { |
| 183 | + return nil |
| 184 | + } |
| 185 | + select { |
| 186 | + case <-ctx.Done(): |
| 187 | + return ctx.Err() |
| 188 | + case <-m.done: |
| 189 | + } |
| 190 | + if errPtr := m.err.Load(); errPtr != nil { |
| 191 | + return *errPtr |
| 192 | + } |
| 193 | + |
| 194 | + return nil |
| 195 | +} |
| 196 | + |
| 197 | +func (m *MemfdCache) ReadAt(b []byte, off int64) (int, error) { |
| 198 | + m.mu.RLock() |
| 199 | + if m.src != nil { |
| 200 | + defer m.mu.RUnlock() |
| 201 | + |
| 202 | + return m.src.readAt(b, off) |
| 203 | + } |
| 204 | + m.mu.RUnlock() |
| 205 | + |
| 206 | + return m.Cache.ReadAt(b, off) |
| 207 | +} |
| 208 | + |
| 209 | +// Slice returns BytesNotAvailableError while the copy is in flight: the |
| 210 | +// memfd-backed slice would outlive the RLock and could be Munmap'd |
| 211 | +// asynchronously. Callers fall back to ReadAt or Wait first. |
| 212 | +func (m *MemfdCache) Slice(off, length int64) ([]byte, error) { |
| 213 | + m.mu.RLock() |
| 214 | + defer m.mu.RUnlock() |
| 215 | + |
| 216 | + if m.src != nil { |
| 217 | + return nil, BytesNotAvailableError{} |
| 218 | + } |
| 219 | + |
| 220 | + return m.Cache.Slice(off, length) |
| 221 | +} |
| 222 | + |
| 223 | +func (m *MemfdCache) Close() error { |
| 224 | + if m.cancel != nil { |
| 225 | + m.cancel() |
| 226 | + <-m.done |
| 227 | + } |
| 228 | + |
| 229 | + return m.Cache.Close() |
| 230 | +} |
| 231 | + |
| 232 | +// memfdSource indexes the memfd-backed ranges by cache offset so reads can |
| 233 | +// be served from the memfd while the background copy is still in flight. |
| 234 | +type memfdSource struct { |
| 235 | + memfd *Memfd |
| 236 | + entries []memfdRange |
| 237 | +} |
| 238 | + |
| 239 | +type memfdRange struct { |
| 240 | + cacheStart int64 |
| 241 | + srcStart int64 |
| 242 | + size int64 |
| 243 | +} |
| 244 | + |
| 245 | +func newMemfdSource(memfd *Memfd, dirty *roaring.Bitmap, blockSize int64) *memfdSource { |
| 246 | + var entries []memfdRange |
| 247 | + var cacheOff int64 |
| 248 | + for r := range BitsetRanges(dirty, blockSize) { |
| 249 | + entries = append(entries, memfdRange{cacheStart: cacheOff, srcStart: r.Start, size: r.Size}) |
| 250 | + cacheOff += r.Size |
| 251 | + } |
| 252 | + |
| 253 | + return &memfdSource{memfd: memfd, entries: entries} |
| 254 | +} |
| 255 | + |
| 256 | +func (s *memfdSource) findEntry(cacheOff int64) int { |
| 257 | + lo, hi := 0, len(s.entries) |
| 258 | + for lo < hi { |
| 259 | + mid := (lo + hi) / 2 |
| 260 | + if s.entries[mid].cacheStart > cacheOff { |
| 261 | + hi = mid |
| 262 | + } else { |
| 263 | + lo = mid + 1 |
| 264 | + } |
| 265 | + } |
| 266 | + i := lo - 1 |
| 267 | + if i < 0 || cacheOff >= s.entries[i].cacheStart+s.entries[i].size { |
| 268 | + return -1 |
| 269 | + } |
| 270 | + |
| 271 | + return i |
| 272 | +} |
| 273 | + |
| 274 | +func (s *memfdSource) readAt(b []byte, cacheOff int64) (int, error) { |
| 275 | + n := 0 |
| 276 | + for n < len(b) { |
| 277 | + i := s.findEntry(cacheOff + int64(n)) |
| 278 | + if i < 0 { |
| 279 | + return n, nil |
| 280 | + } |
| 281 | + e := s.entries[i] |
| 282 | + offsetInEntry := cacheOff + int64(n) - e.cacheStart |
| 283 | + toCopy := min(int64(len(b)-n), e.size-offsetInEntry) |
| 284 | + src, err := s.memfd.Slice(e.srcStart+offsetInEntry, toCopy) |
| 285 | + if err != nil { |
| 286 | + return n, fmt.Errorf("memfd slice: %w", err) |
| 287 | + } |
| 288 | + copy(b[n:n+int(toCopy)], src) |
| 289 | + n += int(toCopy) |
| 290 | + } |
| 291 | + |
| 292 | + return n, nil |
98 | 293 | } |
0 commit comments