Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 92 additions & 38 deletions lib/store/base/buffer_readwriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,54 +17,110 @@ package base
import (
"fmt"
"io"

"github.com/aws/aws-sdk-go/aws"
"sync"
"sync/atomic"
)

var _ FileReadWriter = &BufferReadWriter{}

// BufferReadWriter implements FileReadWriter interface for in-memory buffering.
// BufferReadWriter implements FileReadWriter for in-memory buffering.
//
// Pre-sizing (size > 0) allows concurrent WriteAt calls to non-overlapping
// ranges to run in parallel. Without pre-sizing, each write that grows the
// buffer is serialized.
//
// Bytes, Size, Write, Read, ReadAt, and Seek must not be called concurrently
// with each other or with WriteAt.
type BufferReadWriter struct {
buf *aws.WriteAtBuffer
offset int64
mu sync.RWMutex
buf []byte
written atomic.Int64
offset int64
}

// NewBufferReadWriter creates a new BufferReadWriter with an initial capacity of size bytes.
// NewBufferReadWriter creates a new BufferReadWriter pre-allocated to size bytes.
// Pass the exact blob size when known so concurrent WriteAt calls for
// non-overlapping shard ranges can run in parallel without writer serialization.
func NewBufferReadWriter(size uint64) *BufferReadWriter {
bytesSlice := make([]byte, 0, size)
buf := aws.NewWriteAtBuffer(bytesSlice)
// Although this is default, this is explicitly set to notify that we are reserving
// only as much capacity as needed
buf.GrowthCoeff = 1

return &BufferReadWriter{
buf: buf,
offset: 0,
}
return &BufferReadWriter{buf: make([]byte, size)}
}
Comment thread
sambhav-jain-16 marked this conversation as resolved.

// Write implements io.Writer by using WriteAt with current write offset.
// Write implements io.Writer using the current sequential write offset.
func (b *BufferReadWriter) Write(p []byte) (n int, err error) {
n, err = b.buf.WriteAt(p, b.offset)
n, err = b.WriteAt(p, b.offset)
b.offset += int64(n)
return n, err
}

// WriteAt implements io.WriterAt for parallel writes.
func (b *BufferReadWriter) WriteAt(p []byte, off int64) (n int, err error) {
// WriteAt implements io.WriterAt.
//
// Fast path (off+len(p) within pre-allocated buffer): multiple goroutines may
// call WriteAt concurrently, provided their byte ranges do not overlap.
//
// Slow path (write extends beyond current buffer): acquires an exclusive lock
// only to grow the buffer, then copies under a shared read lock so other
// writers can proceed in parallel.
func (b *BufferReadWriter) WriteAt(p []byte, off int64) (int, error) {
if off < 0 {
return 0, fmt.Errorf("negative offset")
}
return b.buf.WriteAt(p, off)
if len(p) == 0 {
return 0, nil
}
end := off + int64(len(p))
Comment thread
sambhav-jain-16 marked this conversation as resolved.
if end < off {
return 0, fmt.Errorf("write at offset %d length %d overflows int64", off, len(p))
}

// fast path
if n, shouldGrowBuffer := b.updateBuffer(p, off, end); !shouldGrowBuffer {
return n, nil
}

// slow path: grow, then call updateBuffer again
b.growBuffer(end)
n, _ := b.updateBuffer(p, off, end)
return n, nil
Comment on lines +80 to +83
}

// growBuffer expands b.buf to at least size bytes. It is a no-op when the
// buffer is already large enough.
func (b *BufferReadWriter) growBuffer(size int64) {
b.mu.Lock()
defer b.mu.Unlock()
if size <= int64(len(b.buf)) {
return
}
grown := make([]byte, size)
copy(grown, b.buf)
b.buf = grown
}

// updateBuffer copies p into [off, end).
// Returns (-1, true) if end exceeds the buffer, causing WriteAt to take the slow path.
Comment on lines +99 to +100
func (b *BufferReadWriter) updateBuffer(p []byte, off, end int64) (int, bool) {
b.mu.RLock()
defer b.mu.RUnlock()
if end <= int64(len(b.buf)) {
n := copy(b.buf[off:], p)
for {
cur := b.written.Load()
if end <= cur || b.written.CompareAndSwap(cur, end) {
break
}
}
return n, false
}
return -1, true
}

// Read implements io.Reader for sequential reads.
func (b *BufferReadWriter) Read(p []byte) (n int, err error) {
bufBytes := b.buf.Bytes()
if b.offset >= int64(len(bufBytes)) {
written := b.written.Load()
if b.offset >= written {
return 0, io.EOF
}
n = copy(p, bufBytes[b.offset:])
n = copy(p, b.buf[b.offset:written])
b.offset += int64(n)
Comment thread
sambhav-jain-16 marked this conversation as resolved.
Comment thread
sambhav-jain-16 marked this conversation as resolved.
if n < len(p) {
err = io.EOF
Expand All @@ -77,11 +133,14 @@ func (b *BufferReadWriter) ReadAt(p []byte, off int64) (n int, err error) {
if off < 0 {
return 0, fmt.Errorf("negative offset")
}
bufBytes := b.buf.Bytes()
if off >= int64(len(bufBytes)) {
b.mu.RLock()
buf := b.buf
written := b.written.Load()
b.mu.RUnlock()
if off >= written {
return 0, io.EOF
}
n = copy(p, bufBytes[off:])
n = copy(p, buf[off:written])
if n < len(p) {
Comment thread
sambhav-jain-16 marked this conversation as resolved.
err = io.EOF
}
Expand All @@ -91,23 +150,19 @@ func (b *BufferReadWriter) ReadAt(p []byte, off int64) (n int, err error) {
// Seek implements io.Seeker.
func (b *BufferReadWriter) Seek(offset int64, whence int) (int64, error) {
var newOffset int64
bufSize := int64(len(b.buf.Bytes()))

switch whence {
case io.SeekStart:
newOffset = offset
case io.SeekCurrent:
newOffset = b.offset + offset
case io.SeekEnd:
newOffset = bufSize + offset
newOffset = b.written.Load() + offset
default:
return 0, fmt.Errorf("invalid whence: %d", whence)
}

if newOffset < 0 {
return 0, fmt.Errorf("negative position: %d", newOffset)
}

b.offset = newOffset
return newOffset, nil
}
Expand All @@ -117,10 +172,8 @@ func (b *BufferReadWriter) Close() error {
return nil
}

// Size returns the size of the buffer
func (b *BufferReadWriter) Size() int64 {
return int64(len(b.buf.Bytes()))
}
// Size returns the largest end offset written so far.
func (b *BufferReadWriter) Size() int64 { return b.written.Load() }

// Cancel is no-op
func (b *BufferReadWriter) Cancel() error {
Expand All @@ -132,7 +185,8 @@ func (b *BufferReadWriter) Commit() error {
return nil
}

// Bytes returns the full buffer
// Bytes returns the buffer up to the highest offset written so far.
// Any gaps between writes are zero-filled.
func (b *BufferReadWriter) Bytes() []byte {
return b.buf.Bytes()
return b.buf[:b.written.Load()]
Comment thread
sambhav-jain-16 marked this conversation as resolved.
Comment thread
sambhav-jain-16 marked this conversation as resolved.
}
146 changes: 146 additions & 0 deletions lib/store/base/buffer_readwriter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
package base

import (
"errors"
"fmt"
"io"
"runtime"
"sync"
"testing"
"testing/iotest"

Expand Down Expand Up @@ -49,6 +53,7 @@ func TestBufferReadWriter_Write(t *testing.T) {
}

assert.Equal(t, tt.expectedSize, buf.Size())
assert.Equal(t, tt.expectedResult, buf.Bytes())
})
}
}
Expand Down Expand Up @@ -111,6 +116,9 @@ func TestBufferReadWriter_WriteAt(t *testing.T) {
}

assert.Equal(t, tt.expectedSize, buf.Size())
if tt.expectedResult != nil {
assert.Equal(t, tt.expectedResult, buf.Bytes())
}
})
}
}
Expand Down Expand Up @@ -308,3 +316,141 @@ func TestBufferReadWriter_TestReader(t *testing.T) {
err = iotest.TestReader(buf, content)
require.NoError(t, err)
}

func TestBufferReadWriter_ConcurrentWriteAt(t *testing.T) {
const numShards, shardSize = 10, 1024
totalSize := uint64(numShards * shardSize)

tests := []struct {
name string
initSize uint64
}{
{"presized", totalSize},
{"half_presized", totalSize / 2},
{"dynamic", 0},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data := make([]byte, totalSize)
for i := range data {
data[i] = byte(i % 256)
}

buf := NewBufferReadWriter(tt.initSize)
errs := make([]error, numShards)
var wg sync.WaitGroup
for i := 0; i < numShards; i++ {
wg.Add(1)
go func(shard int) {
defer wg.Done()
off := shard * shardSize
_, errs[shard] = buf.WriteAt(data[off:off+shardSize], int64(off))
}(i)
}
wg.Wait()

require.NoError(t, errors.Join(errs...))
assert.Equal(t, data, buf.Bytes())
})
}
}

func TestBufferReadWriter_WriteAtEmpty(t *testing.T) {
buf := NewBufferReadWriter(0)
n, err := buf.WriteAt(nil, 1<<30)
require.NoError(t, err)
assert.Equal(t, 0, n)
assert.Equal(t, int64(0), buf.Size())
assert.Empty(t, buf.Bytes())
}

func totalMutexContentions() int64 {
size := 1024
for {
records := make([]runtime.BlockProfileRecord, size)
n, ok := runtime.MutexProfile(records)
if ok {
var total int64
for i := 0; i < n; i++ {
total += records[i].Count
}
return total
}
size = n + 64
}
}

// benchmarkWriteAt is the shared helper for all WriteAt benchmarks.
// numShards goroutines each write a non-overlapping 4 MiB shard concurrently,
// initSize controls the buffer's initial allocation:
// - initSize == totalSize: pre-sized fast path (production case)
// - initSize == 0: dynamic growth (sequential / wrong-size case)
// - initSize == totalSize/2: partial pre-allocation, triggers growth mid-download
func benchmarkWriteAt(b *testing.B, numShards int, initSize uint64) {
b.Helper()
const shardSize = 4 * 1024 * 1024
totalSize := uint64(numShards) * shardSize

shards := make([][]byte, numShards)
for i := range shards {
shards[i] = make([]byte, shardSize)
for j := range shards[i] {
shards[i][j] = byte(i)
}
}

prev := runtime.SetMutexProfileFraction(1)
defer runtime.SetMutexProfileFraction(prev)
startContentions := totalMutexContentions()
b.SetBytes(int64(totalSize))
b.ReportAllocs()
b.ResetTimer()

errs := make([]error, numShards)
for b.Loop() {
buf := NewBufferReadWriter(initSize)
var wg sync.WaitGroup
for shard := 0; shard < numShards; shard++ {
wg.Add(1)
go func(s int) {
defer wg.Done()
_, errs[s] = buf.WriteAt(shards[s], int64(s)*shardSize)
}(shard)
}
wg.Wait()
Comment thread
sambhav-jain-16 marked this conversation as resolved.
if err := errors.Join(errs...); err != nil {
b.Fatal(err)
}
}

b.StopTimer()
if b.N > 0 {
b.ReportMetric(float64(totalMutexContentions()-startContentions)/float64(b.N), "mutex-contentions/op")
}
}

// BenchmarkBufferReadWriter_WriteAt exercises three buffer initialization
// strategies × three shard counts.
func BenchmarkBufferReadWriter_WriteAt(b *testing.B) {
const shardSize = 4 * 1024 * 1024
cases := []struct {
label string
initFunc func(total uint64) uint64
}{
{"presized", func(total uint64) uint64 { return total }}, // production fast path
{"half_presized", func(total uint64) uint64 { return total / 2 }}, // partial pre-alloc, growth needed
{"dynamic", func(total uint64) uint64 { return 0 }}, // no pre-alloc, always grows
}
shardCounts := []int{1, 4, 10}

for _, tc := range cases {
for _, numShards := range shardCounts {
totalSize := uint64(numShards) * shardSize
initSize := tc.initFunc(totalSize)
b.Run(fmt.Sprintf("%s_%d_shards", tc.label, numShards), func(b *testing.B) {
benchmarkWriteAt(b, numShards, initSize)
})
}
}
}
Loading