Skip to content

Commit 402c402

Browse files
committed
feat(memfd): async background copy via NewCacheFromMemfdAsync
Introduce the MemfdCache wrapper (embedding *Cache) plus the NewCacheFromMemfdAsync constructor: copy runs on a goroutine so gRPC Pause can return as soon as the snapshot file and diff metadata are written. The MemfdBackgroundCopyFlag gates the dispatch in fc.ExportMemory; flag-off keeps the existing sync NewCacheFromMemfd path untouched. In-flight reads route through a memfdSource indexed by cache offset; afterwards they delegate to the embedded Cache and the memfd is closed. Slice returns BytesNotAvailableError while the copy is in flight to prevent UAF on the asynchronous Munmap; callers fall back to ReadAt or Wait first. localDiff takes block.DiffSource and uses a Wait type-assertion in CachePath so existing FS-reading upload paths see complete data.
1 parent 18ee751 commit 402c402

7 files changed

Lines changed: 293 additions & 10 deletions

File tree

packages/orchestrator/pkg/sandbox/block/device.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,16 @@ type Device interface {
4646
io.WriterAt
4747
WriteZeroesAt(off, length int64) (int, error)
4848
}
49+
50+
// DiffSource is what the diff/upload layer reads from. *Cache satisfies it
51+
// directly; the async memfd path wraps *Cache in *MemfdCache to override
52+
// the read methods while the background copy is in flight.
53+
type DiffSource interface {
54+
io.Closer
55+
ReadAt(b []byte, off int64) (int, error)
56+
Slice(off, length int64) ([]byte, error)
57+
Size() (int64, error)
58+
FileSize() (int64, error)
59+
BlockSize() int64
60+
Path() string
61+
}

packages/orchestrator/pkg/sandbox/block/memfd.go

Lines changed: 200 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
"context"
77
"errors"
88
"fmt"
9+
"sync"
10+
"sync/atomic"
911

1012
"github.com/RoaringBitmap/roaring/v2"
1113
"golang.org/x/sys/unix"
@@ -73,26 +75,219 @@ func NewCacheFromMemfd(
7375
if err != nil {
7476
return nil, errors.Join(err, memfd.Close())
7577
}
78+
if err := copyFromMemfd(ctx, cache, memfd, dirty, blockSize); err != nil {
79+
return nil, errors.Join(err, memfd.Close(), cache.Close())
80+
}
81+
if err := memfd.Close(); err != nil {
82+
return nil, errors.Join(fmt.Errorf("close memfd: %w", err), cache.Close())
83+
}
7684

85+
return cache, nil
86+
}
87+
88+
func copyFromMemfd(ctx context.Context, cache *Cache, memfd *Memfd, dirty *roaring.Bitmap, blockSize int64) error {
7789
var cacheOff int64
7890
for r := range BitsetRanges(dirty, blockSize) {
7991
if err := ctx.Err(); err != nil {
80-
return nil, errors.Join(err, memfd.Close(), cache.Close())
92+
return err
8193
}
8294

8395
src, err := memfd.Slice(r.Start, r.Size)
8496
if err != nil {
85-
return nil, errors.Join(fmt.Errorf("memfd slice [%d,%d): %w", r.Start, r.Start+r.Size, err), memfd.Close(), cache.Close())
97+
return fmt.Errorf("memfd slice [%d,%d): %w", r.Start, r.Start+r.Size, err)
8698
}
8799

88100
copy((*cache.mmap)[cacheOff:cacheOff+r.Size], src)
89101
cache.setIsCached(cacheOff, r.Size)
90102
cacheOff += r.Size
91103
}
92104

93-
if err := memfd.Close(); err != nil {
94-
return nil, errors.Join(fmt.Errorf("close memfd: %w", err), cache.Close())
105+
return nil
106+
}
107+
108+
// MemfdCache wraps a Cache that is being populated from a memfd on a
109+
// background goroutine. While the copy is in flight, reads are served
110+
// directly from the memfd via memfdSource; afterwards they delegate to the
111+
// embedded Cache and the memfd is closed.
112+
type MemfdCache struct {
113+
*Cache
114+
115+
mu sync.RWMutex // guards src
116+
src *memfdSource // non-nil while the background copy is in flight
117+
cancel context.CancelFunc
118+
done chan struct{}
119+
err atomic.Pointer[error]
120+
}
121+
122+
// NewCacheFromMemfdAsync starts the memfd→cache copy on a goroutine so gRPC
123+
// Pause can return as soon as the snapshot file and diff metadata are
124+
// written; the FC stop and the memfd copy then run in parallel after the
125+
// response. The returned wrapper takes ownership of memfd; Close cancels
126+
// and joins the copy goroutine.
127+
func NewCacheFromMemfdAsync(
128+
ctx context.Context,
129+
blockSize int64,
130+
filePath string,
131+
memfd *Memfd,
132+
dirty *roaring.Bitmap,
133+
) (*MemfdCache, error) {
134+
cache, err := NewCache(int64(dirty.GetCardinality())*blockSize, blockSize, filePath, false)
135+
if err != nil {
136+
return nil, errors.Join(err, memfd.Close())
95137
}
138+
if dirty.IsEmpty() {
139+
if closeErr := memfd.Close(); closeErr != nil {
140+
return nil, errors.Join(fmt.Errorf("close memfd: %w", closeErr), cache.Close())
141+
}
96142

97-
return cache, nil
143+
return &MemfdCache{Cache: cache}, nil
144+
}
145+
146+
// Detach from the request context so the copy can outlive Pause; Close
147+
// drives cancellation.
148+
copyCtx, cancel := context.WithCancel(context.WithoutCancel(ctx))
149+
m := &MemfdCache{
150+
Cache: cache,
151+
src: newMemfdSource(memfd, dirty, blockSize),
152+
cancel: cancel,
153+
done: make(chan struct{}),
154+
}
155+
156+
go m.runCopy(copyCtx, dirty, blockSize)
157+
158+
return m, nil
159+
}
160+
161+
func (m *MemfdCache) runCopy(ctx context.Context, dirty *roaring.Bitmap, blockSize int64) {
162+
defer close(m.done)
163+
164+
err := copyFromMemfd(ctx, m.Cache, m.src.memfd, dirty, blockSize)
165+
if err != nil {
166+
m.err.Store(&err)
167+
}
168+
169+
m.mu.Lock()
170+
src := m.src
171+
m.src = nil
172+
m.mu.Unlock()
173+
174+
if closeErr := src.memfd.Close(); closeErr != nil {
175+
joined := errors.Join(err, fmt.Errorf("close memfd: %w", closeErr))
176+
m.err.Store(&joined)
177+
}
178+
}
179+
180+
// Wait blocks until the background copy completes (or ctx is cancelled).
181+
func (m *MemfdCache) Wait(ctx context.Context) error {
182+
if m.done == nil {
183+
return nil
184+
}
185+
select {
186+
case <-ctx.Done():
187+
return ctx.Err()
188+
case <-m.done:
189+
}
190+
if errPtr := m.err.Load(); errPtr != nil {
191+
return *errPtr
192+
}
193+
194+
return nil
195+
}
196+
197+
func (m *MemfdCache) ReadAt(b []byte, off int64) (int, error) {
198+
m.mu.RLock()
199+
if m.src != nil {
200+
defer m.mu.RUnlock()
201+
202+
return m.src.readAt(b, off)
203+
}
204+
m.mu.RUnlock()
205+
206+
return m.Cache.ReadAt(b, off)
207+
}
208+
209+
// Slice returns BytesNotAvailableError while the copy is in flight: the
210+
// memfd-backed slice would outlive the RLock and could be Munmap'd
211+
// asynchronously. Callers fall back to ReadAt or Wait first.
212+
func (m *MemfdCache) Slice(off, length int64) ([]byte, error) {
213+
m.mu.RLock()
214+
defer m.mu.RUnlock()
215+
216+
if m.src != nil {
217+
return nil, BytesNotAvailableError{}
218+
}
219+
220+
return m.Cache.Slice(off, length)
221+
}
222+
223+
func (m *MemfdCache) Close() error {
224+
if m.cancel != nil {
225+
m.cancel()
226+
<-m.done
227+
}
228+
229+
return m.Cache.Close()
230+
}
231+
232+
// memfdSource indexes the memfd-backed ranges by cache offset so reads can
233+
// be served from the memfd while the background copy is still in flight.
234+
type memfdSource struct {
235+
memfd *Memfd
236+
entries []memfdRange
237+
}
238+
239+
type memfdRange struct {
240+
cacheStart int64
241+
srcStart int64
242+
size int64
243+
}
244+
245+
func newMemfdSource(memfd *Memfd, dirty *roaring.Bitmap, blockSize int64) *memfdSource {
246+
var entries []memfdRange
247+
var cacheOff int64
248+
for r := range BitsetRanges(dirty, blockSize) {
249+
entries = append(entries, memfdRange{cacheStart: cacheOff, srcStart: r.Start, size: r.Size})
250+
cacheOff += r.Size
251+
}
252+
253+
return &memfdSource{memfd: memfd, entries: entries}
254+
}
255+
256+
func (s *memfdSource) findEntry(cacheOff int64) int {
257+
lo, hi := 0, len(s.entries)
258+
for lo < hi {
259+
mid := (lo + hi) / 2
260+
if s.entries[mid].cacheStart > cacheOff {
261+
hi = mid
262+
} else {
263+
lo = mid + 1
264+
}
265+
}
266+
i := lo - 1
267+
if i < 0 || cacheOff >= s.entries[i].cacheStart+s.entries[i].size {
268+
return -1
269+
}
270+
271+
return i
272+
}
273+
274+
func (s *memfdSource) readAt(b []byte, cacheOff int64) (int, error) {
275+
n := 0
276+
for n < len(b) {
277+
i := s.findEntry(cacheOff + int64(n))
278+
if i < 0 {
279+
return n, nil
280+
}
281+
e := s.entries[i]
282+
offsetInEntry := cacheOff + int64(n) - e.cacheStart
283+
toCopy := min(int64(len(b)-n), e.size-offsetInEntry)
284+
src, err := s.memfd.Slice(e.srcStart+offsetInEntry, toCopy)
285+
if err != nil {
286+
return n, fmt.Errorf("memfd slice: %w", err)
287+
}
288+
copy(b[n:n+int(toCopy)], src)
289+
n += int(toCopy)
290+
}
291+
292+
return n, nil
98293
}

packages/orchestrator/pkg/sandbox/block/memfd_test.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
package block
44

55
import (
6+
"context"
67
"crypto/rand"
8+
"os"
79
"testing"
810

911
"github.com/RoaringBitmap/roaring/v2"
@@ -76,3 +78,57 @@ func TestNewCacheFromMemfd_NonZeroRangeStart(t *testing.T) {
7678
require.NoError(t, err)
7779
require.Equal(t, expected[pageSize*3:pageSize*5], got)
7880
}
81+
82+
// Cancelling the parent ctx after construction must not abort the in-flight
83+
// copy: NewCacheFromMemfdAsync detaches via context.WithoutCancel so the
84+
// copy outlives the Pause RPC. Cancellation goes through Close.
85+
func TestNewCacheFromMemfdAsync_ParentContextCancellationDoesNotAbort(t *testing.T) {
86+
t.Parallel()
87+
88+
pageSize := int64(header.PageSize)
89+
numPages := uint32(16)
90+
memfd, expected := newTestMemfd(t, pageSize*int64(numPages))
91+
92+
dirty := roaring.New()
93+
dirty.AddRange(0, uint64(numPages))
94+
95+
ctx, cancel := context.WithCancel(t.Context())
96+
cache, err := NewCacheFromMemfdAsync(ctx, pageSize, t.TempDir()+"/cache", memfd, dirty)
97+
require.NoError(t, err)
98+
t.Cleanup(func() { _ = cache.Close() })
99+
100+
cancel()
101+
require.NoError(t, cache.Wait(t.Context()))
102+
103+
got := make([]byte, len(expected))
104+
_, err = cache.ReadAt(got, 0)
105+
require.NoError(t, err)
106+
require.Equal(t, expected, got)
107+
}
108+
109+
// After Wait, the cache file on disk has the full payload and Slice no
110+
// longer returns BytesNotAvailableError.
111+
func TestNewCacheFromMemfdAsync_WaitFlushesToFile(t *testing.T) {
112+
t.Parallel()
113+
114+
pageSize := int64(header.PageSize)
115+
numPages := uint32(12)
116+
memfd, expected := newTestMemfd(t, pageSize*int64(numPages))
117+
118+
dirty := roaring.New()
119+
dirty.AddRange(0, uint64(numPages))
120+
121+
cachePath := t.TempDir() + "/cache"
122+
cache, err := NewCacheFromMemfdAsync(t.Context(), pageSize, cachePath, memfd, dirty)
123+
require.NoError(t, err)
124+
t.Cleanup(func() { _ = cache.Close() })
125+
126+
require.NoError(t, cache.Wait(t.Context()))
127+
128+
_, err = cache.Slice(0, pageSize)
129+
require.NoError(t, err)
130+
131+
fromFile, err := os.ReadFile(cachePath)
132+
require.NoError(t, err)
133+
require.Equal(t, expected, fromFile)
134+
}

packages/orchestrator/pkg/sandbox/build/local_diff.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,14 @@ func (f *LocalDiffFile) CloseToDiff(
8080

8181
type localDiff struct {
8282
cacheKey DiffStoreKey
83-
cache *block.Cache
83+
cache block.DiffSource
8484
}
8585

8686
var _ Diff = (*localDiff)(nil)
8787

8888
func NewLocalDiffFromCache(
8989
cacheKey DiffStoreKey,
90-
cache *block.Cache,
90+
cache block.DiffSource,
9191
) (Diff, error) {
9292
return &localDiff{
9393
cache: cache,
@@ -110,6 +110,14 @@ func newLocalDiff(
110110
}
111111

112112
func (b *localDiff) CachePath() (string, error) {
113+
if w, ok := b.cache.(interface {
114+
Wait(ctx context.Context) error
115+
}); ok {
116+
if err := w.Wait(context.Background()); err != nil {
117+
return "", fmt.Errorf("memfd copy: %w", err)
118+
}
119+
}
120+
113121
return b.cache.Path(), nil
114122
}
115123

packages/orchestrator/pkg/sandbox/fc/memory.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func (p *Process) exportMemoryFromFc(
2828
include *roaring.Bitmap,
2929
cachePath string,
3030
blockSize int64,
31-
) (*block.Cache, error) {
31+
) (block.DiffSource, error) {
3232
m, err := p.client.memoryMapping(ctx)
3333
if err != nil {
3434
return nil, fmt.Errorf("failed to get memory mappings: %w", err)
@@ -64,10 +64,14 @@ func (p *Process) ExportMemory(
6464
cachePath string,
6565
blockSize int64,
6666
memfd *block.Memfd,
67-
) (*block.Cache, error) {
67+
bgCopy bool,
68+
) (block.DiffSource, error) {
6869
if memfd == nil {
6970
return p.exportMemoryFromFc(ctx, include, cachePath, blockSize)
7071
}
72+
if bgCopy {
73+
return block.NewCacheFromMemfdAsync(ctx, blockSize, cachePath, memfd, include)
74+
}
7175

7276
return block.NewCacheFromMemfd(ctx, blockSize, cachePath, memfd, include)
7377
}

packages/orchestrator/pkg/sandbox/sandbox.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1138,6 +1138,7 @@ func (s *Sandbox) Pause(
11381138
s.config.DefaultCacheDir,
11391139
s.process,
11401140
s.memory.Memfd(ctx),
1141+
s.featureFlags.BoolFlag(ctx, featureflags.MemfdBackgroundCopyFlag, sandboxLDContext(s.Runtime, s.Config)),
11411142
)
11421143
if err != nil {
11431144
return nil, fmt.Errorf("error while post processing: %w", err)
@@ -1206,13 +1207,14 @@ func pauseProcessMemory(
12061207
cacheDir string,
12071208
fc *fc.Process,
12081209
memfd *block.Memfd,
1210+
bgCopy bool,
12091211
) (d build.Diff, h *header.Header, e error) {
12101212
ctx, span := tracer.Start(ctx, "process-memory")
12111213
defer span.End()
12121214

12131215
// ExportMemory owns memfd and closes it on all paths.
12141216
memfileDiffPath := build.GenerateDiffCachePath(cacheDir, buildID.String(), build.Memfile)
1215-
cache, err := fc.ExportMemory(ctx, diffMetadata.Dirty, memfileDiffPath, diffMetadata.BlockSize, memfd)
1217+
cache, err := fc.ExportMemory(ctx, diffMetadata.Dirty, memfileDiffPath, diffMetadata.BlockSize, memfd, bgCopy)
12161218
if err != nil {
12171219
return nil, nil, fmt.Errorf("failed to export memory: %w", err)
12181220
}

0 commit comments

Comments
 (0)