diff --git a/lib/store/base/buffer_readwriter.go b/lib/store/base/buffer_readwriter.go index 67aa5191e..1b009f994 100644 --- a/lib/store/base/buffer_readwriter.go +++ b/lib/store/base/buffer_readwriter.go @@ -17,54 +17,110 @@ package base import ( "fmt" "io" - - "github.com/aws/aws-sdk-go/aws" + "sync" + "sync/atomic" ) var _ FileReadWriter = &BufferReadWriter{} -// BufferReadWriter implements FileReadWriter interface for in-memory buffering. +// BufferReadWriter implements FileReadWriter for in-memory buffering. +// +// Pre-sizing (size > 0) allows concurrent WriteAt calls to non-overlapping +// ranges to run in parallel. Without pre-sizing, each write that grows the +// buffer is serialized. +// +// Bytes, Size, Write, Read, ReadAt, and Seek must not be called concurrently +// with each other or with WriteAt. type BufferReadWriter struct { - buf *aws.WriteAtBuffer - offset int64 + mu sync.RWMutex + buf []byte + written atomic.Int64 + offset int64 } -// NewBufferReadWriter creates a new BufferReadWriter with an initial capacity of size bytes. +// NewBufferReadWriter creates a new BufferReadWriter pre-allocated to size bytes. +// Pass the exact blob size when known so concurrent WriteAt calls for +// non-overlapping shard ranges can run in parallel without writer serialization. func NewBufferReadWriter(size uint64) *BufferReadWriter { - bytesSlice := make([]byte, 0, size) - buf := aws.NewWriteAtBuffer(bytesSlice) - // Although this is default, this is explicitly set to notify that we are reserving - // only as much capacity as needed - buf.GrowthCoeff = 1 - - return &BufferReadWriter{ - buf: buf, - offset: 0, - } + return &BufferReadWriter{buf: make([]byte, size)} } -// Write implements io.Writer by using WriteAt with current write offset. +// Write implements io.Writer using the current sequential write offset. func (b *BufferReadWriter) Write(p []byte) (n int, err error) { - n, err = b.buf.WriteAt(p, b.offset) + n, err = b.WriteAt(p, b.offset) b.offset += int64(n) return n, err } -// WriteAt implements io.WriterAt for parallel writes. -func (b *BufferReadWriter) WriteAt(p []byte, off int64) (n int, err error) { +// WriteAt implements io.WriterAt. +// +// Fast path (off+len(p) within pre-allocated buffer): multiple goroutines may +// call WriteAt concurrently, provided their byte ranges do not overlap. +// +// Slow path (write extends beyond current buffer): acquires an exclusive lock +// only to grow the buffer, then copies under a shared read lock so other +// writers can proceed in parallel. +func (b *BufferReadWriter) WriteAt(p []byte, off int64) (int, error) { if off < 0 { return 0, fmt.Errorf("negative offset") } - return b.buf.WriteAt(p, off) + if len(p) == 0 { + return 0, nil + } + end := off + int64(len(p)) + if end < off { + return 0, fmt.Errorf("write at offset %d length %d overflows int64", off, len(p)) + } + + // fast path + if n, shouldGrowBuffer := b.updateBuffer(p, off, end); !shouldGrowBuffer { + return n, nil + } + + // slow path: grow, then call updateBuffer again + b.growBuffer(end) + n, _ := b.updateBuffer(p, off, end) + return n, nil +} + +// growBuffer expands b.buf to at least size bytes. It is a no-op when the +// buffer is already large enough. +func (b *BufferReadWriter) growBuffer(size int64) { + b.mu.Lock() + defer b.mu.Unlock() + if size <= int64(len(b.buf)) { + return + } + grown := make([]byte, size) + copy(grown, b.buf) + b.buf = grown +} + +// updateBuffer copies p into [off, end). +// Returns (-1, true) if end exceeds the buffer, causing WriteAt to take the slow path. +func (b *BufferReadWriter) updateBuffer(p []byte, off, end int64) (int, bool) { + b.mu.RLock() + defer b.mu.RUnlock() + if end <= int64(len(b.buf)) { + n := copy(b.buf[off:], p) + for { + cur := b.written.Load() + if end <= cur || b.written.CompareAndSwap(cur, end) { + break + } + } + return n, false + } + return -1, true } // Read implements io.Reader for sequential reads. func (b *BufferReadWriter) Read(p []byte) (n int, err error) { - bufBytes := b.buf.Bytes() - if b.offset >= int64(len(bufBytes)) { + written := b.written.Load() + if b.offset >= written { return 0, io.EOF } - n = copy(p, bufBytes[b.offset:]) + n = copy(p, b.buf[b.offset:written]) b.offset += int64(n) if n < len(p) { err = io.EOF @@ -77,11 +133,14 @@ func (b *BufferReadWriter) ReadAt(p []byte, off int64) (n int, err error) { if off < 0 { return 0, fmt.Errorf("negative offset") } - bufBytes := b.buf.Bytes() - if off >= int64(len(bufBytes)) { + b.mu.RLock() + buf := b.buf + written := b.written.Load() + b.mu.RUnlock() + if off >= written { return 0, io.EOF } - n = copy(p, bufBytes[off:]) + n = copy(p, buf[off:written]) if n < len(p) { err = io.EOF } @@ -91,23 +150,19 @@ func (b *BufferReadWriter) ReadAt(p []byte, off int64) (n int, err error) { // Seek implements io.Seeker. func (b *BufferReadWriter) Seek(offset int64, whence int) (int64, error) { var newOffset int64 - bufSize := int64(len(b.buf.Bytes())) - switch whence { case io.SeekStart: newOffset = offset case io.SeekCurrent: newOffset = b.offset + offset case io.SeekEnd: - newOffset = bufSize + offset + newOffset = b.written.Load() + offset default: return 0, fmt.Errorf("invalid whence: %d", whence) } - if newOffset < 0 { return 0, fmt.Errorf("negative position: %d", newOffset) } - b.offset = newOffset return newOffset, nil } @@ -117,10 +172,8 @@ func (b *BufferReadWriter) Close() error { return nil } -// Size returns the size of the buffer -func (b *BufferReadWriter) Size() int64 { - return int64(len(b.buf.Bytes())) -} +// Size returns the largest end offset written so far. +func (b *BufferReadWriter) Size() int64 { return b.written.Load() } // Cancel is no-op func (b *BufferReadWriter) Cancel() error { @@ -132,7 +185,8 @@ func (b *BufferReadWriter) Commit() error { return nil } -// Bytes returns the full buffer +// Bytes returns the buffer up to the highest offset written so far. +// Any gaps between writes are zero-filled. func (b *BufferReadWriter) Bytes() []byte { - return b.buf.Bytes() + return b.buf[:b.written.Load()] } diff --git a/lib/store/base/buffer_readwriter_test.go b/lib/store/base/buffer_readwriter_test.go index cfcf7c5bf..7316290ff 100644 --- a/lib/store/base/buffer_readwriter_test.go +++ b/lib/store/base/buffer_readwriter_test.go @@ -15,7 +15,11 @@ package base import ( + "errors" + "fmt" "io" + "runtime" + "sync" "testing" "testing/iotest" @@ -49,6 +53,7 @@ func TestBufferReadWriter_Write(t *testing.T) { } assert.Equal(t, tt.expectedSize, buf.Size()) + assert.Equal(t, tt.expectedResult, buf.Bytes()) }) } } @@ -111,6 +116,9 @@ func TestBufferReadWriter_WriteAt(t *testing.T) { } assert.Equal(t, tt.expectedSize, buf.Size()) + if tt.expectedResult != nil { + assert.Equal(t, tt.expectedResult, buf.Bytes()) + } }) } } @@ -308,3 +316,141 @@ func TestBufferReadWriter_TestReader(t *testing.T) { err = iotest.TestReader(buf, content) require.NoError(t, err) } + +func TestBufferReadWriter_ConcurrentWriteAt(t *testing.T) { + const numShards, shardSize = 10, 1024 + totalSize := uint64(numShards * shardSize) + + tests := []struct { + name string + initSize uint64 + }{ + {"presized", totalSize}, + {"half_presized", totalSize / 2}, + {"dynamic", 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data := make([]byte, totalSize) + for i := range data { + data[i] = byte(i % 256) + } + + buf := NewBufferReadWriter(tt.initSize) + errs := make([]error, numShards) + var wg sync.WaitGroup + for i := 0; i < numShards; i++ { + wg.Add(1) + go func(shard int) { + defer wg.Done() + off := shard * shardSize + _, errs[shard] = buf.WriteAt(data[off:off+shardSize], int64(off)) + }(i) + } + wg.Wait() + + require.NoError(t, errors.Join(errs...)) + assert.Equal(t, data, buf.Bytes()) + }) + } +} + +func TestBufferReadWriter_WriteAtEmpty(t *testing.T) { + buf := NewBufferReadWriter(0) + n, err := buf.WriteAt(nil, 1<<30) + require.NoError(t, err) + assert.Equal(t, 0, n) + assert.Equal(t, int64(0), buf.Size()) + assert.Empty(t, buf.Bytes()) +} + +func totalMutexContentions() int64 { + size := 1024 + for { + records := make([]runtime.BlockProfileRecord, size) + n, ok := runtime.MutexProfile(records) + if ok { + var total int64 + for i := 0; i < n; i++ { + total += records[i].Count + } + return total + } + size = n + 64 + } +} + +// benchmarkWriteAt is the shared helper for all WriteAt benchmarks. +// numShards goroutines each write a non-overlapping 4 MiB shard concurrently, +// initSize controls the buffer's initial allocation: +// - initSize == totalSize: pre-sized fast path (production case) +// - initSize == 0: dynamic growth (sequential / wrong-size case) +// - initSize == totalSize/2: partial pre-allocation, triggers growth mid-download +func benchmarkWriteAt(b *testing.B, numShards int, initSize uint64) { + b.Helper() + const shardSize = 4 * 1024 * 1024 + totalSize := uint64(numShards) * shardSize + + shards := make([][]byte, numShards) + for i := range shards { + shards[i] = make([]byte, shardSize) + for j := range shards[i] { + shards[i][j] = byte(i) + } + } + + prev := runtime.SetMutexProfileFraction(1) + defer runtime.SetMutexProfileFraction(prev) + startContentions := totalMutexContentions() + b.SetBytes(int64(totalSize)) + b.ReportAllocs() + b.ResetTimer() + + errs := make([]error, numShards) + for b.Loop() { + buf := NewBufferReadWriter(initSize) + var wg sync.WaitGroup + for shard := 0; shard < numShards; shard++ { + wg.Add(1) + go func(s int) { + defer wg.Done() + _, errs[s] = buf.WriteAt(shards[s], int64(s)*shardSize) + }(shard) + } + wg.Wait() + if err := errors.Join(errs...); err != nil { + b.Fatal(err) + } + } + + b.StopTimer() + if b.N > 0 { + b.ReportMetric(float64(totalMutexContentions()-startContentions)/float64(b.N), "mutex-contentions/op") + } +} + +// BenchmarkBufferReadWriter_WriteAt exercises three buffer initialization +// strategies × three shard counts. +func BenchmarkBufferReadWriter_WriteAt(b *testing.B) { + const shardSize = 4 * 1024 * 1024 + cases := []struct { + label string + initFunc func(total uint64) uint64 + }{ + {"presized", func(total uint64) uint64 { return total }}, // production fast path + {"half_presized", func(total uint64) uint64 { return total / 2 }}, // partial pre-alloc, growth needed + {"dynamic", func(total uint64) uint64 { return 0 }}, // no pre-alloc, always grows + } + shardCounts := []int{1, 4, 10} + + for _, tc := range cases { + for _, numShards := range shardCounts { + totalSize := uint64(numShards) * shardSize + initSize := tc.initFunc(totalSize) + b.Run(fmt.Sprintf("%s_%d_shards", tc.label, numShards), func(b *testing.B) { + benchmarkWriteAt(b, numShards, initSize) + }) + } + } +}