diff --git a/.github/actions/build-sandbox-template/action.yml b/.github/actions/build-sandbox-template/action.yml index f42d2b02b8..34e1454f79 100644 --- a/.github/actions/build-sandbox-template/action.yml +++ b/.github/actions/build-sandbox-template/action.yml @@ -1,6 +1,24 @@ name: "Build Sandbox Template" description: "Builds the Firecracker sandbox template." +inputs: + compress_enabled: + description: "Enable compression (true/false)" + required: false + default: "false" + compress_type: + description: "Compression type (zstd, lz4)" + required: false + default: "" + compress_level: + description: "Compression level (zstd: 1=fastest, 2=default; lz4: 0)" + required: false + default: "" + compress_workers: + description: "Number of frame encode workers" + required: false + default: "" + runs: using: "composite" steps: @@ -9,6 +27,10 @@ runs: TEMPLATE_ID: "2j6ly824owf4awgai1xo" KERNEL_VERSION: "vmlinux-6.1.158" FIRECRACKER_VERSION: "v1.14.1_458ca91" + COMPRESS_ENABLED: ${{ inputs.compress_enabled }} + COMPRESS_TYPE: ${{ inputs.compress_type }} + COMPRESS_LEVEL: ${{ inputs.compress_level }} + COMPRESS_FRAME_ENCODE_WORKERS: ${{ inputs.compress_workers }} run: | # Generate an unique build ID for the template for this run export BUILD_ID=$(uuidgen) @@ -17,6 +39,10 @@ runs: echo "TESTS_SANDBOX_TEMPLATE_ID=${TEMPLATE_ID}" >> .env.test echo "TESTS_SANDBOX_BUILD_ID=${BUILD_ID}" >> .env.test + echo "COMPRESS_ENABLED=${COMPRESS_ENABLED}" >> .env.test + echo "COMPRESS_TYPE=${COMPRESS_TYPE}" >> .env.test + echo "COMPRESS_LEVEL=${COMPRESS_LEVEL}" >> .env.test + echo "COMPRESS_FRAME_ENCODE_WORKERS=${COMPRESS_FRAME_ENCODE_WORKERS}" >> .env.test sudo -E make -C packages/orchestrator build-template \ ARTIFACTS_REGISTRY_PROVIDER=Local \ diff --git a/.github/actions/start-services/action.yml b/.github/actions/start-services/action.yml index df4f4fe9f8..1f462f58e1 100644 --- a/.github/actions/start-services/action.yml +++ b/.github/actions/start-services/action.yml @@ -1,6 +1,24 @@ name: "Start Services" description: "Sets up and starts the required services, including PostgreSQL." +inputs: + compress_enabled: + description: "Enable compression (true/false)" + required: false + default: "false" + compress_type: + description: "Compression type (zstd, lz4)" + required: false + default: "" + compress_level: + description: "Compression level (zstd: 1=fastest, 2=default; lz4: 0)" + required: false + default: "" + compress_workers: + description: "Number of frame encode workers" + required: false + default: "" + runs: using: "composite" steps: @@ -107,6 +125,11 @@ runs: API_INTERNAL_GRPC_ADDRESS: "localhost:5009" DEFAULT_PERSISTENT_VOLUME_TYPE: "test-volume-type" SANDBOX_STORAGE_BACKEND: "redis" + COMPRESS_ENABLED: ${{ inputs.compress_enabled }} + COMPRESS_TYPE: ${{ inputs.compress_type }} + COMPRESS_LEVEL: ${{ inputs.compress_level }} + COMPRESS_FRAME_ENCODE_WORKERS: ${{ inputs.compress_workers }} + E2B_DEBUG: "true" run: | mkdir -p $SHARED_CHUNK_CACHE_PATH mkdir -p ~/logs diff --git a/.github/workflows/integration_tests.yml b/.github/workflows/integration_tests.yml index 5ec826262d..e9e1edca56 100644 --- a/.github/workflows/integration_tests.yml +++ b/.github/workflows/integration_tests.yml @@ -10,12 +10,31 @@ on: secrets: CODECOV_TOKEN: { required: false } jobs: - integration_tests: + run: runs-on: infra-tests timeout-minutes: 30 permissions: contents: read id-token: write + strategy: + fail-fast: false + matrix: + include: + - name: uncompressed + compress_enabled: "false" + compress_type: "" + compress_level: "" + compress_workers: "" + - name: zstd1 + compress_enabled: "true" + compress_type: "zstd" + compress_level: "1" + compress_workers: "8" + - name: lz4 + compress_enabled: "true" + compress_type: "lz4" + compress_level: "0" + compress_workers: "8" env: # Surfaced as env so upload steps can gate on presence (skipped on fork PRs). CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} @@ -32,9 +51,19 @@ jobs: - name: Build Template uses: ./.github/actions/build-sandbox-template + with: + compress_enabled: ${{ matrix.compress_enabled }} + compress_type: ${{ matrix.compress_type }} + compress_level: ${{ matrix.compress_level }} + compress_workers: ${{ matrix.compress_workers }} - name: Start Services uses: ./.github/actions/start-services + with: + compress_enabled: ${{ matrix.compress_enabled }} + compress_type: ${{ matrix.compress_type }} + compress_level: ${{ matrix.compress_level }} + compress_workers: ${{ matrix.compress_workers }} - name: Run Integration Tests env: @@ -80,7 +109,7 @@ jobs: if: ${{ always() && inputs.publish == true }} uses: actions/upload-artifact@v6 with: - name: Integration Tests Results + name: Integration Tests Results (${{ matrix.name }}) path: ./tests/integration/test-results.xml - name: Upload test results to Codecov @@ -96,5 +125,18 @@ jobs: if: ${{ always() && inputs.publish == true }} uses: actions/upload-artifact@v6 with: - name: Service Logs + name: Service Logs (${{ matrix.name }}) path: ~/logs/*.log + + integration_tests: + needs: run + if: always() + runs-on: ubuntu-latest + steps: + - name: Aggregate matrix result + run: | + if [[ "${{ needs.run.result }}" != "success" ]]; then + echo "matrix result: ${{ needs.run.result }}" + exit 1 + fi + echo "all matrix shards succeeded" diff --git a/.mockery.yaml b/.mockery.yaml index 7cb277c42a..20c5b979d7 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -39,34 +39,38 @@ packages: interfaces: featureFlagsClient: config: - dir: packages/shared/pkg/storage/mocks - filename: mockfeatureflagsclient.go - pkgname: storagemocks + dir: packages/shared/pkg/storage + filename: mock_featureflagsclient.go + pkgname: storage + inpackage: true structname: MockFeatureFlagsClient Blob: config: - dir: packages/shared/pkg/storage/mocks - filename: mockobjectprovider.go - pkgname: storagemocks + dir: packages/shared/pkg/storage + filename: mock_blob.go + pkgname: storage + inpackage: true Seekable: config: - dir: packages/shared/pkg/storage/mocks - filename: mockseekableobjectprovider.go - pkgname: storagemocks + dir: packages/shared/pkg/storage + filename: mock_seekable.go + pkgname: storage + inpackage: true StorageProvider: config: - dir: packages/shared/pkg/storage/mocks/provider - filename: mockstorageprovider.go - pkgname: providermocks - + dir: packages/shared/pkg/storage + filename: mock_storageprovider.go + pkgname: storage + inpackage: true io: interfaces: Reader: config: - dir: packages/shared/pkg/storage/mocks - filename: mockioreader.go - pkgname: storagemocks + dir: packages/shared/pkg/storage + filename: mock_ioreader.go + pkgname: storage + inpackage: true github.com/e2b-dev/infra/packages/shared/pkg/utils: interfaces: @@ -76,6 +80,14 @@ packages: filename: mocks_test.go pkgname: utils + github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/build: + interfaces: + Diff: + config: + dir: packages/orchestrator/pkg/sandbox/build/mocks + filename: mockdiff.go + pkgname: buildmocks + github.com/e2b-dev/infra/packages/api/internal/handlers: interfaces: featureFlagsClient: diff --git a/packages/orchestrator/Makefile b/packages/orchestrator/Makefile index 0b691fac54..282dd694d4 100644 --- a/packages/orchestrator/Makefile +++ b/packages/orchestrator/Makefile @@ -135,6 +135,10 @@ build-template: fetch-busybox GCP_PROJECT_ID=$(GCP_PROJECT_ID) \ GCP_DOCKER_REPOSITORY_NAME=$(GCP_DOCKER_REPOSITORY_NAME) \ GCP_REGION=$(GCP_REGION) \ + COMPRESS_ENABLED=$(COMPRESS_ENABLED) \ + COMPRESS_TYPE=$(COMPRESS_TYPE) \ + COMPRESS_LEVEL=$(COMPRESS_LEVEL) \ + COMPRESS_FRAME_ENCODE_WORKERS=$(COMPRESS_FRAME_ENCODE_WORKERS) \ ENVIRONMENT=local \ go run cmd/create-build/main.go \ -template $(TEMPLATE_ID) \ diff --git a/packages/orchestrator/benchmarks/benchmark_test.go b/packages/orchestrator/benchmarks/benchmark_test.go index 68178b38e5..64f9427f68 100644 --- a/packages/orchestrator/benchmarks/benchmark_test.go +++ b/packages/orchestrator/benchmarks/benchmark_test.go @@ -275,6 +275,7 @@ func BenchmarkBaseImageLaunch(b *testing.B) { sandboxes, templateCache, buildMetrics, + nil, ) buildPath := filepath.Join(os.Getenv("LOCAL_TEMPLATE_STORAGE_BASE_PATH"), buildID, "rootfs.ext4") diff --git a/packages/orchestrator/benchmarks/concurrent_benchmark_test.go b/packages/orchestrator/benchmarks/concurrent_benchmark_test.go index 0143e7462e..594ab92774 100644 --- a/packages/orchestrator/benchmarks/concurrent_benchmark_test.go +++ b/packages/orchestrator/benchmarks/concurrent_benchmark_test.go @@ -324,6 +324,7 @@ func BenchmarkConcurrentResume(b *testing.B) { config.BuilderConfig, l, featureFlags, sandboxFactory, persistenceTemplate, persistenceBuild, artifactRegistry, dockerhubRepository, sandboxProxy, sandboxes, templateCache, buildMetrics, + nil, ) // build template if not cached diff --git a/packages/orchestrator/chunks.proto b/packages/orchestrator/chunks.proto index 89993d17f2..74e1ed2700 100644 --- a/packages/orchestrator/chunks.proto +++ b/packages/orchestrator/chunks.proto @@ -2,17 +2,19 @@ syntax = "proto3"; option go_package = "https://github.com/e2b-dev/infra/orchestrator"; -// ChunkService allows orchestrators to serve snapshot files directly from -// their local cache to peer orchestrators, bypassing GCS during hot resumes. +// ChunkService allows orchestrators to serve snapshot files directly from their +// local cache to peer orchestrators, bypassing remote storage during hot +// resumes. // PeerAvailability carries the routing decision included in every response. // When neither flag is set, the file is available in the peer's local cache. message PeerAvailability { - // not_available is true when the file is not in the local cache. - // The caller should fall back to GCS. + // not_available is true when the file is not in the local cache. The caller + // should fall back to remote storage. bool not_available = 1; - // use_storage is true when the GCS upload has completed and the caller - // should switch to reading from GCS/NFS directly instead of this peer. + // use_storage is true when the remote storage upload has completed and the + // caller should switch to reading from remote storage directly instead of + // this peer. bool use_storage = 2; } diff --git a/packages/orchestrator/cmd/create-build/main.go b/packages/orchestrator/cmd/create-build/main.go index dfbcca6779..6a8690b4f3 100644 --- a/packages/orchestrator/cmd/create-build/main.go +++ b/packages/orchestrator/cmd/create-build/main.go @@ -320,10 +320,18 @@ func doBuild( buildMetrics, _ := metrics.NewBuildMetrics(noop.MeterProvider{}) sandboxFactory := sandbox.NewFactory(c.BuilderConfig, networkPool, devicePool, featureFlags, hoststats.NewNoopDelivery(), cgroup.NewNoopManager(), network.NewNoopEgressProxy(), sandboxes) + // Layered V4 builds need the upload coordinator so child layers wait on + // their parents' header finalization. Redis is nil (CLI is single-host — + // no cross-orch signaling needed); local same-orch coordination via + // futures is what matters here. + uploads := sandbox.NewUploads(templateCache, persistenceTemplate, nil) + defer uploads.Stop() + builder := build.NewBuilder( builderConfig, l, featureFlags, sandboxFactory, persistenceTemplate, persistenceBuild, artifactRegistry, dockerhubRepo, sandboxProxy, sandboxes, templateCache, buildMetrics, + uploads, ) l = l.With(zap.String("envID", templateID)).With(zap.String("buildID", buildID)) diff --git a/packages/orchestrator/cmd/resume-build/main.go b/packages/orchestrator/cmd/resume-build/main.go index 51c40bddec..41ecaa4311 100644 --- a/packages/orchestrator/cmd/resume-build/main.go +++ b/packages/orchestrator/cmd/resume-build/main.go @@ -656,21 +656,23 @@ func (r *runner) pauseOnce(ctx context.Context, opts pauseOptions, verbose bool) // Only upload when not in benchmark mode (verbose = true means single run) if verbose { - paths := storage.Paths{BuildID: opts.newBuildID} if opts.isRemoteStorage { fmt.Println("šŸ“¤ Uploading snapshot...") - if err := snapshot.Upload(ctx, r.storage, paths, nil); err != nil { - return timings, fmt.Errorf("failed to upload snapshot: %w", err) - } - fmt.Println("āœ… Snapshot uploaded successfully") } else { fmt.Println("šŸ’¾ Saving snapshot to local storage...") - if err := snapshot.Upload(ctx, r.storage, paths, nil); err != nil { - return timings, fmt.Errorf("failed to save snapshot: %w", err) - } - fmt.Println("āœ… Snapshot saved successfully") } + upload, err := sandbox.NewUpload(ctx, nil, snapshot, r.storage, storage.CompressConfig{}, nil, "", nil) + if err != nil { + return timings, fmt.Errorf("failed to prepare upload: %w", err) + } + + if err := upload.Run(ctx); err != nil { + return timings, fmt.Errorf("failed to upload snapshot: %w", err) + } + + fmt.Println("āœ… Snapshot uploaded successfully") + fmt.Printf("\nāœ… Build finished: %s\n", opts.newBuildID) printArtifactSizes(opts.storagePath, opts.newBuildID) diff --git a/packages/orchestrator/cmd/smoketest/smoke_test.go b/packages/orchestrator/cmd/smoketest/smoke_test.go index 5134c6379b..dc20cd89c3 100644 --- a/packages/orchestrator/cmd/smoketest/smoke_test.go +++ b/packages/orchestrator/cmd/smoketest/smoke_test.go @@ -237,6 +237,7 @@ func newTestInfra(t *testing.T, ctx context.Context) *testInfra { builderConfig, l, flags, factory, persistenceTemplate, persistenceBuild, artifactRegistry, dockerhubRepo, sandboxProxy, sandboxes, templateCache, buildMetrics, + nil, ) return ti diff --git a/packages/orchestrator/pkg/factories/run.go b/packages/orchestrator/pkg/factories/run.go index 7a625bdb4f..65397103de 100644 --- a/packages/orchestrator/pkg/factories/run.go +++ b/packages/orchestrator/pkg/factories/run.go @@ -548,6 +548,13 @@ func run(config cfg.Config, opts Options) (success bool) { builder := chrooted.NewBuilder(config) volumeService := volumes.New(config, builder) + uploads := sandbox.NewUploads(templateCache, persistence, redisClient) + closers = append(closers, closer{"pending uploads", func(context.Context) error { + uploads.Stop() + + return nil + }}) + orchestratorService, err := server.New(server.ServiceConfig{ Config: config, SandboxFactory: sandboxFactory, @@ -561,10 +568,14 @@ func run(config cfg.Config, opts Options) (success bool) { FeatureFlags: featureFlags, SbxEventsService: events.NewEventsService(sbxEventsDeliveryTargets), PeerRegistry: peerRegistry, + Uploads: uploads, }) if err != nil { logger.L().Fatal(ctx, "failed to create orchestrator server", zap.Error(err)) } + closers = append(closers, closer{"orchestrator server", func(context.Context) error { + return orchestratorService.Close() + }}) // template manager sandbox logger tmplSbxLoggerExternal := sbxlogger.NewLogger( @@ -639,6 +650,7 @@ func run(config cfg.Config, opts Options) (success bool) { templateCache, persistence, buildPersistence, + uploads, ) if err != nil { logger.L().Fatal(ctx, "failed to create template manager", zap.Error(err)) diff --git a/packages/orchestrator/pkg/sandbox/block/cache_test.go b/packages/orchestrator/pkg/sandbox/block/cache_test.go index bf44d52abe..5b33307b9c 100644 --- a/packages/orchestrator/pkg/sandbox/block/cache_test.go +++ b/packages/orchestrator/pkg/sandbox/block/cache_test.go @@ -299,12 +299,11 @@ func TestCacheExportToDiff_ZeroDirtyBlockMapsToSnapshotBuild(t *testing.T) { diffHeader, err := diffMetadata.ToDiffHeader(t.Context(), originalHeader, snapshotBuildID) require.NoError(t, err) - _, _, mappedBuildID, err := diffHeader.GetShiftedMapping(t.Context(), 0) + mapped, err := diffHeader.GetShiftedMapping(t.Context(), 0) require.NoError(t, err) - require.NotNil(t, mappedBuildID) - require.Equal(t, snapshotBuildID, *mappedBuildID, "zero-filled dirty block should map to the snapshot diff when empty detection is skipped") - require.NotEqual(t, uuid.Nil, *mappedBuildID, "zero-filled dirty block should no longer be represented as an empty mapping") + require.Equal(t, snapshotBuildID, mapped.BuildId, "zero-filled dirty block should map to the snapshot diff when empty detection is skipped") + require.NotEqual(t, uuid.Nil, mapped.BuildId, "zero-filled dirty block should no longer be represented as an empty mapping") } func TestCacheExportToDiff_MixedDirtyBlocksKeepsZeroBlockInDiff(t *testing.T) { @@ -358,17 +357,17 @@ func TestCacheExportToDiff_MixedDirtyBlocksKeepsZeroBlockInDiff(t *testing.T) { diffHeader, err := diffMetadata.ToDiffHeader(t.Context(), originalHeader, snapshotBuildID) require.NoError(t, err) - _, _, firstBlockBuildID, err := diffHeader.GetShiftedMapping(t.Context(), 0) + firstBlock, err := diffHeader.GetShiftedMapping(t.Context(), 0) require.NoError(t, err) - require.Equal(t, snapshotBuildID, *firstBlockBuildID, "zero-filled dirty block should still map to the snapshot diff") + require.Equal(t, snapshotBuildID, firstBlock.BuildId, "zero-filled dirty block should still map to the snapshot diff") - _, _, secondBlockBuildID, err := diffHeader.GetShiftedMapping(t.Context(), blockSize) + secondBlock, err := diffHeader.GetShiftedMapping(t.Context(), blockSize) require.NoError(t, err) - require.Equal(t, snapshotBuildID, *secondBlockBuildID) + require.Equal(t, snapshotBuildID, secondBlock.BuildId) - _, _, thirdBlockBuildID, err := diffHeader.GetShiftedMapping(t.Context(), 2*blockSize) + thirdBlock, err := diffHeader.GetShiftedMapping(t.Context(), 2*blockSize) require.NoError(t, err) - require.Equal(t, baseBuildID, *thirdBlockBuildID, "clean blocks should keep the base mapping") + require.Equal(t, baseBuildID, thirdBlock.BuildId, "clean blocks should keep the base mapping") } func TestCacheExportToDiff_NonContiguousDirtyBlocksPreserveRangeOrder(t *testing.T) { diff --git a/packages/orchestrator/pkg/sandbox/block/chunk.go b/packages/orchestrator/pkg/sandbox/block/chunk.go deleted file mode 100644 index ad2017d2aa..0000000000 --- a/packages/orchestrator/pkg/sandbox/block/chunk.go +++ /dev/null @@ -1,301 +0,0 @@ -package block - -import ( - "context" - "errors" - "fmt" - "io" - "strconv" - - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/metric" - "go.uber.org/zap" - "golang.org/x/sync/errgroup" - "golang.org/x/sync/singleflight" - - "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/block/metrics" - "github.com/e2b-dev/infra/packages/shared/pkg/featureflags" - "github.com/e2b-dev/infra/packages/shared/pkg/logger" - "github.com/e2b-dev/infra/packages/shared/pkg/storage" - "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" - "github.com/e2b-dev/infra/packages/shared/pkg/telemetry" -) - -const ( - pullType = "pull-type" - pullTypeLocal = "local" - pullTypeRemote = "remote" - - failureReason = "failure-reason" - - failureTypeLocalRead = "local-read" - failureTypeLocalReadAgain = "local-read-again" - failureTypeRemoteRead = "remote-read" - failureTypeCacheFetch = "cache-fetch" -) - -type precomputedAttrs struct { - successFromCache metric.MeasurementOption - successFromRemote metric.MeasurementOption - - failCacheRead metric.MeasurementOption - failRemoteFetch metric.MeasurementOption - failLocalReadAgain metric.MeasurementOption - - // RemoteReads timer (runFetch) - remoteSuccess metric.MeasurementOption - remoteFailure metric.MeasurementOption -} - -var chunkerAttrs = precomputedAttrs{ - successFromCache: telemetry.PrecomputeAttrs( - telemetry.Success, - attribute.String(pullType, pullTypeLocal)), - - successFromRemote: telemetry.PrecomputeAttrs( - telemetry.Success, - attribute.String(pullType, pullTypeRemote)), - - failCacheRead: telemetry.PrecomputeAttrs( - telemetry.Failure, - attribute.String(pullType, pullTypeLocal), - attribute.String(failureReason, failureTypeLocalRead)), - - failRemoteFetch: telemetry.PrecomputeAttrs( - telemetry.Failure, - attribute.String(pullType, pullTypeRemote), - attribute.String(failureReason, failureTypeCacheFetch)), - - failLocalReadAgain: telemetry.PrecomputeAttrs( - telemetry.Failure, - attribute.String(pullType, pullTypeLocal), - attribute.String(failureReason, failureTypeLocalReadAgain)), - - remoteSuccess: telemetry.PrecomputeAttrs( - telemetry.Success), - - remoteFailure: telemetry.PrecomputeAttrs( - telemetry.Failure, - attribute.String(failureReason, failureTypeRemoteRead)), -} - -// Chunker is the interface satisfied by both FullFetchChunker and StreamingChunker. -type Chunker interface { - Slice(ctx context.Context, off, length int64) ([]byte, error) - ReadAt(ctx context.Context, b []byte, off int64) (int, error) - WriteTo(ctx context.Context, w io.Writer) (int64, error) - Close() error - FileSize() (int64, error) -} - -// NewChunker creates a Chunker based on the chunker-config feature flag. -// It reads the flag internally so callers don't need to parse flag values. -func NewChunker( - ctx context.Context, - featureFlags *featureflags.Client, - size, blockSize int64, - upstream storage.Seekable, - cachePath string, - metrics metrics.Metrics, -) (Chunker, error) { - useStreaming, minReadBatchSizeKB := getChunkerConfig(ctx, featureFlags) - - if useStreaming { - return NewStreamingChunker(size, blockSize, upstream, cachePath, metrics, int64(minReadBatchSizeKB)*1024, featureFlags) - } - - return NewFullFetchChunker(size, blockSize, upstream, cachePath, metrics) -} - -// getChunkerConfig fetches the chunker-config feature flag and returns the parsed values. -func getChunkerConfig(ctx context.Context, ff *featureflags.Client) (useStreaming bool, minReadBatchSizeKB int) { - value := ff.JSONFlag(ctx, featureflags.ChunkerConfigFlag) - - if v := value.GetByKey("useStreaming"); v.IsDefined() { - useStreaming = v.BoolValue() - } - - if v := value.GetByKey("minReadBatchSizeKB"); v.IsDefined() { - minReadBatchSizeKB = v.IntValue() - } - - return useStreaming, minReadBatchSizeKB -} - -type FullFetchChunker struct { - base storage.SeekableReader - cache *Cache - metrics metrics.Metrics - - size int64 - - fetchers singleflight.Group -} - -func NewFullFetchChunker( - size, blockSize int64, - base storage.SeekableReader, - cachePath string, - metrics metrics.Metrics, -) (*FullFetchChunker, error) { - cache, err := NewCache(size, blockSize, cachePath, false) - if err != nil { - return nil, fmt.Errorf("failed to create file cache: %w", err) - } - - chunker := &FullFetchChunker{ - size: size, - base: base, - cache: cache, - metrics: metrics, - } - - return chunker, nil -} - -func (c *FullFetchChunker) ReadAt(ctx context.Context, b []byte, off int64) (int, error) { - slice, err := c.Slice(ctx, off, int64(len(b))) - if err != nil { - return 0, fmt.Errorf("failed to slice cache at %d-%d: %w", off, off+int64(len(b)), err) - } - - return copy(b, slice), nil -} - -func (c *FullFetchChunker) WriteTo(ctx context.Context, w io.Writer) (int64, error) { - for i := int64(0); i < c.size; i += storage.MemoryChunkSize { - chunk := make([]byte, storage.MemoryChunkSize) - - n, err := c.ReadAt(ctx, chunk, i) - if err != nil { - return 0, fmt.Errorf("failed to slice cache at %d-%d: %w", i, i+storage.MemoryChunkSize, err) - } - - _, err = w.Write(chunk[:n]) - if err != nil { - return 0, fmt.Errorf("failed to write chunk %d to writer: %w", i, err) - } - } - - return c.size, nil -} - -func (c *FullFetchChunker) Slice(ctx context.Context, off, length int64) ([]byte, error) { - timer := c.metrics.SlicesTimerFactory.Begin() - - b, err := c.cache.Slice(off, length) - if err == nil { - timer.RecordRaw(ctx, length, chunkerAttrs.successFromCache) - - return b, nil - } - - if !errors.As(err, &BytesNotAvailableError{}) { - timer.RecordRaw(ctx, length, chunkerAttrs.failCacheRead) - - return nil, fmt.Errorf("failed read from cache at offset %d: %w", off, err) - } - - chunkErr := c.fetchToCache(ctx, off, length) - if chunkErr != nil { - timer.RecordRaw(ctx, length, chunkerAttrs.failRemoteFetch) - - return nil, fmt.Errorf("failed to ensure data at %d-%d: %w", off, off+length, chunkErr) - } - - b, cacheErr := c.cache.Slice(off, length) - if cacheErr != nil { - timer.RecordRaw(ctx, length, chunkerAttrs.failLocalReadAgain) - - return nil, fmt.Errorf("failed to read from cache after ensuring data at %d-%d: %w", off, off+length, cacheErr) - } - - timer.RecordRaw(ctx, length, chunkerAttrs.successFromRemote) - - return b, nil -} - -// fetchToCache ensures that the data at the given offset and length is available in the cache. -func (c *FullFetchChunker) fetchToCache(ctx context.Context, off, length int64) error { - var eg errgroup.Group - - chunks := header.BlocksOffsets(length, storage.MemoryChunkSize) - - startingChunk := header.BlockIdx(off, storage.MemoryChunkSize) - startingChunkOffset := header.BlockOffset(startingChunk, storage.MemoryChunkSize) - - for _, chunkOff := range chunks { - // Ensure the closure captures the correct block offset. - fetchOff := startingChunkOffset + chunkOff - - eg.Go(func() (err error) { - defer func() { - if r := recover(); r != nil { - logger.L().Error(ctx, "recovered from panic in the fetch handler", zap.Any("error", r)) - err = fmt.Errorf("recovered from panic in the fetch handler: %v", r) - } - }() - - key := strconv.FormatInt(fetchOff, 10) - - _, err, _ = c.fetchers.Do(key, func() (any, error) { - // Check early to prevent overwriting data, Slice requires thread safety - if c.cache.isCached(fetchOff, storage.MemoryChunkSize) { - return nil, nil - } - - select { - case <-ctx.Done(): - return nil, fmt.Errorf("error fetching range %d-%d: %w", fetchOff, fetchOff+storage.MemoryChunkSize, ctx.Err()) - default: - } - - // The size of the buffer is adjusted if the last chunk is not a multiple of the block size. - b, releaseCacheCloseLock, err := c.cache.addressBytes(fetchOff, storage.MemoryChunkSize) - if err != nil { - return nil, err - } - - defer releaseCacheCloseLock() - - fetchSW := c.metrics.RemoteReadsTimerFactory.Begin() - - readBytes, err := c.base.ReadAt(ctx, b, fetchOff) - if err != nil { - fetchSW.RecordRaw(ctx, int64(readBytes), chunkerAttrs.remoteFailure) - - return nil, fmt.Errorf("failed to read chunk from base %d: %w", fetchOff, err) - } - - if readBytes != len(b) { - fetchSW.RecordRaw(ctx, int64(readBytes), chunkerAttrs.remoteFailure) - - return nil, fmt.Errorf("failed to read chunk from base %d: expected %d bytes, got %d bytes", fetchOff, len(b), readBytes) - } - - c.cache.setIsCached(fetchOff, int64(readBytes)) - - fetchSW.RecordRaw(ctx, int64(readBytes), chunkerAttrs.remoteSuccess) - - return nil, nil - }) - - return err - }) - } - - err := eg.Wait() - if err != nil { - return fmt.Errorf("failed to ensure data at %d-%d: %w", off, off+length, err) - } - - return nil -} - -func (c *FullFetchChunker) Close() error { - return c.cache.Close() -} - -func (c *FullFetchChunker) FileSize() (int64, error) { - return c.cache.FileSize() -} diff --git a/packages/orchestrator/pkg/sandbox/block/chunk_bench_test.go b/packages/orchestrator/pkg/sandbox/block/chunk_bench_test.go deleted file mode 100644 index b49be35bc6..0000000000 --- a/packages/orchestrator/pkg/sandbox/block/chunk_bench_test.go +++ /dev/null @@ -1,58 +0,0 @@ -package block - -import ( - "context" - "path/filepath" - "testing" - - sdkmetric "go.opentelemetry.io/otel/sdk/metric" - - blockmetrics "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/block/metrics" -) - -const ( - cbBlockSize int64 = 4096 - cbNumBlocks int64 = 16384 // 64 MiB - cbCacheSize int64 = cbNumBlocks * cbBlockSize -) - -// BenchmarkChunkerSlice_CacheHit benchmarks the full FullFetchChunker.Slice -// hot path on a cache hit: bitmap check + mmap slice return + OTEL -// timer.Success with attribute construction. -func BenchmarkChunkerSlice_CacheHit(b *testing.B) { - provider := sdkmetric.NewMeterProvider() - b.Cleanup(func() { provider.Shutdown(context.Background()) }) - - m, err := blockmetrics.NewMetrics(provider) - if err != nil { - b.Fatal(err) - } - - chunker, err := NewFullFetchChunker( - cbCacheSize, cbBlockSize, - nil, // base is never called on cache hit - filepath.Join(b.TempDir(), "cache"), - m, - ) - if err != nil { - b.Fatal(err) - } - b.Cleanup(func() { chunker.Close() }) - - // Pre-populate the cache so every Slice hits. - chunker.cache.setIsCached(0, cbCacheSize) - - ctx := context.Background() - - b.ResetTimer() - for i := range b.N { - off := int64(i%int(cbNumBlocks)) * cbBlockSize - s, sliceErr := chunker.Slice(ctx, off, cbBlockSize) - if sliceErr != nil { - b.Fatal(sliceErr) - } - if len(s) == 0 { - b.Fatal("empty slice") - } - } -} diff --git a/packages/orchestrator/pkg/sandbox/block/chunk_test.go b/packages/orchestrator/pkg/sandbox/block/chunk_test.go deleted file mode 100644 index c9350a80f5..0000000000 --- a/packages/orchestrator/pkg/sandbox/block/chunk_test.go +++ /dev/null @@ -1,165 +0,0 @@ -package block - -import ( - "context" - "errors" - "fmt" - "sync/atomic" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.org/x/sync/errgroup" - - "github.com/e2b-dev/infra/packages/shared/pkg/storage" -) - -// failingUpstream returns an error on ReadAt for specific offsets. -type failingUpstream struct { - data []byte - failCount atomic.Int32 // incremented on each failed ReadAt - failErr error -} - -func (u *failingUpstream) ReadAt(_ context.Context, buffer []byte, off int64) (int, error) { - if u.failErr != nil { - u.failCount.Add(1) - - return 0, u.failErr - } - - end := min(off+int64(len(buffer)), int64(len(u.data))) - n := copy(buffer, u.data[off:end]) - - return n, nil -} - -func (u *failingUpstream) Size(_ context.Context) (int64, error) { - return int64(len(u.data)), nil -} - -func TestFullFetchChunker_BasicSlice(t *testing.T) { - t.Parallel() - - data := makeTestData(t, storage.MemoryChunkSize) - upstream := &fastUpstream{data: data, blockSize: testBlockSize} - - chunker, err := NewFullFetchChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - ) - require.NoError(t, err) - defer chunker.Close() - - slice, err := chunker.Slice(t.Context(), 0, testBlockSize) - require.NoError(t, err) - assert.Equal(t, data[:testBlockSize], slice) -} - -func TestFullFetchChunker_RetryAfterError(t *testing.T) { - t.Parallel() - - data := makeTestData(t, storage.MemoryChunkSize) - - upstream := &failingUpstream{ - data: data, - failErr: errors.New("connection pool exhausted"), - } - - chunker, err := NewFullFetchChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - ) - require.NoError(t, err) - defer chunker.Close() - - // First call fails - _, err = chunker.Slice(t.Context(), 0, testBlockSize) - require.Error(t, err) - - firstFailCount := upstream.failCount.Load() - require.Positive(t, firstFailCount) - - // Clear the error — simulate saturation passing - upstream.failErr = nil - - // Retry should succeed — singleflight does not cache errors - slice, err := chunker.Slice(t.Context(), 0, testBlockSize) - require.NoError(t, err) - assert.Equal(t, data[:testBlockSize], slice) -} - -func TestFullFetchChunker_ConcurrentSameChunk(t *testing.T) { - t.Parallel() - - data := makeTestData(t, storage.MemoryChunkSize) - readCount := atomic.Int64{} - - upstream := &countingUpstream{ - inner: &fastUpstream{data: data, blockSize: testBlockSize}, - readCount: &readCount, - } - - chunker, err := NewFullFetchChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - ) - require.NoError(t, err) - defer chunker.Close() - - numGoroutines := 10 - results := make([][]byte, numGoroutines) - - var eg errgroup.Group - - for i := range numGoroutines { - eg.Go(func() error { - slice, err := chunker.Slice(t.Context(), 0, testBlockSize) - if err != nil { - return fmt.Errorf("goroutine %d failed: %w", i, err) - } - - results[i] = make([]byte, len(slice)) - copy(results[i], slice) - - return nil - }) - } - - require.NoError(t, eg.Wait()) - - for i := range numGoroutines { - assert.Equal(t, data[:testBlockSize], results[i], "goroutine %d got wrong data", i) - } -} - -func TestFullFetchChunker_DifferentChunksIndependent(t *testing.T) { - t.Parallel() - - // Two 4MB chunks - size := storage.MemoryChunkSize * 2 - data := makeTestData(t, size) - upstream := &fastUpstream{data: data, blockSize: testBlockSize} - - chunker, err := NewFullFetchChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - ) - require.NoError(t, err) - defer chunker.Close() - - // Read from chunk 0 - slice0, err := chunker.Slice(t.Context(), 0, testBlockSize) - require.NoError(t, err) - assert.Equal(t, data[:testBlockSize], slice0) - - // Read from chunk 1 - off1 := int64(storage.MemoryChunkSize) - slice1, err := chunker.Slice(t.Context(), off1, testBlockSize) - require.NoError(t, err) - assert.Equal(t, data[off1:off1+testBlockSize], slice1) -} diff --git a/packages/orchestrator/pkg/sandbox/block/device.go b/packages/orchestrator/pkg/sandbox/block/device.go index 39a1cae845..9f81d58abe 100644 --- a/packages/orchestrator/pkg/sandbox/block/device.go +++ b/packages/orchestrator/pkg/sandbox/block/device.go @@ -8,23 +8,35 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" ) +// BytesNotAvailableError indicates the requested range is not yet cached. type BytesNotAvailableError struct{} func (BytesNotAvailableError) Error() string { return "The requested bytes are not available on the device" } +type FramedReader interface { + ReadAt(ctx context.Context, p []byte, off int64, ft *storage.FrameTable) (int, error) +} + +type FramedSlicer interface { + Slice(ctx context.Context, off, length int64, ft *storage.FrameTable) ([]byte, error) +} + +// Slicer provides plain block reads (no FrameTable). Used by UFFD/NBD. type Slicer interface { Slice(ctx context.Context, off, length int64) ([]byte, error) BlockSize() int64 } type ReadonlyDevice interface { - storage.SeekableReader + ReadAt(ctx context.Context, p []byte, off int64) (int, error) + Size(ctx context.Context) (int64, error) io.Closer Slicer BlockSize() int64 Header() *header.Header + SwapHeader(h *header.Header) } type Device interface { diff --git a/packages/orchestrator/pkg/sandbox/block/empty.go b/packages/orchestrator/pkg/sandbox/block/empty.go index 8574b2a75d..e7e1795af8 100644 --- a/packages/orchestrator/pkg/sandbox/block/empty.go +++ b/packages/orchestrator/pkg/sandbox/block/empty.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "sync/atomic" "github.com/google/uuid" @@ -11,7 +12,7 @@ import ( ) type Empty struct { - header *header.Header + header atomic.Pointer[header.Header] } var _ ReadonlyDevice = (*Empty)(nil) @@ -26,9 +27,10 @@ func NewEmpty(size int64, blockSize int64, buildID uuid.UUID) (*Empty, error) { return nil, fmt.Errorf("failed to create header: %w", err) } - return &Empty{ - header: h, - }, nil + e := &Empty{} + e.header.Store(h) + + return e, nil } func (e *Empty) ReadAt(ctx context.Context, p []byte, off int64) (int, error) { @@ -41,11 +43,11 @@ func (e *Empty) ReadAt(ctx context.Context, p []byte, off int64) (int, error) { } func (e *Empty) Size(_ context.Context) (int64, error) { - return int64(e.header.Metadata.Size), nil + return int64(e.Header().Metadata.Size), nil } func (e *Empty) BlockSize() int64 { - return int64(e.header.Metadata.BlockSize) + return int64(e.Header().Metadata.BlockSize) } func (e *Empty) Close() error { @@ -54,7 +56,7 @@ func (e *Empty) Close() error { func (e *Empty) Slice(_ context.Context, off, length int64) ([]byte, error) { end := off + length - size := int64(e.header.Metadata.Size) + size := int64(e.Header().Metadata.Size) if end > size { end = size length = end - off @@ -65,7 +67,11 @@ func (e *Empty) Slice(_ context.Context, off, length int64) ([]byte, error) { } func (e *Empty) Header() *header.Header { - return e.header + return e.header.Load() +} + +func (e *Empty) SwapHeader(h *header.Header) { + e.header.Store(h) } func (e *Empty) UpdateSize() error { diff --git a/packages/orchestrator/pkg/sandbox/block/fetch_session.go b/packages/orchestrator/pkg/sandbox/block/fetch_session.go new file mode 100644 index 0000000000..eddcbd5679 --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/block/fetch_session.go @@ -0,0 +1,155 @@ +package block + +import ( + "context" + "fmt" + "sync" + "sync/atomic" +) + +type fetchSession struct { + // chunk is what we are fetching, can be >= 1 block. chunkOff/chunkLen are absolute offsets in U-space. + chunkOff int64 + chunkLen int64 + cache *Cache + + mu sync.Mutex + cond sync.Cond // broadcast on progress; lazily initialized with mu + + fetchErr error + done bool // true once terminated (success or error) + + // bytesReady is the byte count (from chunkOff) up to which all blocks + // are fully written and marked cached. Atomic so registerAndWait can + // do a lock-free fast-path check: bytesReady only increases. + bytesReady atomic.Int64 +} + +// contains reports whether the session covers the byte range [off, off+length). +func (s *fetchSession) contains(off, length int64) bool { + return s.chunkOff <= off && s.chunkOff+s.chunkLen >= off+length +} + +// terminated reports whether the session reached a terminal state. +// Must be called with mu held. +func (s *fetchSession) terminated() bool { + return s.done +} + +func newFetchSession(chunkOff, chunkLen int64, cache *Cache) *fetchSession { + s := &fetchSession{ + chunkOff: chunkOff, + chunkLen: chunkLen, + cache: cache, + } + s.cond.L = &s.mu + + return s +} + +// registerAndWait blocks until the block at blockOff is cached, the session +// terminates, or ctx is cancelled. Each caller requests exactly one block. +func (s *fetchSession) registerAndWait(ctx context.Context, blockOff int64) error { + blockSize := s.cache.blockSize + + if blockOff%blockSize != 0 { + return fmt.Errorf("blockOff %d is not aligned to block size %d", blockOff, blockSize) + } + + if blockOff < s.chunkOff || blockOff >= s.chunkOff+s.chunkLen { + return fmt.Errorf("blockOff %d is outside session range [%d, %d)", blockOff, s.chunkOff, s.chunkOff+s.chunkLen) + } + + // endByte is the byte offset (relative to chunkOff) that must be ready + // for our block to be fully written. + relEnd := blockOff + blockSize - s.chunkOff + endByte := min(relEnd, s.chunkLen) + + // Lock-free fast path: bytesReady only increases, so >= endByte + // guarantees data is available. + if s.bytesReady.Load() >= endByte { + return nil + } + + // Set up context cancellation to unblock cond.Wait. + stop := context.AfterFunc(ctx, func() { + s.cond.Broadcast() + }) + defer stop() + + s.mu.Lock() + defer s.mu.Unlock() + + for { + if s.bytesReady.Load() >= endByte { + return nil + } + + // Terminal but block not covered — only happens on error. + // setDone sets bytesReady=chunkLen, so terminated() with bytesReady < endByte + // means fetchErr != nil. Check cache in case a prior session already fetched this block. + if s.terminated() { + // isCached reads an atomic bitset — safe to call under mu. + if s.cache.isCached(blockOff, blockSize) { + return nil + } + + if s.fetchErr == nil { + return fmt.Errorf("fetch session terminated without error but block %d not cached (bytesReady=%d, endByte=%d)", + blockOff/blockSize, s.bytesReady.Load(), endByte) + } + + return fmt.Errorf("fetch failed: %w", s.fetchErr) + } + + if ctx.Err() != nil { + return ctx.Err() + } + + s.cond.Wait() + } +} + +// advance updates progress and wakes all waiters. +func (s *fetchSession) advance(bytesReady int64) { + s.mu.Lock() + defer s.mu.Unlock() + + s.bytesReady.Store(bytesReady) + s.cond.Broadcast() +} + +// setDone marks the session as successfully completed and wakes all waiters. +func (s *fetchSession) setDone() { + s.mu.Lock() + s.bytesReady.Store(s.chunkLen) + s.done = true + s.mu.Unlock() + + s.cond.Broadcast() +} + +// fail records the error unconditionally and wakes all waiters. +func (s *fetchSession) fail(err error) { + s.mu.Lock() + s.fetchErr = err + s.done = true + s.mu.Unlock() + + s.cond.Broadcast() +} + +// failIfRunning records the error only if the session has not already +// terminated — used in panic recovery and safety-net defers to avoid +// overriding a successful completion. Always broadcasts to ensure no +// waiter blocks forever. +func (s *fetchSession) failIfRunning(err error) { + s.mu.Lock() + if !s.terminated() { + s.fetchErr = err + s.done = true + } + s.mu.Unlock() + + s.cond.Broadcast() +} diff --git a/packages/orchestrator/pkg/sandbox/block/fetch_session_test.go b/packages/orchestrator/pkg/sandbox/block/fetch_session_test.go new file mode 100644 index 0000000000..ee38d3cf67 --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/block/fetch_session_test.go @@ -0,0 +1,350 @@ +package block + +import ( + "context" + "errors" + "path/filepath" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +const fetchSessionBlockSize int64 = 4096 + +func makeTestCacheForSession(t *testing.T, numBlocks int64) *Cache { + t.Helper() + + size := fetchSessionBlockSize * numBlocks + c, err := NewCache(size, fetchSessionBlockSize, filepath.Join(t.TempDir(), "cache"), false) + require.NoError(t, err) + t.Cleanup(func() { _ = c.Close() }) + + return c +} + +func TestFetchSession_FastPath(t *testing.T) { + t.Parallel() + + const blockSize = fetchSessionBlockSize + + cache := makeTestCacheForSession(t, 4) + s := newFetchSession(0, 4*blockSize, cache) + s.bytesReady.Store(4 * blockSize) + + // All blocks already covered — must return immediately via atomic fast path. + start := time.Now() + require.NoError(t, s.registerAndWait(context.Background(), 0)) + require.NoError(t, s.registerAndWait(context.Background(), 3*blockSize)) + require.Less(t, time.Since(start), time.Millisecond) +} + +func TestFetchSession_ProgressiveAdvance(t *testing.T) { + t.Parallel() + + const blockSize = fetchSessionBlockSize + const numBlocks = 512 // 2 MB total + + cache := makeTestCacheForSession(t, numBlocks) + s := newFetchSession(0, int64(numBlocks)*blockSize, cache) + + var returned atomic.Int64 + + type result struct { + blockIdx int + err error + } + ch := make(chan result, numBlocks) + + for i := range numBlocks { + go func(idx int) { + err := s.registerAndWait(context.Background(), int64(idx)*blockSize) + returned.Add(1) + ch <- result{idx, err} + }(i) + } + + // bytesReady is 0 — no waiter can have returned yet. + time.Sleep(time.Millisecond) + require.Equal(t, int64(0), returned.Load()) + + for covered := 2; covered <= numBlocks; covered += 2 { + s.advance(int64(covered) * blockSize) + + got := [2]int{} + for j := range 2 { + r := <-ch + require.NoError(t, r.err, "block %d", r.blockIdx) + got[j] = r.blockIdx + } + + if got[0] > got[1] { + got[0], got[1] = got[1], got[0] + } + + require.Equal(t, [2]int{covered - 2, covered - 1}, got, + "advance to %d should unblock exactly blocks %d and %d", covered, covered-2, covered-1) + require.Equal(t, int64(covered), returned.Load()) + } +} + +func TestFetchSession_SetDoneUnblocksAll(t *testing.T) { + t.Parallel() + + const blockSize = fetchSessionBlockSize + const numBlocks = 64 + + cache := makeTestCacheForSession(t, numBlocks) + s := newFetchSession(0, int64(numBlocks)*blockSize, cache) + + var returned atomic.Int64 + ch := make(chan error, numBlocks) + + for i := range numBlocks { + go func(idx int) { + err := s.registerAndWait(context.Background(), int64(idx)*blockSize) + returned.Add(1) + ch <- err + }(i) + } + + require.Equal(t, int64(0), returned.Load()) + + s.setDone() + + for range numBlocks { + require.NoError(t, <-ch) + } + + require.Equal(t, int64(numBlocks), returned.Load()) +} + +func TestFetchSession_FailPropagatesError(t *testing.T) { + t.Parallel() + + const blockSize = fetchSessionBlockSize + + cache := makeTestCacheForSession(t, 4) + s := newFetchSession(0, 4*blockSize, cache) + + var returned atomic.Int64 + ch := make(chan error, 1) + + go func() { + err := s.registerAndWait(context.Background(), 0) + returned.Add(1) + ch <- err + }() + + require.Equal(t, int64(0), returned.Load()) + + sentinel := errors.New("storage unavailable") + s.fail(sentinel) + + require.ErrorIs(t, <-ch, sentinel) + require.Equal(t, int64(1), returned.Load()) +} + +func TestFetchSession_ContextCancellation(t *testing.T) { + t.Parallel() + + const blockSize = fetchSessionBlockSize + + cache := makeTestCacheForSession(t, 4) + s := newFetchSession(0, 4*blockSize, cache) + + ctx, cancel := context.WithCancel(context.Background()) + + var returned atomic.Int64 + ch := make(chan error, 1) + + go func() { + err := s.registerAndWait(ctx, 0) + returned.Add(1) + ch <- err + }() + + require.Equal(t, int64(0), returned.Load()) + + cancel() + + require.ErrorIs(t, <-ch, context.Canceled) + require.Equal(t, int64(1), returned.Load()) +} + +func TestFetchSession_TerminatedButCachedByPriorSession(t *testing.T) { + t.Parallel() + + const blockSize = fetchSessionBlockSize + + cache := makeTestCacheForSession(t, 4) + s := newFetchSession(0, 4*blockSize, cache) + + // Mark block 0 as cached externally (simulates a prior fetch session). + cache.setIsCached(0, blockSize) + + // Fail this session — but block 0 is already in the cache. + s.fail(errors.New("some error")) + + err := s.registerAndWait(context.Background(), 0) + require.NoError(t, err) +} + +func TestFetchSession_TerminatedNoErrorBlockNotCached(t *testing.T) { + t.Parallel() + + const blockSize = fetchSessionBlockSize + + cache := makeTestCacheForSession(t, 4) + s := newFetchSession(0, 4*blockSize, cache) + + // Manually put in a terminal state without setting bytesReady = chunkLen. + // (setDone always sets bytesReady, so this is a defensive-code path.) + s.mu.Lock() + s.done = true + s.mu.Unlock() + s.cond.Broadcast() + + err := s.registerAndWait(context.Background(), 2*blockSize) + require.Error(t, err) + require.Contains(t, err.Error(), "terminated without error but block") +} + +func TestFetchSession_FailIfRunning_NoOpAfterSetDone(t *testing.T) { + t.Parallel() + + const blockSize = fetchSessionBlockSize + + cache := makeTestCacheForSession(t, 4) + s := newFetchSession(0, 4*blockSize, cache) + + s.setDone() + s.failIfRunning(errors.New("should be ignored")) + + // setDone set bytesReady = chunkLen, so all blocks are covered. + require.NoError(t, s.registerAndWait(context.Background(), 0)) + require.NoError(t, s.registerAndWait(context.Background(), 3*blockSize)) + require.NoError(t, s.fetchErr) +} + +func TestFetchSession_FailIfRunning_BeforeDone(t *testing.T) { + t.Parallel() + + const blockSize = fetchSessionBlockSize + + cache := makeTestCacheForSession(t, 4) + s := newFetchSession(0, 4*blockSize, cache) + + sentinel := errors.New("panic recovery") + s.failIfRunning(sentinel) + + err := s.registerAndWait(context.Background(), 0) + require.ErrorIs(t, err, sentinel) +} + +func TestFetchSession_NonZeroChunkOffset(t *testing.T) { + t.Parallel() + + const blockSize = fetchSessionBlockSize + + cache := makeTestCacheForSession(t, 8) + + chunkOff := 2 * blockSize // chunk starts at block 2 + chunkLen := blockSize + 1000 // 1.24 blocks — not aligned, exercises min clamp + lastBlockOff := chunkOff + blockSize // block 3 + s := newFetchSession(chunkOff, chunkLen, cache) + + var returned atomic.Int64 + ch := make(chan error, 2) + + go func() { + err := s.registerAndWait(context.Background(), chunkOff) // block 2 + returned.Add(1) + ch <- err + }() + + go func() { + // Block 3 extends past chunkOff+chunkLen, so endByte is clamped to chunkLen. + // relEnd = lastBlockOff + blockSize - chunkOff = 2*blockSize = 8192 + // endByte = min(8192, 5096) = 5096 (chunkLen) + err := s.registerAndWait(context.Background(), lastBlockOff) + returned.Add(1) + ch <- err + }() + + require.Equal(t, int64(0), returned.Load()) + + // Advance covers block 2 only. + // relEnd for block 2 = chunkOff + blockSize - chunkOff = blockSize = 4096 + // endByte = min(4096, 5096) = 4096 + s.advance(blockSize) + + require.NoError(t, <-ch) + require.Equal(t, int64(1), returned.Load()) + + // Advance to chunkLen — enough for the partial last block. + s.advance(chunkLen) + + require.NoError(t, <-ch) + require.Equal(t, int64(2), returned.Load()) +} + +func TestFetchSession_ConcurrentWaitersAndCancel(t *testing.T) { + t.Parallel() + + const blockSize = fetchSessionBlockSize + + cache := makeTestCacheForSession(t, 8) + s := newFetchSession(0, 8*blockSize, cache) + + ctx, cancel := context.WithCancel(context.Background()) + + var returned atomic.Int64 + + var wg sync.WaitGroup + + // 4 waiters with cancellable context, 4 with background context. + cancelErrs := make([]error, 4) + bgErrs := make([]error, 4) + + for i := range 4 { + wg.Add(2) + + go func(idx int) { + defer wg.Done() + cancelErrs[idx] = s.registerAndWait(ctx, int64(idx)*blockSize) + returned.Add(1) + }(i) + + go func(idx int) { + defer wg.Done() + bgErrs[idx] = s.registerAndWait(context.Background(), int64(idx+4)*blockSize) + returned.Add(1) + }(i) + } + + require.Equal(t, int64(0), returned.Load()) + + // Cancel the first group. + cancel() + + // Complete the session for the second group. + s.setDone() + + wg.Wait() + + require.Equal(t, int64(8), returned.Load()) + + for i, err := range cancelErrs { + // May have been cancelled OR completed — both are OK since setDone races with cancel. + if err != nil { + require.ErrorIs(t, err, context.Canceled, "cancel waiter %d", i) + } + } + + for i, err := range bgErrs { + require.NoError(t, err, "bg waiter %d", i) + } +} diff --git a/packages/orchestrator/pkg/sandbox/block/local.go b/packages/orchestrator/pkg/sandbox/block/local.go index 013c4b8940..8400c3cc0f 100644 --- a/packages/orchestrator/pkg/sandbox/block/local.go +++ b/packages/orchestrator/pkg/sandbox/block/local.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "os" + "sync/atomic" "github.com/google/uuid" @@ -15,7 +16,7 @@ type Local struct { f *os.File path string - header *header.Header + header atomic.Pointer[header.Header] } var _ ReadonlyDevice = (*Local)(nil) @@ -44,11 +45,10 @@ func NewLocal(path string, blockSize int64, buildID uuid.UUID) (*Local, error) { return nil, fmt.Errorf("failed to create header: %w", err) } - return &Local{ - f: f, - path: path, - header: h, - }, nil + d := &Local{f: f, path: path} + d.header.Store(h) + + return d, nil } func (d *Local) Path() string { @@ -65,11 +65,11 @@ func (d *Local) ReadAt(ctx context.Context, p []byte, off int64) (int, error) { } func (d *Local) Size(_ context.Context) (int64, error) { - return int64(d.header.Metadata.Size), nil + return int64(d.Header().Metadata.Size), nil } func (d *Local) BlockSize() int64 { - return int64(d.header.Metadata.BlockSize) + return int64(d.Header().Metadata.BlockSize) } func (d *Local) Close() (e error) { @@ -83,7 +83,7 @@ func (d *Local) Close() (e error) { func (d *Local) Slice(_ context.Context, off, length int64) ([]byte, error) { end := off + length - size := int64(d.header.Metadata.Size) + size := int64(d.Header().Metadata.Size) if end > size { end = size length = end - off @@ -99,7 +99,11 @@ func (d *Local) Slice(_ context.Context, off, length int64) ([]byte, error) { } func (d *Local) Header() *header.Header { - return d.header + return d.header.Load() +} + +func (d *Local) SwapHeader(h *header.Header) { + d.header.Store(h) } func (d *Local) UpdateHeaderSize() error { @@ -108,7 +112,16 @@ func (d *Local) UpdateHeaderSize() error { return fmt.Errorf("failed to get file info: %w", err) } - d.header.Metadata.Size = uint64(info.Size()) + h := d.Header() + metaCopy := *h.Metadata + metaCopy.Size = uint64(info.Size()) + + updated := &header.Header{ + Metadata: &metaCopy, + Builds: h.Builds, + Mapping: h.Mapping, + } + d.SwapHeader(updated) return nil } diff --git a/packages/orchestrator/pkg/sandbox/block/mocks/mockreadonlydevice.go b/packages/orchestrator/pkg/sandbox/block/mocks/mockreadonlydevice.go index d0c464b661..8336b299c1 100644 --- a/packages/orchestrator/pkg/sandbox/block/mocks/mockreadonlydevice.go +++ b/packages/orchestrator/pkg/sandbox/block/mocks/mockreadonlydevice.go @@ -173,8 +173,8 @@ func (_c *MockReadonlyDevice_Header_Call) RunAndReturn(run func() *header.Header } // ReadAt provides a mock function for the type MockReadonlyDevice -func (_mock *MockReadonlyDevice) ReadAt(ctx context.Context, buffer []byte, off int64) (int, error) { - ret := _mock.Called(ctx, buffer, off) +func (_mock *MockReadonlyDevice) ReadAt(ctx context.Context, p []byte, off int64) (int, error) { + ret := _mock.Called(ctx, p, off) if len(ret) == 0 { panic("no return value specified for ReadAt") @@ -183,15 +183,15 @@ func (_mock *MockReadonlyDevice) ReadAt(ctx context.Context, buffer []byte, off var r0 int var r1 error if returnFunc, ok := ret.Get(0).(func(context.Context, []byte, int64) (int, error)); ok { - return returnFunc(ctx, buffer, off) + return returnFunc(ctx, p, off) } if returnFunc, ok := ret.Get(0).(func(context.Context, []byte, int64) int); ok { - r0 = returnFunc(ctx, buffer, off) + r0 = returnFunc(ctx, p, off) } else { r0 = ret.Get(0).(int) } if returnFunc, ok := ret.Get(1).(func(context.Context, []byte, int64) error); ok { - r1 = returnFunc(ctx, buffer, off) + r1 = returnFunc(ctx, p, off) } else { r1 = ret.Error(1) } @@ -205,13 +205,13 @@ type MockReadonlyDevice_ReadAt_Call struct { // ReadAt is a helper method to define mock.On call // - ctx context.Context -// - buffer []byte +// - p []byte // - off int64 -func (_e *MockReadonlyDevice_Expecter) ReadAt(ctx interface{}, buffer interface{}, off interface{}) *MockReadonlyDevice_ReadAt_Call { - return &MockReadonlyDevice_ReadAt_Call{Call: _e.mock.On("ReadAt", ctx, buffer, off)} +func (_e *MockReadonlyDevice_Expecter) ReadAt(ctx interface{}, p interface{}, off interface{}) *MockReadonlyDevice_ReadAt_Call { + return &MockReadonlyDevice_ReadAt_Call{Call: _e.mock.On("ReadAt", ctx, p, off)} } -func (_c *MockReadonlyDevice_ReadAt_Call) Run(run func(ctx context.Context, buffer []byte, off int64)) *MockReadonlyDevice_ReadAt_Call { +func (_c *MockReadonlyDevice_ReadAt_Call) Run(run func(ctx context.Context, p []byte, off int64)) *MockReadonlyDevice_ReadAt_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { @@ -239,7 +239,7 @@ func (_c *MockReadonlyDevice_ReadAt_Call) Return(n int, err error) *MockReadonly return _c } -func (_c *MockReadonlyDevice_ReadAt_Call) RunAndReturn(run func(ctx context.Context, buffer []byte, off int64) (int, error)) *MockReadonlyDevice_ReadAt_Call { +func (_c *MockReadonlyDevice_ReadAt_Call) RunAndReturn(run func(ctx context.Context, p []byte, off int64) (int, error)) *MockReadonlyDevice_ReadAt_Call { _c.Call.Return(run) return _c } @@ -377,3 +377,43 @@ func (_c *MockReadonlyDevice_Slice_Call) RunAndReturn(run func(ctx context.Conte _c.Call.Return(run) return _c } + +// SwapHeader provides a mock function for the type MockReadonlyDevice +func (_mock *MockReadonlyDevice) SwapHeader(h *header.Header) { + _mock.Called(h) + return +} + +// MockReadonlyDevice_SwapHeader_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SwapHeader' +type MockReadonlyDevice_SwapHeader_Call struct { + *mock.Call +} + +// SwapHeader is a helper method to define mock.On call +// - h *header.Header +func (_e *MockReadonlyDevice_Expecter) SwapHeader(h interface{}) *MockReadonlyDevice_SwapHeader_Call { + return &MockReadonlyDevice_SwapHeader_Call{Call: _e.mock.On("SwapHeader", h)} +} + +func (_c *MockReadonlyDevice_SwapHeader_Call) Run(run func(h *header.Header)) *MockReadonlyDevice_SwapHeader_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 *header.Header + if args[0] != nil { + arg0 = args[0].(*header.Header) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *MockReadonlyDevice_SwapHeader_Call) Return() *MockReadonlyDevice_SwapHeader_Call { + _c.Call.Return() + return _c +} + +func (_c *MockReadonlyDevice_SwapHeader_Call) RunAndReturn(run func(h *header.Header)) *MockReadonlyDevice_SwapHeader_Call { + _c.Run(run) + return _c +} diff --git a/packages/orchestrator/pkg/sandbox/block/overlay.go b/packages/orchestrator/pkg/sandbox/block/overlay.go index 499aa23ada..0e9987937f 100644 --- a/packages/orchestrator/pkg/sandbox/block/overlay.go +++ b/packages/orchestrator/pkg/sandbox/block/overlay.go @@ -89,3 +89,7 @@ func (o *Overlay) Close() error { func (o *Overlay) Header() *header.Header { return o.device.Header() } + +func (o *Overlay) SwapHeader(h *header.Header) { + o.device.SwapHeader(h) +} diff --git a/packages/orchestrator/pkg/sandbox/block/streaming_chunk.go b/packages/orchestrator/pkg/sandbox/block/streaming_chunk.go index ca648397f3..94e9cf8ef4 100644 --- a/packages/orchestrator/pkg/sandbox/block/streaming_chunk.go +++ b/packages/orchestrator/pkg/sandbox/block/streaming_chunk.go @@ -1,200 +1,70 @@ package block import ( - "cmp" "context" "errors" "fmt" "io" - "slices" "sync" - "sync/atomic" "time" - "golang.org/x/sync/errgroup" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/block/metrics" "github.com/e2b-dev/infra/packages/shared/pkg/featureflags" "github.com/e2b-dev/infra/packages/shared/pkg/storage" - "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" + "github.com/e2b-dev/infra/packages/shared/pkg/telemetry" ) const ( // defaultFetchTimeout is the maximum time a single 4MB chunk fetch may take. // Acts as a safety net: if the upstream hangs, the goroutine won't live forever. defaultFetchTimeout = 60 * time.Second - - // defaultMinReadBatchSize is the floor for the read batch size when blockSize - // is very small (e.g. 4KB rootfs). The actual batch is max(blockSize, minReadBatchSize). - defaultMinReadBatchSize = 16 * 1024 // 16 KB ) -type rangeWaiter struct { - // endByte is the byte offset (relative to chunkOff) at which this waiter's - // entire requested range is cached. Equal to the end of the last block - // overlapping the requested range. Always a multiple of blockSize. - endByte int64 - ch chan error // buffered cap 1 -} - -type fetchSession struct { - mu sync.Mutex - chunkOff int64 - chunkLen int64 - cache *Cache - waiters []*rangeWaiter // sorted by endByte ascending - fetchErr error - - // bytesReady is the byte count (from chunkOff) up to which all blocks are - // fully written to mmap and marked cached. Always a multiple of blockSize - // during progressive reads. Used to cheaply determine which sorted waiters - // are satisfied without calling isCached. - // - // Atomic so registerAndWait can do a lock-free fast-path check: - // bytesReady only increases, so a Load() >= endByte guarantees data - // availability without taking the mutex. - bytesReady atomic.Int64 -} - -// terminated reports whether the fetch session has reached a terminal state -// (done or errored). Must be called with s.mu held. -func (s *fetchSession) terminated() bool { - return s.fetchErr != nil || s.bytesReady.Load() == s.chunkLen -} - -// registerAndWait adds a waiter for the given range and blocks until the range -// is cached or the context is cancelled. Returns nil if the range was already -// cached before registering. -func (s *fetchSession) registerAndWait(ctx context.Context, off, length int64) error { - blockSize := s.cache.BlockSize() - lastBlockIdx := (off + length - 1 - s.chunkOff) / blockSize - endByte := (lastBlockIdx + 1) * blockSize - - // Lock-free fast path: bytesReady only increases, so >= endByte - // guarantees data is available without taking the lock. - if s.bytesReady.Load() >= endByte { - return nil - } - - s.mu.Lock() - - // Re-check under lock. - if endByte <= s.bytesReady.Load() { - s.mu.Unlock() - - return nil - } - - // Terminal but range not covered — only happens on error - // (Done sets bytesReady=chunkLen). Check cache for prior session data. - if s.terminated() { - fetchErr := s.fetchErr - s.mu.Unlock() - if s.cache.isCached(off, length) { - return nil - } - - if fetchErr != nil { - return fmt.Errorf("fetch failed: %w", fetchErr) - } - - return fmt.Errorf("fetch completed but range %d-%d not cached", off, off+length) - } - - // Fetch in progress — register waiter. - w := &rangeWaiter{endByte: endByte, ch: make(chan error, 1)} - idx, _ := slices.BinarySearchFunc(s.waiters, endByte, func(w *rangeWaiter, target int64) int { - return cmp.Compare(w.endByte, target) - }) - s.waiters = slices.Insert(s.waiters, idx, w) - s.mu.Unlock() - - select { - case err := <-w.ch: - return err - case <-ctx.Done(): - return ctx.Err() - } -} - -// notifyWaiters notifies waiters whose ranges are satisfied. -// -// Because waiters are sorted by endByte and the fetch fills the chunk -// sequentially, we only need to walk from the front until we hit a waiter -// whose endByte exceeds bytesReady — all subsequent waiters are unsatisfied. -// -// In terminal states (done/errored) all remaining waiters are notified. -// Must be called with s.mu held. -func (s *fetchSession) notifyWaiters(sendErr error) { - ready := s.bytesReady.Load() - - // Terminal: notify every remaining waiter. - if s.terminated() { - for _, w := range s.waiters { - if sendErr != nil && w.endByte > ready { - w.ch <- sendErr - } - close(w.ch) - } - s.waiters = nil - - return - } - - // Progress: pop satisfied waiters from the sorted front. - i := 0 - for i < len(s.waiters) && s.waiters[i].endByte <= ready { - close(s.waiters[i].ch) - i++ - } - s.waiters = s.waiters[i:] -} - -type StreamingChunker struct { - upstream storage.StreamingReader - cache *Cache - metrics metrics.Metrics - fetchTimeout time.Duration - featureFlags *featureflags.Client - minReadBatchSize int64 +type Chunker struct { + upstream storage.StreamingReader + cache *Cache + metrics metrics.Metrics + fetchTimeout time.Duration + featureFlags *featureflags.Client size int64 - fetchMu sync.Mutex - fetchMap map[int64]*fetchSession + fetchMu sync.Mutex + fetchSessions []*fetchSession } -func NewStreamingChunker( +var ( + _ FramedReader = (*Chunker)(nil) + _ FramedSlicer = (*Chunker)(nil) +) + +func NewChunker( + ff *featureflags.Client, size, blockSize int64, upstream storage.StreamingReader, cachePath string, metrics metrics.Metrics, - minReadBatchSize int64, - ff *featureflags.Client, -) (*StreamingChunker, error) { +) (*Chunker, error) { cache, err := NewCache(size, blockSize, cachePath, false) if err != nil { return nil, fmt.Errorf("failed to create file cache: %w", err) } - if minReadBatchSize <= 0 { - minReadBatchSize = defaultMinReadBatchSize - } - - return &StreamingChunker{ - size: size, - upstream: upstream, - cache: cache, - metrics: metrics, - featureFlags: ff, - fetchTimeout: defaultFetchTimeout, - minReadBatchSize: minReadBatchSize, - fetchMap: make(map[int64]*fetchSession), + return &Chunker{ + size: size, + upstream: upstream, + cache: cache, + metrics: metrics, + featureFlags: ff, + fetchTimeout: defaultFetchTimeout, }, nil } -func (c *StreamingChunker) ReadAt(ctx context.Context, b []byte, off int64) (int, error) { - slice, err := c.Slice(ctx, off, int64(len(b))) +func (c *Chunker) ReadAt(ctx context.Context, b []byte, off int64, ft *storage.FrameTable) (int, error) { + slice, err := c.Slice(ctx, off, int64(len(b)), ft) if err != nil { return 0, fmt.Errorf("failed to slice cache at %d-%d: %w", off, off+int64(len(b)), err) } @@ -202,70 +72,29 @@ func (c *StreamingChunker) ReadAt(ctx context.Context, b []byte, off int64) (int return copy(b, slice), nil } -func (c *StreamingChunker) WriteTo(ctx context.Context, w io.Writer) (int64, error) { - chunk := make([]byte, storage.MemoryChunkSize) - - for i := int64(0); i < c.size; i += storage.MemoryChunkSize { - n, err := c.ReadAt(ctx, chunk, i) - if err != nil { - return 0, fmt.Errorf("failed to slice cache at %d-%d: %w", i, i+storage.MemoryChunkSize, err) - } - - _, err = w.Write(chunk[:n]) - if err != nil { - return 0, fmt.Errorf("failed to write chunk %d to writer: %w", i, err) - } +func (c *Chunker) Slice(ctx context.Context, off, length int64, ft *storage.FrameTable) ([]byte, error) { + attrs := chunkerAttrs + if ft.IsCompressed() { + attrs = chunkerAttrsCompressed } - - return c.size, nil -} - -func (c *StreamingChunker) Slice(ctx context.Context, off, length int64) ([]byte, error) { timer := c.metrics.SlicesTimerFactory.Begin() // Fast path: already cached b, err := c.cache.Slice(off, length) if err == nil { - timer.RecordRaw(ctx, length, chunkerAttrs.successFromCache) + timer.RecordRaw(ctx, length, attrs.successFromCache) return b, nil } if !errors.As(err, &BytesNotAvailableError{}) { - timer.RecordRaw(ctx, length, chunkerAttrs.failCacheRead) + timer.RecordRaw(ctx, length, attrs.failCacheRead) return nil, fmt.Errorf("failed read from cache at offset %d: %w", off, err) } - // Compute which 4MB chunks overlap with the requested range - firstChunkOff := header.BlockOffset(header.BlockIdx(off, storage.MemoryChunkSize), storage.MemoryChunkSize) - lastChunkOff := header.BlockOffset(header.BlockIdx(off+length-1, storage.MemoryChunkSize), storage.MemoryChunkSize) - - var eg errgroup.Group - - for fetchOff := firstChunkOff; fetchOff <= lastChunkOff; fetchOff += storage.MemoryChunkSize { - eg.Go(func() error { - // Clip request to this chunk's boundaries - chunkEnd := fetchOff + storage.MemoryChunkSize - clippedOff := max(off, fetchOff) - clippedEnd := min(off+length, chunkEnd, c.size) - clippedLen := clippedEnd - clippedOff - - if clippedLen <= 0 { - return nil - } - - session, justGotCached := c.getOrCreateSession(ctx, fetchOff) - if justGotCached { - return nil - } - - return session.registerAndWait(ctx, clippedOff, clippedLen) - }) - } - - if err := eg.Wait(); err != nil { - timer.RecordRaw(ctx, length, chunkerAttrs.failRemoteFetch) + if err := c.fetch(ctx, off, ft); err != nil { + timer.RecordRaw(ctx, length, attrs.failRemoteFetch) return nil, fmt.Errorf("failed to ensure data at %d-%d: %w", off, off+length, err) } @@ -273,193 +102,288 @@ func (c *StreamingChunker) Slice(ctx context.Context, off, length int64) ([]byte // sliceDirect skips isCached — the waiter already confirmed the data is in the mmap. b, cacheErr := c.cache.sliceDirect(off, length) if cacheErr != nil { - timer.RecordRaw(ctx, length, chunkerAttrs.failLocalReadAgain) + timer.RecordRaw(ctx, length, attrs.failLocalReadAgain) return nil, fmt.Errorf("failed to read from cache after ensuring data at %d-%d: %w", off, off+length, cacheErr) } - timer.RecordRaw(ctx, length, chunkerAttrs.successFromRemote) + timer.RecordRaw(ctx, length, attrs.successFromRemote) return b, nil } -// getOrCreateSession returns a fetch session for the chunk at fetchOff, or -// (nil, true) if the data is already fully cached. -// -// Slice() checks isCached() before calling this method as a lock-free fast -// path. A TOCTOU race exists between that check and the fetchMap lookup: -// a fetch can finish (writing the dirty bitmap) and delete itself from -// fetchMap in between, so the caller misses both. To close this we re-check -// isCached under fetchMu. This is safe because runFetch calls setIsCached -// before acquiring fetchMu to delete, so the lock provides a happens-before -// guarantee that the bitmap writes are visible here. -func (c *StreamingChunker) getOrCreateSession(ctx context.Context, fetchOff int64) (_ *fetchSession, cached bool) { - chunkLen := min(int64(storage.MemoryChunkSize), c.size-fetchOff) - +// getOrCreateSession returns a fetch session for the chunk at [off, off+length), +// or (nil, true) if the data is already fully cached. +func (c *Chunker) getOrCreateSession(ctx context.Context, off, length int64, ft *storage.FrameTable) (_ *fetchSession, cached bool) { c.fetchMu.Lock() + defer c.fetchMu.Unlock() - if existing, ok := c.fetchMap[fetchOff]; ok { - c.fetchMu.Unlock() - - return existing, false + for _, s := range c.fetchSessions { + if s.contains(off, length) { + return s, false + } } - if c.cache.isCached(fetchOff, chunkLen) { - c.fetchMu.Unlock() - + // Re-check cache under fetchMu. A fetch can finish (marking blocks + // cached via setIsCached) and remove itself from sessions between + // the lock-free Slice() and the session scan above. The lock + // provides a happens-before guarantee that the bitmap writes are visible. + if c.cache.isCached(off, length) { return nil, true } - s := &fetchSession{ - chunkOff: fetchOff, - chunkLen: chunkLen, - cache: c.cache, - } - c.fetchMap[fetchOff] = s - c.fetchMu.Unlock() + s := newFetchSession(off, length, c.cache) + c.fetchSessions = append(c.fetchSessions, s) // Detach from the caller's cancel signal so the shared fetch goroutine // continues even if the first caller's context is cancelled. Trace/value // context is preserved for metrics. - go c.runFetch(context.WithoutCancel(ctx), s) + go c.runFetch(context.WithoutCancel(ctx), s, ft) return s, false } -func (s *fetchSession) setDone() { - s.mu.Lock() - defer s.mu.Unlock() - - s.bytesReady.Store(s.chunkLen) - s.notifyWaiters(nil) -} - -func (s *fetchSession) setError(err error, onlyIfRunning bool) { - s.mu.Lock() - defer s.mu.Unlock() +// fetch ensures the frame/chunk covering off is fetched into the mmap cache, +// then waits until the block at off is available. Deduplicates concurrent +// requests for the same region via the session list. +func (c *Chunker) fetch(ctx context.Context, off int64, ft *storage.FrameTable) error { + chunkOff, chunkLen, err := c.locateChunk(off, ft) + if err != nil { + return fmt.Errorf("failed to locate chunk for offset %d: %w", off, err) + } - if onlyIfRunning && s.terminated() { - return + session, justGotCached := c.getOrCreateSession(ctx, chunkOff, chunkLen, ft) + if justGotCached { + return nil } - s.fetchErr = err - s.notifyWaiters(err) + blockSize := c.cache.BlockSize() + blockOff := (off / blockSize) * blockSize + + return session.registerAndWait(ctx, blockOff) } -func (c *StreamingChunker) runFetch(ctx context.Context, s *fetchSession) { +// runFetch fetches data from storage into the mmap cache. Runs in a background goroutine. +func (c *Chunker) runFetch(ctx context.Context, s *fetchSession, ft *storage.FrameTable) { ctx, cancel := context.WithTimeout(ctx, c.fetchTimeout) defer cancel() - defer func() { - c.fetchMu.Lock() - delete(c.fetchMap, s.chunkOff) - c.fetchMu.Unlock() - }() + defer c.releaseSession(s) - // Panic recovery: ensure waiters are always notified even if the fetch - // goroutine panics (e.g. nil pointer in upstream reader, mmap fault). - // Without this, waiters would block forever on their channels. + // Unconditionally terminate the session on exit so registerAndWait + // never blocks forever — whether the fetch succeeded, failed, or panicked. defer func() { if r := recover(); r != nil { - err := fmt.Errorf("fetch panicked: %v", r) - s.setError(err, true) + s.failIfRunning(fmt.Errorf("fetch panicked: %v", r)) + + return } + + // Safety net: if no code path called setDone/fail, terminate now. + s.failIfRunning(errors.New("fetch exited without completing")) }() mmapSlice, releaseLock, err := c.cache.addressBytes(s.chunkOff, s.chunkLen) if err != nil { - s.setError(err, false) + s.fail(err) return } defer releaseLock() + attrs := chunkerAttrs + if ft.IsCompressed() { + attrs = chunkerAttrsCompressed + } fetchTimer := c.metrics.RemoteReadsTimerFactory.Begin() - err = c.progressiveRead(ctx, s, mmapSlice) + readBytes, err := c.progressiveRead(ctx, s, mmapSlice, ft) if err != nil { - fetchTimer.RecordRaw(ctx, s.chunkLen, chunkerAttrs.remoteFailure) + fetchTimer.RecordRaw(ctx, readBytes, attrs.remoteFailure) - s.setError(err, false) + s.fail(err) return } // Mark entire chunk as cached BEFORE releasing waiters. - // This ensures isCached returns true before the session is removed from fetchMap, + // This ensures isCached returns true before the session is removed from fetchSessions, // closing the TOCTOU window in getOrCreateSession. c.cache.setIsCached(s.chunkOff, s.chunkLen) - fetchTimer.RecordRaw(ctx, s.chunkLen, chunkerAttrs.remoteSuccess) + fetchTimer.RecordRaw(ctx, readBytes, attrs.remoteSuccess) s.setDone() } -func (c *StreamingChunker) progressiveRead(ctx context.Context, s *fetchSession, mmapSlice []byte) error { - reader, err := c.upstream.OpenRangeReader(ctx, s.chunkOff, s.chunkLen) +func (c *Chunker) progressiveRead(ctx context.Context, s *fetchSession, mmapSlice []byte, ft *storage.FrameTable) (totalRead int64, err error) { + reader, err := c.upstream.OpenRangeReader(ctx, s.chunkOff, s.chunkLen, ft) if err != nil { - return fmt.Errorf("failed to open range reader at %d: %w", s.chunkOff, err) + return 0, fmt.Errorf("failed to open range reader at %d: %w", s.chunkOff, err) } - defer reader.Close() + defer func() { + if closeErr := reader.Close(); closeErr != nil && err == nil { + err = closeErr + } + }() blockSize := c.cache.BlockSize() - readBatch := max(blockSize, c.getMinReadBatchSize(ctx)) - var totalRead int64 - var prevCompleted int64 + readBatch := max(blockSize, int64(c.featureFlags.IntFlag(ctx, featureflags.MinChunkerReadSizeKB))*1024) for totalRead < s.chunkLen { - // Read in batches of max(blockSize, 16KB) to align notification + // Read in batches of max(blockSize, minReadBatchSize) to align notification // granularity with the read size and minimize lock/notify overhead. readEnd := min(totalRead+readBatch, s.chunkLen) - n, readErr := reader.Read(mmapSlice[totalRead:readEnd]) + n, readErr := io.ReadFull(reader, mmapSlice[totalRead:readEnd]) totalRead += int64(n) - completedBlocks := totalRead / blockSize - if completedBlocks > prevCompleted { - prevCompleted = completedBlocks - - // Notify waiters at block granularity via bytesReady. + if n > 0 { // Dirty marking is deferred to runFetch after the full chunk is fetched. - s.mu.Lock() - s.bytesReady.Store(completedBlocks * blockSize) - s.notifyWaiters(nil) - s.mu.Unlock() - } - - if errors.Is(readErr, io.EOF) { - // Remaining waiters are notified in runFetch via the Done state. - break + // With coarse dirty granularity, marking here would expose partially-written data. + s.advance(totalRead) } if readErr != nil { - return fmt.Errorf("failed reading at offset %d after %d bytes: %w", s.chunkOff, totalRead, readErr) + if totalRead >= s.chunkLen { + break // all bytes received; trailing EOF is expected + } + + return totalRead, fmt.Errorf("failed reading at offset %d after %d bytes: %w", s.chunkOff, totalRead, readErr) } } - if totalRead < s.chunkLen { - return fmt.Errorf("short read: expected %d bytes, got %d", s.chunkLen, totalRead) - } + return totalRead, nil +} - return nil +// releaseSession removes s from the active list (swap-delete). +func (c *Chunker) releaseSession(s *fetchSession) { + c.fetchMu.Lock() + defer c.fetchMu.Unlock() + + for i, a := range c.fetchSessions { + if a == s { + c.fetchSessions[i] = c.fetchSessions[len(c.fetchSessions)-1] + c.fetchSessions[len(c.fetchSessions)-1] = nil + c.fetchSessions = c.fetchSessions[:len(c.fetchSessions)-1] + + return + } + } } -// getMinReadBatchSize returns the effective min read batch size. When a feature -// flags client is available, the value is read just-in-time from the flag so -// it can be tuned without restarting the service. -func (c *StreamingChunker) getMinReadBatchSize(ctx context.Context) int64 { - if c.featureFlags != nil { - _, minKB := getChunkerConfig(ctx, c.featureFlags) - if minKB > 0 { - return int64(minKB) * 1024 +// locateChunk returns the aligned (offset, length) of the chunk containing off. +// For compressed data the frame table defines chunk boundaries; for +// uncompressed data chunks are MemoryChunkSize-aligned (for backwards +// compatibility) and clamped to file size. +func (c *Chunker) locateChunk(off int64, ft *storage.FrameTable) (chunkOff, chunkLen int64, err error) { + if ft.IsCompressed() { + r, err := ft.LocateUncompressed(off) + if err != nil { + return 0, 0, err } + + return r.Offset, int64(r.Length), nil } - return c.minReadBatchSize + chunkOff = (off / storage.MemoryChunkSize) * storage.MemoryChunkSize + + return chunkOff, min(int64(storage.MemoryChunkSize), c.size-chunkOff), nil } -func (c *StreamingChunker) Close() error { +func (c *Chunker) Close() error { return c.cache.Close() } -func (c *StreamingChunker) FileSize() (int64, error) { +func (c *Chunker) FileSize() (int64, error) { return c.cache.FileSize() } + +const ( + compressedAttr = "compressed" + pullType = "pull-type" + pullTypeLocal = "local" + pullTypeRemote = "remote" + + failureReason = "failure-reason" + + failureTypeLocalRead = "local-read" + failureTypeLocalReadAgain = "local-read-again" + failureTypeRemoteRead = "remote-read" + failureTypeCacheFetch = "cache-fetch" +) + +type precomputedAttrs struct { + successFromCache metric.MeasurementOption + successFromRemote metric.MeasurementOption + + failCacheRead metric.MeasurementOption + failRemoteFetch metric.MeasurementOption + failLocalReadAgain metric.MeasurementOption + + // RemoteReads timer (runFetch) + remoteSuccess metric.MeasurementOption + remoteFailure metric.MeasurementOption +} + +var chunkerAttrs = precomputedAttrs{ + successFromCache: telemetry.PrecomputeAttrs( + telemetry.Success, + attribute.String(pullType, pullTypeLocal)), + + successFromRemote: telemetry.PrecomputeAttrs( + telemetry.Success, + attribute.String(pullType, pullTypeRemote)), + + failCacheRead: telemetry.PrecomputeAttrs( + telemetry.Failure, + attribute.String(pullType, pullTypeLocal), + attribute.String(failureReason, failureTypeLocalRead)), + + failRemoteFetch: telemetry.PrecomputeAttrs( + telemetry.Failure, + attribute.String(pullType, pullTypeRemote), + attribute.String(failureReason, failureTypeCacheFetch)), + + failLocalReadAgain: telemetry.PrecomputeAttrs( + telemetry.Failure, + attribute.String(pullType, pullTypeLocal), + attribute.String(failureReason, failureTypeLocalReadAgain)), + + remoteSuccess: telemetry.PrecomputeAttrs( + telemetry.Success), + + remoteFailure: telemetry.PrecomputeAttrs( + telemetry.Failure, + attribute.String(failureReason, failureTypeRemoteRead)), +} + +var chunkerAttrsCompressed = precomputedAttrs{ + successFromCache: telemetry.PrecomputeAttrs( + telemetry.Success, attribute.Bool(compressedAttr, true), + attribute.String(pullType, pullTypeLocal)), + + successFromRemote: telemetry.PrecomputeAttrs( + telemetry.Success, attribute.Bool(compressedAttr, true), + attribute.String(pullType, pullTypeRemote)), + + failCacheRead: telemetry.PrecomputeAttrs( + telemetry.Failure, attribute.Bool(compressedAttr, true), + attribute.String(pullType, pullTypeLocal), + attribute.String(failureReason, failureTypeLocalRead)), + + failRemoteFetch: telemetry.PrecomputeAttrs( + telemetry.Failure, attribute.Bool(compressedAttr, true), + attribute.String(pullType, pullTypeRemote), + attribute.String(failureReason, failureTypeCacheFetch)), + + failLocalReadAgain: telemetry.PrecomputeAttrs( + telemetry.Failure, attribute.Bool(compressedAttr, true), + attribute.String(pullType, pullTypeLocal), + attribute.String(failureReason, failureTypeLocalReadAgain)), + + remoteSuccess: telemetry.PrecomputeAttrs( + telemetry.Success, attribute.Bool(compressedAttr, true)), + + remoteFailure: telemetry.PrecomputeAttrs( + telemetry.Failure, attribute.Bool(compressedAttr, true), + attribute.String(failureReason, failureTypeRemoteRead)), +} diff --git a/packages/orchestrator/pkg/sandbox/block/streaming_chunk_test.go b/packages/orchestrator/pkg/sandbox/block/streaming_chunk_test.go index 661a04b04c..28ed991cbc 100644 --- a/packages/orchestrator/pkg/sandbox/block/streaming_chunk_test.go +++ b/packages/orchestrator/pkg/sandbox/block/streaming_chunk_test.go @@ -3,498 +3,441 @@ package block import ( "bytes" "context" - "crypto/rand" - "errors" "fmt" "io" - mathrand "math/rand/v2" + "math/rand/v2" "sync/atomic" "testing" - "time" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/metric/noop" "golang.org/x/sync/errgroup" "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/block/metrics" + "github.com/e2b-dev/infra/packages/shared/pkg/featureflags" "github.com/e2b-dev/infra/packages/shared/pkg/storage" "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" ) const ( testBlockSize = header.PageSize // 4KB + testFrameSize = 256 * 1024 // 256 KB per frame for fast tests + testFileSize = testFrameSize * 4 ) -// slowUpstream simulates GCS: implements both SeekableReader and StreamingReader. -// OpenRangeReader returns a reader that yields blockSize bytes per Read() call -// with a configurable delay between calls. -type slowUpstream struct { - data []byte - blockSize int64 - delay time.Duration -} - -var ( - _ storage.SeekableReader = (*slowUpstream)(nil) - _ storage.StreamingReader = (*slowUpstream)(nil) -) - -func (s *slowUpstream) ReadAt(_ context.Context, buffer []byte, off int64) (int, error) { - end := min(off+int64(len(buffer)), int64(len(s.data))) - n := copy(buffer, s.data[off:end]) +func newTestMetrics(tb testing.TB) metrics.Metrics { + tb.Helper() - return n, nil -} + m, err := metrics.NewMetrics(noop.NewMeterProvider()) + require.NoError(tb, err) -func (s *slowUpstream) Size(_ context.Context) (int64, error) { - return int64(len(s.data)), nil + return m } -func (s *slowUpstream) OpenRangeReader(_ context.Context, off, length int64) (io.ReadCloser, error) { - end := min(off+length, int64(len(s.data))) +func makeTestData(size int) []byte { + rng := rand.New(rand.NewPCG(42, 0)) //nolint:gosec // deterministic test data + data := make([]byte, size) + for i := range data { + data[i] = byte(rng.IntN(256)) + } - return &slowReader{ - data: s.data[off:end], - blockSize: int(s.blockSize), - delay: s.delay, - }, nil + return data } -type slowReader struct { - data []byte - pos int - blockSize int - delay time.Duration +// fakeSeekable implements storage.Seekable backed by in-memory data. +// When ctrl is non-nil, reads are gated through its channels for concurrency tests. +type fakeSeekable struct { + data []byte + failAfter int64 // >0: truncate reads at this offset; 0 = disabled + fetchCount atomic.Int64 + ctrl *testControl // nil = ungated immediate reads } -func (r *slowReader) Read(p []byte) (int, error) { - if r.pos >= len(r.data) { - return 0, io.EOF - } - - if r.delay > 0 { - time.Sleep(r.delay) - } - - end := min(r.pos+r.blockSize, len(r.data)) +var _ storage.Seekable = (*fakeSeekable)(nil) - n := copy(p, r.data[r.pos:end]) - r.pos += n - - if r.pos >= len(r.data) { - return n, io.EOF - } - - return n, nil +// testControl provides channel-based flow control for fakeSeekable. +type testControl struct { + advance chan struct{} // close to release reads + consumed chan struct{} // receives after each read step + opened chan struct{} // receives when OpenRangeReader is called + closed chan struct{} // receives when reader is closed (fetch done) + onOpen func() // optional callback on OpenRangeReader } -func (r *slowReader) Close() error { - return nil -} - -// fastUpstream simulates NFS: same interfaces but no delay. -type fastUpstream = slowUpstream +func newTestChunker(t *testing.T, file storage.Seekable, size int64) *Chunker { + t.Helper() + c, err := NewChunker(&featureflags.Client{}, size, testBlockSize, file, t.TempDir()+"/cache", newTestMetrics(t)) + require.NoError(t, err) -// streamingFunc adapts a function into a StreamingReader. -type streamingFunc func(ctx context.Context, off, length int64) (io.ReadCloser, error) + return c +} -func (f streamingFunc) OpenRangeReader(ctx context.Context, off, length int64) (io.ReadCloser, error) { - return f(ctx, off, length) +func (s *fakeSeekable) Size(_ context.Context) (int64, error) { + return int64(len(s.data)), nil } -// errorAfterNUpstream fails after reading n bytes. -type errorAfterNUpstream struct { - data []byte - failAfter int64 - blockSize int64 +func (s *fakeSeekable) StoreFile(context.Context, string, ...storage.PutOption) (*storage.FrameTable, [32]byte, error) { + panic("not used") } -var _ storage.StreamingReader = (*errorAfterNUpstream)(nil) +func (s *fakeSeekable) OpenRangeReader(_ context.Context, offsetU int64, length int64, frameTable *storage.FrameTable) (io.ReadCloser, error) { + s.fetchCount.Add(1) -func (u *errorAfterNUpstream) OpenRangeReader(_ context.Context, off, length int64) (io.ReadCloser, error) { - end := min(off+length, int64(len(u.data))) + if s.ctrl != nil { + if s.ctrl.onOpen != nil { + s.ctrl.onOpen() + } - return &errorAfterNReader{ - data: u.data[off:end], - blockSize: int(u.blockSize), - failAfter: int(u.failAfter - off), - }, nil -} + select { + case s.ctrl.opened <- struct{}{}: + default: + } -type errorAfterNReader struct { - data []byte - pos int - blockSize int - failAfter int -} + end := min(offsetU+length, int64(len(s.data))) -func (r *errorAfterNReader) Read(p []byte) (int, error) { - if r.pos >= len(r.data) { - return 0, io.EOF + return &controlledReader{ + data: s.data[offsetU:end], + step: max(16*1024, testBlockSize), + advance: s.ctrl.advance, + consumed: s.ctrl.consumed, + closed: s.ctrl.closed, + }, nil } - if r.pos >= r.failAfter { - return 0, errors.New("simulated upstream error") - } + var fetchOff, fetchLen int64 + if frameTable.IsCompressed() { + r, err := frameTable.LocateCompressed(offsetU) + if err != nil { + return nil, fmt.Errorf("frame lookup: %w", err) + } - end := min(r.pos+r.blockSize, len(r.data)) + fetchOff = r.Offset + fetchLen = int64(r.Length) + } else { + fetchOff = offsetU + fetchLen = length + } - n := copy(p, r.data[r.pos:end]) - r.pos += n + end := min(fetchOff+fetchLen, int64(len(s.data))) + if s.failAfter > 0 { + end = min(end, s.failAfter) + } - if r.pos >= len(r.data) { - return n, io.EOF + r := io.Reader(bytes.NewReader(s.data[fetchOff:end])) + if frameTable.IsCompressed() { + return storage.NewDecompressingReader(r, frameTable.CompressionType()) } - return n, nil + return io.NopCloser(r), nil } -func (r *errorAfterNReader) Close() error { - return nil -} +func makeCompressedTestData(tb testing.TB, data []byte) (*storage.FrameTable, *fakeSeekable) { + tb.Helper() -func newTestMetrics(t *testing.T) metrics.Metrics { - t.Helper() + ft, compressed, _, err := storage.CompressBytes(context.Background(), data, storage.CompressConfig{ + Enabled: true, + Type: "lz4", + EncoderConcurrency: 1, + FrameEncodeWorkers: 1, + FrameSizeKB: testFrameSize / 1024, + MinPartSizeMB: 50, + }) + require.NoError(tb, err) - m, err := metrics.NewMetrics(noop.NewMeterProvider()) - require.NoError(t, err) + return ft, &fakeSeekable{data: compressed} +} - return m +type chunkerTestCase struct { + name string + newChunker func(t *testing.T, data []byte) (*Chunker, *storage.FrameTable) } -func makeTestData(t *testing.T, size int) []byte { - t.Helper() +var allChunkerTestCases = []chunkerTestCase{ + { + name: "Compressed", + newChunker: func(t *testing.T, data []byte) (*Chunker, *storage.FrameTable) { + t.Helper() + ft, getter := makeCompressedTestData(t, data) - data := make([]byte, size) - _, err := rand.Read(data) - require.NoError(t, err) + return newTestChunker(t, getter, int64(len(data))), ft + }, + }, + { + name: "Uncompressed", + newChunker: func(t *testing.T, data []byte) (*Chunker, *storage.FrameTable) { + t.Helper() - return data + return newTestChunker(t, &fakeSeekable{data: data}, int64(len(data))), nil + }, + }, } -func TestStreamingChunker_BasicSlice(t *testing.T) { +func TestChunker_BasicSlice(t *testing.T) { t.Parallel() - data := makeTestData(t, storage.MemoryChunkSize) - upstream := &fastUpstream{data: data, blockSize: testBlockSize} + for _, tc := range allChunkerTestCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() - chunker, err := NewStreamingChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - 0, nil, - ) - require.NoError(t, err) - defer chunker.Close() + data := makeTestData(testFileSize) + chunker, ft := tc.newChunker(t, data) + defer chunker.Close() - // Read first page - slice, err := chunker.Slice(t.Context(), 0, testBlockSize) - require.NoError(t, err) - require.Equal(t, data[:testBlockSize], slice) + slice, err := chunker.Slice(t.Context(), 0, testBlockSize, ft) + require.NoError(t, err) + require.Equal(t, data[:testBlockSize], slice) + }) + } } -func TestStreamingChunker_CacheHit(t *testing.T) { +// TestChunker_CacheHit verifies that a second read of the same block +// is served from cache without an additional upstream fetch. +func TestChunker_CacheHit(t *testing.T) { t.Parallel() - data := makeTestData(t, storage.MemoryChunkSize) - readCount := atomic.Int64{} - - upstream := &countingUpstream{ - inner: &fastUpstream{data: data, blockSize: testBlockSize}, - readCount: &readCount, - } + data := makeTestData(testFileSize) - chunker, err := NewStreamingChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - 0, nil, - ) - require.NoError(t, err) + // Uncompressed only — we need direct access to the fakeSeekable to count fetches. + file := &fakeSeekable{data: data} + chunker := newTestChunker(t, file, int64(len(data))) defer chunker.Close() - // First read: triggers fetch - _, err = chunker.Slice(t.Context(), 0, testBlockSize) + // First read triggers a fetch. + slice1, err := chunker.Slice(t.Context(), 0, testBlockSize, nil) require.NoError(t, err) + require.Equal(t, data[:testBlockSize], slice1) - // Wait for the full chunk to be fetched - time.Sleep(50 * time.Millisecond) - - firstCount := readCount.Load() - require.Positive(t, firstCount) + firstFetches := file.fetchCount.Load() + require.Positive(t, firstFetches) - // Second read: should hit cache - slice, err := chunker.Slice(t.Context(), 0, testBlockSize) + // Second read of the same block — should hit cache. + slice2, err := chunker.Slice(t.Context(), 0, testBlockSize, nil) require.NoError(t, err) - require.Equal(t, data[:testBlockSize], slice) - - // No additional reads should have happened - assert.Equal(t, firstCount, readCount.Load()) -} - -type countingUpstream struct { - inner *fastUpstream - readCount *atomic.Int64 -} - -var ( - _ storage.SeekableReader = (*countingUpstream)(nil) - _ storage.StreamingReader = (*countingUpstream)(nil) -) - -func (c *countingUpstream) ReadAt(ctx context.Context, buffer []byte, off int64) (int, error) { - c.readCount.Add(1) - - return c.inner.ReadAt(ctx, buffer, off) -} - -func (c *countingUpstream) Size(ctx context.Context) (int64, error) { - return c.inner.Size(ctx) -} - -func (c *countingUpstream) OpenRangeReader(ctx context.Context, off, length int64) (io.ReadCloser, error) { - c.readCount.Add(1) - - return c.inner.OpenRangeReader(ctx, off, length) + require.Equal(t, data[:testBlockSize], slice2) + require.Equal(t, firstFetches, file.fetchCount.Load(), "expected no additional upstream fetch") } -func TestStreamingChunker_FullChunkCachedAfterPartialRequest(t *testing.T) { +// TestChunker_FullChunkCachedAfterPartialRequest verifies that requesting the +// first block triggers a full background fetch of the entire chunk/frame, so +// the last block becomes available without additional upstream fetches. +func TestChunker_FullChunkCachedAfterPartialRequest(t *testing.T) { t.Parallel() - data := makeTestData(t, storage.MemoryChunkSize) - openCount := atomic.Int64{} + for _, tc := range allChunkerTestCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() - upstream := &countingUpstream{ - inner: &fastUpstream{data: data, blockSize: testBlockSize}, - readCount: &openCount, - } + data := makeTestData(testFileSize) + chunker, ft := tc.newChunker(t, data) + defer chunker.Close() - chunker, err := NewStreamingChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - 0, nil, - ) - require.NoError(t, err) - defer chunker.Close() + _, err := chunker.Slice(t.Context(), 0, testBlockSize, ft) + require.NoError(t, err) - // Request only the FIRST block of the 4MB chunk. - _, err = chunker.Slice(t.Context(), 0, testBlockSize) - require.NoError(t, err) - - // The background goroutine should continue fetching the remaining data. - // Use a blocking Slice call (with timeout) instead of require.Eventually - // to avoid racing condition goroutines against defer chunker.Close(). - lastOff := int64(storage.MemoryChunkSize) - testBlockSize - ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) - defer cancel() - - slice, err := chunker.Slice(ctx, lastOff, testBlockSize) - require.NoError(t, err) - require.True(t, bytes.Equal(data[lastOff:], slice)) - - // Exactly one OpenRangeReader call should have been made for the entire - // chunk, not one per requested block. - assert.Equal(t, int64(1), openCount.Load(), - "expected 1 OpenRangeReader call (full chunk fetched in background), got %d", openCount.Load()) + // The second Slice joins the in-flight session (or hits + // cache if the fetch already completed). Either way it blocks + // until the data is available — no polling needed. + lastOff := int64(testFileSize) - testBlockSize + slice, err := chunker.Slice(t.Context(), lastOff, testBlockSize, ft) + require.NoError(t, err) + require.Equal(t, data[lastOff:lastOff+testBlockSize], slice) + }) + } } -func TestStreamingChunker_ConcurrentSameChunk(t *testing.T) { +// TestChunker_ConcurrentSameChunk verifies that concurrent requests for the same +// chunk don't cause duplicate upstream fetches. +func TestChunker_ConcurrentSameChunk(t *testing.T) { t.Parallel() - data := makeTestData(t, storage.MemoryChunkSize) - // Use a slow upstream so requests will overlap - upstream := &slowUpstream{ - data: data, - blockSize: testBlockSize, - delay: 50 * time.Microsecond, - } + data := makeTestData(testFileSize) - chunker, err := NewStreamingChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - 0, nil, - ) - require.NoError(t, err) + var fetchCount atomic.Int64 + chunker := newControlledChunker(t, data) + chunker.onOpen = func() { fetchCount.Add(1) } defer chunker.Close() - numGoroutines := 10 - offsets := make([]int64, numGoroutines) - for i := range numGoroutines { - offsets[i] = int64(i) * testBlockSize - } - - results := make([][]byte, numGoroutines) + const numGoroutines = 10 var eg errgroup.Group - - for i := range numGoroutines { + started := make(chan struct{}) + for range numGoroutines { eg.Go(func() error { - slice, err := chunker.Slice(t.Context(), offsets[i], testBlockSize) - if err != nil { - return fmt.Errorf("goroutine %d failed: %w", i, err) - } - results[i] = make([]byte, len(slice)) - copy(results[i], slice) + <-started + _, sliceErr := chunker.Slice(t.Context(), 0, testBlockSize, nil) - return nil + return sliceErr }) } + // Release goroutines, wait for the fetch to start (blocked on advance), + // then release data. + close(started) + <-chunker.opened + close(chunker.advance) + require.NoError(t, eg.Wait()) - for i := range numGoroutines { - require.Equal(t, data[offsets[i]:offsets[i]+testBlockSize], results[i], - "goroutine %d got wrong data", i) - } + require.Equal(t, int64(1), fetchCount.Load(), + "expected 1 fetch (dedup), got %d", fetchCount.Load()) } -func TestStreamingChunker_ErrorKeepsPartialData(t *testing.T) { +func TestChunker_EarlyReturn(t *testing.T) { t.Parallel() - chunkSize := storage.MemoryChunkSize - data := makeTestData(t, chunkSize) - failAfter := int64(chunkSize / 2) // Fail at 2MB + data := makeTestData(testFileSize) + chunker := newControlledChunker(t, data) + defer chunker.Close() - upstream := &errorAfterNUpstream{ - data: data, - failAfter: failAfter, - blockSize: testBlockSize, - } + lastOff := int64(len(data)) - testBlockSize - chunker, err := NewStreamingChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - 0, nil, - ) - require.NoError(t, err) - defer chunker.Close() + type result struct { + data []byte + err error + } - // Request the last page — this should fail because upstream dies at 2MB - lastOff := int64(chunkSize) - testBlockSize - _, err = chunker.Slice(t.Context(), lastOff, testBlockSize) - require.Error(t, err) + earlyDone := make(chan result, 1) + lateDone := make(chan result, 1) + + go func() { + slice, sliceErr := chunker.Slice(t.Context(), 0, testBlockSize, nil) + earlyDone <- result{data: bytes.Clone(slice), err: sliceErr} // clone: slice backed by mutable mmap + }() + go func() { + slice, sliceErr := chunker.Slice(t.Context(), lastOff, testBlockSize, nil) + lateDone <- result{data: bytes.Clone(slice), err: sliceErr} + }() + + // Advance exactly one read step (16KB). This covers offset 0 but is + // far from the last block, and no further reads can proceed until we + // send more signals — eliminating the scheduling race. + chunker.advance <- struct{}{} + <-chunker.consumed + + // Offset 0 is within the first readBatch — should be available now. + r := <-earlyDone + require.NoError(t, r.err) + require.Equal(t, data[:testBlockSize], r.data) + + // No more reads have been allowed, so the last offset is unreachable. + select { + case <-lateDone: + t.Fatal("late reader completed before its data was delivered") + default: + } - // But first page (within first 2MB) should still be cached and servable - slice, err := chunker.Slice(t.Context(), 0, testBlockSize) - require.NoError(t, err) - require.Equal(t, data[:testBlockSize], slice) + // Release all remaining reads so the late reader can complete. + close(chunker.advance) + r = <-lateDone + require.NoError(t, r.err) + require.Equal(t, data[lastOff:lastOff+testBlockSize], r.data) } -func TestStreamingChunker_ContextCancellation(t *testing.T) { +// TestChunker_ErrorKeepsPartialData verifies that an upstream error at the +// midpoint of a chunk still allows data before the error to be served. +func TestChunker_ErrorKeepsPartialData(t *testing.T) { t.Parallel() - data := makeTestData(t, storage.MemoryChunkSize) - upstream := &slowUpstream{ - data: data, - blockSize: testBlockSize, - delay: 1 * time.Millisecond, - } + data := makeTestData(testFileSize) - chunker, err := NewStreamingChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - 0, nil, - ) - require.NoError(t, err) + chunker := newTestChunker(t, &fakeSeekable{data: data, failAfter: int64(testFileSize / 2)}, int64(len(data))) defer chunker.Close() - // Request with a context that we'll cancel quickly - ctx, cancel := context.WithTimeout(t.Context(), 1*time.Millisecond) - defer cancel() - - lastOff := int64(storage.MemoryChunkSize) - testBlockSize - _, err = chunker.Slice(ctx, lastOff, testBlockSize) - // This should fail with context cancellation + lastOff := int64(testFileSize) - testBlockSize + _, err := chunker.Slice(t.Context(), lastOff, testBlockSize, nil) require.Error(t, err) - // But another caller with a valid context should still get the data - // because the fetch goroutine uses background context - time.Sleep(200 * time.Millisecond) // Wait for fetch to complete - slice, err := chunker.Slice(t.Context(), 0, testBlockSize) + slice, err := chunker.Slice(t.Context(), 0, testBlockSize, nil) require.NoError(t, err) require.Equal(t, data[:testBlockSize], slice) } -func TestStreamingChunker_LastBlockPartial(t *testing.T) { +// TestChunker_ContextCancellation verifies that a cancelled caller context +// doesn't kill the background fetch — another caller can still get data. +func TestChunker_ContextCancellation(t *testing.T) { t.Parallel() - // File size not aligned to blockSize - size := storage.MemoryChunkSize - 100 - data := makeTestData(t, size) - upstream := &fastUpstream{data: data, blockSize: testBlockSize} - - chunker, err := NewStreamingChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - 0, nil, - ) - require.NoError(t, err) + data := makeTestData(testFileSize) + chunker := newControlledChunker(t, data) defer chunker.Close() - // Read the last partial block - lastBlockOff := (int64(size) / testBlockSize) * testBlockSize - remaining := int64(size) - lastBlockOff + ctx, cancel := context.WithCancel(t.Context()) + + done := make(chan error, 1) + go func() { + _, sliceErr := chunker.Slice(ctx, 0, testBlockSize, nil) + done <- sliceErr + }() + + // Wait for the fetch goroutine to be blocked on the reader, then cancel. + <-chunker.opened + cancel() + + require.Error(t, <-done) - slice, err := chunker.Slice(t.Context(), lastBlockOff, remaining) + // Release the fetch — it runs with context.WithoutCancel so it continues. + close(chunker.advance) + <-chunker.closed + + // Fetch completed — data is now cached. + slice, err := chunker.Slice(t.Context(), 0, testBlockSize, nil) require.NoError(t, err) - require.Equal(t, data[lastBlockOff:], slice) + require.Equal(t, data[:testBlockSize], slice) } -func TestStreamingChunker_MultiChunkSlice(t *testing.T) { +// TestChunker_LastBlockPartial verifies correct handling of a file whose size +// is not aligned to blockSize — the final block is shorter than blockSize. +func TestChunker_LastBlockPartial(t *testing.T) { t.Parallel() - // Two 4MB chunks - size := storage.MemoryChunkSize * 2 - data := makeTestData(t, size) - upstream := &fastUpstream{data: data, blockSize: testBlockSize} - - chunker, err := NewStreamingChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - 0, nil, - ) - require.NoError(t, err) - defer chunker.Close() + size := testFileSize - 100 + data := makeTestData(size) - // Request spanning two chunks: last page of chunk 0 + first page of chunk 1 - off := int64(storage.MemoryChunkSize) - testBlockSize - length := testBlockSize * 2 + for _, tc := range allChunkerTestCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() - slice, err := chunker.Slice(t.Context(), off, int64(length)) - require.NoError(t, err) - require.Equal(t, data[off:off+int64(length)], slice) + chunker, ft := tc.newChunker(t, data) + defer chunker.Close() + + lastBlockOff := (int64(size) / testBlockSize) * testBlockSize + remaining := int64(size) - lastBlockOff + + slice, err := chunker.Slice(t.Context(), lastBlockOff, remaining, ft) + require.NoError(t, err) + require.Equal(t, data[lastBlockOff:], slice) + }) + } } -// panicUpstream panics during Read after delivering a configurable number of bytes. -type panicUpstream struct { +// panicSeekable panics during Read after delivering panicAfter bytes. +type panicSeekable struct { data []byte - blockSize int64 - panicAfter int64 // byte offset at which to panic (0 = panic immediately) + panicAfter int64 } -var _ storage.StreamingReader = (*panicUpstream)(nil) +var _ storage.Seekable = (*panicSeekable)(nil) -func (u *panicUpstream) OpenRangeReader(_ context.Context, off, length int64) (io.ReadCloser, error) { - end := min(off+length, int64(len(u.data))) +func (s *panicSeekable) Size(_ context.Context) (int64, error) { + return int64(len(s.data)), nil +} + +func (s *panicSeekable) StoreFile(context.Context, string, ...storage.PutOption) (*storage.FrameTable, [32]byte, error) { + panic("not used") +} + +func (s *panicSeekable) OpenRangeReader(_ context.Context, off int64, length int64, _ *storage.FrameTable) (io.ReadCloser, error) { + end := min(off+length, int64(len(s.data))) return &panicReader{ - data: u.data[off:end], - blockSize: int(u.blockSize), - panicAfter: int(u.panicAfter - off), + data: s.data[off:end], + panicAfter: int(s.panicAfter - off), }, nil } type panicReader struct { data []byte pos int - blockSize int panicAfter int } @@ -507,7 +450,7 @@ func (r *panicReader) Read(p []byte) (int, error) { return 0, io.EOF } - end := min(r.pos+r.blockSize, len(r.data)) + end := min(r.pos+len(p), len(r.data)) n := copy(p, r.data[r.pos:end]) r.pos += n @@ -518,340 +461,125 @@ func (r *panicReader) Close() error { return nil } -func TestStreamingChunker_PanicRecovery(t *testing.T) { +func TestChunker_PanicRecovery(t *testing.T) { t.Parallel() - data := makeTestData(t, storage.MemoryChunkSize) - panicAt := int64(storage.MemoryChunkSize / 2) // Panic at 2MB - - upstream := &panicUpstream{ - data: data, - blockSize: testBlockSize, - panicAfter: panicAt, - } + data := makeTestData(testFileSize) + panicAt := int64(testFileSize / 2) - chunker, err := NewStreamingChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - 0, nil, - ) - require.NoError(t, err) + chunker := newTestChunker(t, &panicSeekable{data: data, panicAfter: panicAt}, int64(len(data))) defer chunker.Close() // Request data past the panic point — should get an error, not hang or crash - lastOff := int64(storage.MemoryChunkSize) - testBlockSize - _, err = chunker.Slice(t.Context(), lastOff, testBlockSize) + lastOff := int64(testFileSize) - testBlockSize + _, err := chunker.Slice(t.Context(), lastOff, testBlockSize, nil) require.Error(t, err) - assert.Contains(t, err.Error(), "panicked") // Data before the panic point should still be cached - slice, err := chunker.Slice(t.Context(), 0, testBlockSize) + slice, err := chunker.Slice(t.Context(), 0, testBlockSize, nil) require.NoError(t, err) require.Equal(t, data[:testBlockSize], slice) } -func TestStreamingChunker_ConcurrentSameChunk_SharedSession(t *testing.T) { +func TestChunker_ConcurrentStress(t *testing.T) { t.Parallel() - data := makeTestData(t, storage.MemoryChunkSize) - - gate := make(chan struct{}) - openCount := atomic.Int64{} - - // OpenRangeReader blocks on the gate, keeping the session in fetchMap - // until both callers have entered. This removes the scheduling-dependent - // race in the old slow-upstream version of this test. - upstream := streamingFunc(func(_ context.Context, off, length int64) (io.ReadCloser, error) { - openCount.Add(1) - <-gate - - end := min(off+length, int64(len(data))) - - return io.NopCloser(bytes.NewReader(data[off:end])), nil - }) - - chunker, err := NewStreamingChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - 0, nil, - ) - require.NoError(t, err) - defer chunker.Close() - - // Two different ranges inside the same 4MB chunk. - offA := int64(0) - offB := int64(storage.MemoryChunkSize) - testBlockSize // last block - - var eg errgroup.Group - var sliceA, sliceB []byte - - eg.Go(func() error { - s, err := chunker.Slice(t.Context(), offA, testBlockSize) - if err != nil { - return err - } - sliceA = make([]byte, len(s)) - copy(sliceA, s) - - return nil - }) - eg.Go(func() error { - s, err := chunker.Slice(t.Context(), offB, testBlockSize) - if err != nil { - return err - } - sliceB = make([]byte, len(s)) - copy(sliceB, s) - - return nil - }) - - // Let both goroutines enter getOrCreateSession, then release the fetch. - time.Sleep(10 * time.Millisecond) - close(gate) - - require.NoError(t, eg.Wait()) - - assert.Equal(t, data[offA:offA+testBlockSize], sliceA) - assert.Equal(t, data[offB:offB+testBlockSize], sliceB) - assert.Equal(t, int64(1), openCount.Load(), - "expected exactly 1 OpenRangeReader call (shared session), got %d", openCount.Load()) -} - -// --- Benchmarks --- -// -// Uses a bandwidth-limited upstream with real time.Sleep to simulate GCS and -// NFS backends. Measures actual wall-clock latency per caller. -// -// Backend parameters (tuned to match observed production latencies): -// GCS: 20ms TTFB + 100 MB/s → 4MB chunk ā‰ˆ 62ms (observed ~60ms) -// NFS: 1ms TTFB + 500 MB/s → 4MB chunk ā‰ˆ 9ms (observed ~9-10ms) -// -// All sub-benchmarks share a pre-generated offset sequence so results are -// directly comparable across chunker types and backends. -// -// Recommended invocation (~1 minute): -// go test -bench BenchmarkRandomAccess -benchtime 150x -count=3 -run '^$' ./... - -func newBenchmarkMetrics(b *testing.B) metrics.Metrics { - b.Helper() + for _, tc := range allChunkerTestCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + data := makeTestData(testFileSize) + chunker, ft := tc.newChunker(t, data) + defer chunker.Close() + + const numGoroutines = 50 + const opsPerGoroutine = 5 + readLen := int64(testBlockSize) + + var eg errgroup.Group + + for i := range numGoroutines { + eg.Go(func() error { + for j := range opsPerGoroutine { + off := int64(((i*opsPerGoroutine)+j)%(len(data)/int(readLen))) * readLen + slice, err := chunker.Slice(t.Context(), off, readLen, ft) + if err != nil { + return fmt.Errorf("goroutine %d op %d: %w", i, j, err) + } + if !bytes.Equal(data[off:off+readLen], slice) { + return fmt.Errorf("goroutine %d op %d: data mismatch at off=%d", i, j, off) + } + } - m, err := metrics.NewMetrics(noop.NewMeterProvider()) - require.NoError(b, err) + return nil + }) + } - return m + require.NoError(t, eg.Wait()) + }) + } } -// realisticUpstream simulates a storage backend with configurable time-to-first-byte -// and bandwidth. ReadAt blocks for the full transfer duration (bulk fetch model). -// OpenRangeReader returns a bandwidth-limited progressive reader. -type realisticUpstream struct { - data []byte - blockSize int64 - ttfb time.Duration - bytesPerSec float64 +// controlledChunker wraps a Chunker with channel-based flow control for tests. +// advance gates reads; opened/consumed/closed signal fetch lifecycle events. +type controlledChunker struct { + *Chunker + *testControl } -var ( - _ storage.SeekableReader = (*realisticUpstream)(nil) - _ storage.StreamingReader = (*realisticUpstream)(nil) -) +func newControlledChunker(t *testing.T, data []byte) *controlledChunker { + t.Helper() -func (u *realisticUpstream) ReadAt(_ context.Context, buffer []byte, off int64) (int, error) { - transferTime := time.Duration(float64(len(buffer)) / u.bytesPerSec * float64(time.Second)) - time.Sleep(u.ttfb + transferTime) + ctrl := &testControl{ + advance: make(chan struct{}), + consumed: make(chan struct{}, 10), + opened: make(chan struct{}, 10), + closed: make(chan struct{}, 10), + } - end := min(off+int64(len(buffer)), int64(len(u.data))) - n := copy(buffer, u.data[off:end]) + file := &fakeSeekable{data: data, ctrl: ctrl} - return n, nil + return &controlledChunker{ + Chunker: newTestChunker(t, file, int64(len(data))), + testControl: ctrl, + } } -func (u *realisticUpstream) Size(_ context.Context) (int64, error) { - return int64(len(u.data)), nil +// controlledReader yields data in fixed-size steps, blocking on advance +// before each Read. After advance is closed, reads proceed immediately. +type controlledReader struct { + data []byte + pos int + step int + advance chan struct{} + consumed chan struct{} + closed chan struct{} } -func (u *realisticUpstream) OpenRangeReader(_ context.Context, off, length int64) (io.ReadCloser, error) { - end := min(off+length, int64(len(u.data))) - - return &bandwidthReader{ - data: u.data[off:end], - blockSize: int(u.blockSize), - ttfb: u.ttfb, - bytesPerSec: u.bytesPerSec, - }, nil -} - -// bandwidthReader delivers data at a steady rate after an initial TTFB delay. -// Uses cumulative timing (time since first byte) so OS scheduling jitter does -// not compound across blocks. -type bandwidthReader struct { - data []byte - pos int - blockSize int - ttfb time.Duration - bytesPerSec float64 - startTime time.Time - started bool -} - -func (r *bandwidthReader) Read(p []byte) (int, error) { - if !r.started { - r.started = true - time.Sleep(r.ttfb) - r.startTime = time.Now() - } - +func (r *controlledReader) Read(p []byte) (int, error) { if r.pos >= len(r.data) { return 0, io.EOF } - end := min(r.pos+r.blockSize, len(r.data)) + <-r.advance + + end := min(r.pos+min(len(p), r.step), len(r.data)) n := copy(p, r.data[r.pos:end]) r.pos += n - // Enforce bandwidth: sleep until this many bytes should have arrived. - expectedArrival := r.startTime.Add(time.Duration(float64(r.pos) / r.bytesPerSec * float64(time.Second))) - if wait := time.Until(expectedArrival); wait > 0 { - time.Sleep(wait) - } - - if r.pos >= len(r.data) { - return n, io.EOF + select { + case r.consumed <- struct{}{}: + default: } return n, nil } -func (r *bandwidthReader) Close() error { - return nil -} - -type benchChunker interface { - Slice(ctx context.Context, off, length int64) ([]byte, error) - Close() error -} - -func BenchmarkRandomAccess(b *testing.B) { - size := int64(storage.MemoryChunkSize) - data := make([]byte, size) - - backends := []struct { - name string - upstream *realisticUpstream - }{ - { - name: "GCS", - upstream: &realisticUpstream{ - data: data, - blockSize: testBlockSize, - ttfb: 20 * time.Millisecond, - bytesPerSec: 100e6, // 100 MB/s — full 4MB chunk ā‰ˆ 62ms (observed ~60ms) - }, - }, - { - name: "NFS", - upstream: &realisticUpstream{ - data: data, - blockSize: testBlockSize, - ttfb: 1 * time.Millisecond, - bytesPerSec: 500e6, // 500 MB/s — full 4MB chunk ā‰ˆ 9ms (observed ~9-10ms) - }, - }, - } - - chunkerTypes := []struct { - name string - newChunker func(b *testing.B, m metrics.Metrics, upstream *realisticUpstream) benchChunker - }{ - { - name: "StreamingChunker", - newChunker: func(b *testing.B, m metrics.Metrics, upstream *realisticUpstream) benchChunker { - b.Helper() - c, err := NewStreamingChunker(size, testBlockSize, upstream, b.TempDir()+"/cache", m, 0, nil) - require.NoError(b, err) - - return c - }, - }, - { - name: "FullFetchChunker", - newChunker: func(b *testing.B, m metrics.Metrics, upstream *realisticUpstream) benchChunker { - b.Helper() - c, err := NewFullFetchChunker(size, testBlockSize, upstream, b.TempDir()+"/cache", m) - require.NoError(b, err) - - return c - }, - }, - } - - // Realistic concurrency: UFFD faults are limited by vCPU count (typically - // 1-2 for Firecracker VMs) and NBD requests are largely sequential. - const numCallers = 3 - - // Pre-generate a fixed sequence of random offsets so all sub-benchmarks - // use identical access patterns, making results directly comparable. - const maxIters = 500 - numBlocks := size / testBlockSize - rng := mathrand.New(mathrand.NewPCG(42, 0)) - - allOffsets := make([][]int64, maxIters) - for i := range allOffsets { - offsets := make([]int64, numCallers) - for j := range offsets { - offsets[j] = rng.Int64N(numBlocks) * testBlockSize - } - allOffsets[i] = offsets +func (r *controlledReader) Close() error { + select { + case r.closed <- struct{}{}: + default: } - for _, backend := range backends { - for _, ct := range chunkerTypes { - b.Run(backend.name+"/"+ct.name, func(b *testing.B) { - m := newBenchmarkMetrics(b) - - b.ReportMetric(0, "ns/op") - - var sumAvg, sumMax float64 - - for i := range b.N { - offsets := allOffsets[i%maxIters] - - chunker := ct.newChunker(b, m, backend.upstream) - - latencies := make([]time.Duration, numCallers) - - var eg errgroup.Group - for ci, off := range offsets { - eg.Go(func() error { - start := time.Now() - _, err := chunker.Slice(context.Background(), off, testBlockSize) - latencies[ci] = time.Since(start) - - return err - }) - } - require.NoError(b, eg.Wait()) - - var totalLatency time.Duration - var maxLatency time.Duration - for _, l := range latencies { - totalLatency += l - maxLatency = max(maxLatency, l) - } - - avgUs := float64(totalLatency.Microseconds()) / float64(numCallers) - sumAvg += avgUs - sumMax = max(sumMax, float64(maxLatency.Microseconds())) - - chunker.Close() - } - - b.ReportMetric(sumAvg/float64(b.N), "avg-us/caller") - b.ReportMetric(sumMax, "worst-us/caller") - }) - } - } + return nil } diff --git a/packages/orchestrator/pkg/sandbox/build/build.go b/packages/orchestrator/pkg/sandbox/build/build.go index 149caab3d6..c01cd4f3c1 100644 --- a/packages/orchestrator/pkg/sandbox/build/build.go +++ b/packages/orchestrator/pkg/sandbox/build/build.go @@ -2,10 +2,14 @@ package build import ( "context" + "errors" "fmt" "io" + "sync/atomic" + "time" "github.com/google/uuid" + "go.uber.org/zap" blockmetrics "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/block/metrics" "github.com/e2b-dev/infra/packages/shared/pkg/logger" @@ -13,8 +17,12 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" ) +// swapReadHeaderBudget bounds how long the read-path swap polls GCS for the +// V4 header to appear. +const swapReadHeaderBudget = 30 * time.Second + type File struct { - header *header.Header + header atomic.Pointer[header.Header] store *DiffStore fileType DiffType persistence storage.StorageProvider @@ -28,25 +36,36 @@ func NewFile( persistence storage.StorageProvider, metrics blockmetrics.Metrics, ) *File { - return &File{ - header: header, + f := &File{ store: store, fileType: fileType, persistence: persistence, metrics: metrics, } + f.header.Store(header) + + return f +} + +func (b *File) Header() *header.Header { + return b.header.Load() +} + +func (b *File) SwapHeader(h *header.Header) { + b.header.Store(h) } func (b *File) ReadAt(ctx context.Context, p []byte, off int64) (n int, err error) { for n < len(p) { - mappedOffset, mappedLength, buildID, err := b.header.GetShiftedMapping(ctx, off+int64(n)) + h := b.header.Load() + + mappedToBuild, err := h.GetShiftedMapping(ctx, off+int64(n)) if err != nil { return 0, fmt.Errorf("failed to get mapping: %w", err) } remainingReadLength := int64(len(p)) - int64(n) - - readLength := min(mappedLength, remainingReadLength) + readLength := min(int64(mappedToBuild.Length), remainingReadLength) if readLength <= 0 { logger.L().Error(ctx, fmt.Sprintf( @@ -54,13 +73,13 @@ func (b *File) ReadAt(ctx context.Context, p []byte, off int64) (n int, err erro len(p)-n, off, readLength, - buildID, + mappedToBuild.BuildId, b.fileType, - mappedOffset, + mappedToBuild.Offset, n, int64(n)+readLength, n, - mappedLength, + mappedToBuild.Length, remainingReadLength, )) @@ -70,22 +89,31 @@ func (b *File) ReadAt(ctx context.Context, p []byte, off int64) (n int, err erro // Skip reading when the uuid is nil. // We will use this to handle base builds that are already diffs. // The passed slice p must start as empty, otherwise we would need to copy the empty values there. - if *buildID == uuid.Nil { + if mappedToBuild.BuildId == uuid.Nil { n += int(readLength) continue } - mappedBuild, err := b.getBuild(ctx, buildID) + size := b.buildFileSize(h, mappedToBuild.BuildId) + ft := h.GetBuildFrameData(mappedToBuild.BuildId) + mappedBuild, err := b.getBuild(ctx, mappedToBuild.BuildId, size, ft.CompressionType()) if err != nil { return 0, fmt.Errorf("failed to get build: %w", err) } buildN, err := mappedBuild.ReadAt(ctx, p[n:int64(n)+readLength], - mappedOffset, + int64(mappedToBuild.Offset), + ft, ) if err != nil { + if retry, swapErr := b.retryOnTransition(ctx, err); retry { + continue + } else if swapErr != nil { + return 0, swapErr + } + return 0, fmt.Errorf("failed to read from source: %w", err) } @@ -97,32 +125,84 @@ func (b *File) ReadAt(ctx context.Context, p []byte, off int64) (n int, err erro // The slice access must be in the predefined blocksize of the build. func (b *File) Slice(ctx context.Context, off, _ int64) ([]byte, error) { - mappedOffset, _, buildID, err := b.header.GetShiftedMapping(ctx, off) - if err != nil { - return nil, fmt.Errorf("failed to get mapping: %w", err) + for { + h := b.header.Load() + + mappedBuild, err := h.GetShiftedMapping(ctx, off) + if err != nil { + return nil, fmt.Errorf("failed to get mapping: %w", err) + } + + // Pass empty huge page when the build id is nil. + if mappedBuild.BuildId == uuid.Nil { + return header.EmptyHugePage, nil + } + + size := b.buildFileSize(h, mappedBuild.BuildId) + ft := h.GetBuildFrameData(mappedBuild.BuildId) + diff, err := b.getBuild(ctx, mappedBuild.BuildId, size, ft.CompressionType()) + if err != nil { + return nil, fmt.Errorf("failed to get build: %w", err) + } + + result, err := diff.Slice(ctx, int64(mappedBuild.Offset), int64(h.Metadata.BlockSize), ft) + if err != nil { + if retry, swapErr := b.retryOnTransition(ctx, err); retry { + continue + } else if swapErr != nil { + return nil, swapErr + } + + return nil, err + } + + return result, nil } +} - // Pass empty huge page when the build id is nil. - if *buildID == uuid.Nil { - return header.EmptyHugePage, nil +// retryOnTransition catches a PeerTransitionedError and swaps the header from +// storage. Returns (true, nil) to signal the caller should continue the loop, +// or (false, swapErr) if the swap itself failed. peerSeekable emits the +// transition error at most once per seekable, so the loop is naturally +// bounded — no retry counter needed here. +func (b *File) retryOnTransition(ctx context.Context, err error) (bool, error) { + var transErr *storage.PeerTransitionedError + if !errors.As(err, &transErr) { + return false, nil } - build, err := b.getBuild(ctx, buildID) - if err != nil { - return nil, fmt.Errorf("failed to get build: %w", err) + logger.L().Info(ctx, "peer transition detected, swapping header", + zap.String("file_type", string(b.fileType)), + ) + + h, loadErr := PollRemoteStorageForHeader(ctx, b.persistence, b.header.Load().Metadata.BuildId, b.fileType, nil, swapReadHeaderBudget) + if loadErr != nil { + return false, fmt.Errorf("failed to swap header: %w", loadErr) + } + b.SwapHeader(h) + + return true, nil +} + +// buildFileSize returns the uncompressed file size for a build. Returns 0 for +// V3 headers, which signals the read path to fall back to a Size() RPC. +func (b *File) buildFileSize(h *header.Header, buildID uuid.UUID) int64 { + if bd, ok := h.Builds[buildID]; ok { + return bd.Size } - return build.Slice(ctx, mappedOffset, int64(b.header.Metadata.BlockSize)) + return 0 } -func (b *File) getBuild(ctx context.Context, buildID *uuid.UUID) (Diff, error) { +func (b *File) getBuild(ctx context.Context, buildID uuid.UUID, uncompressedSize int64, ct storage.CompressionType) (Diff, error) { storageDiff, err := newStorageDiff( b.store.cachePath, buildID.String(), b.fileType, - int64(b.header.Metadata.BlockSize), + int64(b.Header().Metadata.BlockSize), b.metrics, b.persistence, + uncompressedSize, ct, b.store.flags, ) if err != nil { diff --git a/packages/orchestrator/pkg/sandbox/build/diff.go b/packages/orchestrator/pkg/sandbox/build/diff.go index b817235aa9..5a0f33e318 100644 --- a/packages/orchestrator/pkg/sandbox/build/diff.go +++ b/packages/orchestrator/pkg/sandbox/build/diff.go @@ -27,10 +27,11 @@ const ( type Diff interface { io.Closer storage.SeekableReader - block.Slicer + block.FramedSlicer CacheKey() DiffStoreKey CachePath() (string, error) FileSize() (int64, error) + BlockSize() int64 Init(ctx context.Context) error } @@ -39,10 +40,10 @@ type NoDiff struct{} var _ Diff = (*NoDiff)(nil) func (n *NoDiff) CachePath() (string, error) { - return "", NoDiffError{} + return "", nil } -func (n *NoDiff) Slice(_ context.Context, _, _ int64) ([]byte, error) { +func (n *NoDiff) Slice(_ context.Context, _, _ int64, _ *storage.FrameTable) ([]byte, error) { return nil, NoDiffError{} } @@ -50,7 +51,7 @@ func (n *NoDiff) Close() error { return nil } -func (n *NoDiff) ReadAt(_ context.Context, _ []byte, _ int64) (int, error) { +func (n *NoDiff) ReadAt(_ context.Context, _ []byte, _ int64, _ *storage.FrameTable) (int, error) { return 0, NoDiffError{} } diff --git a/packages/orchestrator/pkg/sandbox/build/header_load.go b/packages/orchestrator/pkg/sandbox/build/header_load.go new file mode 100644 index 0000000000..f779aac6b8 --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/build/header_load.go @@ -0,0 +1,76 @@ +package build + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/google/uuid" + + "github.com/e2b-dev/infra/packages/shared/pkg/storage" + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +const ( + loadV4InitialBackoff = 100 * time.Millisecond + loadV4MaxBackoff = 5 * time.Second + loadV4MaxTransientErrors = 3 +) + +// PollRemoteStorageForHeader polls storage for the post-upload V4 header for buildID/fileType. +// ErrObjectNotExist is retried until the budget expires; other LoadHeader +// errors are tolerated up to loadV4MaxTransientErrors consecutive occurrences +// (e.g. transient GCS hiccups during the rare window between the upload-done +// signal and object visibility) before giving up. +// +// hint is an optional accelerator. A nil error received on the channel says +// "the upload just finished, poll storage now"; a non-nil error says "the +// upload failed" and PollRemoteStorageForHeader returns it immediately without further polling. +// A nil channel never fires, so callers without hint plumbing fall through to +// the ticker-only path. budget bounds total wait time. +func PollRemoteStorageForHeader( + ctx context.Context, + store storage.StorageProvider, + buildID uuid.UUID, + t DiffType, + hint <-chan error, + budget time.Duration, +) (*header.Header, error) { + hdrPath := storage.Paths{BuildID: buildID.String()}.HeaderFile(string(t)) + deadline := time.Now().Add(budget) + + backoff := loadV4InitialBackoff + transientErrs := 0 + for { + h, err := header.LoadHeader(ctx, store, hdrPath) + if err == nil { + return h, nil + } + if !errors.Is(err, storage.ErrObjectNotExist) { + transientErrs++ + if transientErrs >= loadV4MaxTransientErrors { + return nil, fmt.Errorf("load V4 header for %s/%s after %d attempts: %w", buildID, t, transientErrs, err) + } + } else { + transientErrs = 0 + } + if !time.Now().Before(deadline) { + return nil, fmt.Errorf("V4 header for %s/%s not visible after %s: %w", buildID, t, budget, err) + } + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case hintErr := <-hint: + if hintErr != nil { + return nil, fmt.Errorf("upload signaled failure for %s/%s: %w", buildID, t, hintErr) + } + backoff = loadV4InitialBackoff + case <-time.After(backoff): + if backoff < loadV4MaxBackoff { + backoff *= 2 + } + } + } +} diff --git a/packages/orchestrator/pkg/sandbox/build/local_diff.go b/packages/orchestrator/pkg/sandbox/build/local_diff.go index df5fec4ea7..ea43d9bb00 100644 --- a/packages/orchestrator/pkg/sandbox/build/local_diff.go +++ b/packages/orchestrator/pkg/sandbox/build/local_diff.go @@ -6,6 +6,7 @@ import ( "os" "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/block" + "github.com/e2b-dev/infra/packages/shared/pkg/storage" ) type LocalDiffFile struct { @@ -114,11 +115,11 @@ func (b *localDiff) Close() error { return b.cache.Close() } -func (b *localDiff) ReadAt(_ context.Context, p []byte, off int64) (int, error) { +func (b *localDiff) ReadAt(_ context.Context, p []byte, off int64, _ *storage.FrameTable) (int, error) { return b.cache.ReadAt(p, off) } -func (b *localDiff) Slice(_ context.Context, off, length int64) ([]byte, error) { +func (b *localDiff) Slice(_ context.Context, off, length int64, _ *storage.FrameTable) ([]byte, error) { return b.cache.Slice(off, length) } diff --git a/packages/orchestrator/pkg/sandbox/build/mocks/mockdiff.go b/packages/orchestrator/pkg/sandbox/build/mocks/mockdiff.go index ea61e38b25..b52ed79aad 100644 --- a/packages/orchestrator/pkg/sandbox/build/mocks/mockdiff.go +++ b/packages/orchestrator/pkg/sandbox/build/mocks/mockdiff.go @@ -8,6 +8,7 @@ import ( "context" "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/build" + "github.com/e2b-dev/infra/packages/shared/pkg/storage" mock "github.com/stretchr/testify/mock" ) @@ -328,8 +329,8 @@ func (_c *MockDiff_Init_Call) RunAndReturn(run func(ctx context.Context) error) } // ReadAt provides a mock function for the type MockDiff -func (_mock *MockDiff) ReadAt(ctx context.Context, buffer []byte, off int64) (int, error) { - ret := _mock.Called(ctx, buffer, off) +func (_mock *MockDiff) ReadAt(ctx context.Context, buffer []byte, off int64, ft *storage.FrameTable) (int, error) { + ret := _mock.Called(ctx, buffer, off, ft) if len(ret) == 0 { panic("no return value specified for ReadAt") @@ -337,16 +338,16 @@ func (_mock *MockDiff) ReadAt(ctx context.Context, buffer []byte, off int64) (in var r0 int var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, []byte, int64) (int, error)); ok { - return returnFunc(ctx, buffer, off) + if returnFunc, ok := ret.Get(0).(func(context.Context, []byte, int64, *storage.FrameTable) (int, error)); ok { + return returnFunc(ctx, buffer, off, ft) } - if returnFunc, ok := ret.Get(0).(func(context.Context, []byte, int64) int); ok { - r0 = returnFunc(ctx, buffer, off) + if returnFunc, ok := ret.Get(0).(func(context.Context, []byte, int64, *storage.FrameTable) int); ok { + r0 = returnFunc(ctx, buffer, off, ft) } else { r0 = ret.Get(0).(int) } - if returnFunc, ok := ret.Get(1).(func(context.Context, []byte, int64) error); ok { - r1 = returnFunc(ctx, buffer, off) + if returnFunc, ok := ret.Get(1).(func(context.Context, []byte, int64, *storage.FrameTable) error); ok { + r1 = returnFunc(ctx, buffer, off, ft) } else { r1 = ret.Error(1) } @@ -362,11 +363,12 @@ type MockDiff_ReadAt_Call struct { // - ctx context.Context // - buffer []byte // - off int64 -func (_e *MockDiff_Expecter) ReadAt(ctx interface{}, buffer interface{}, off interface{}) *MockDiff_ReadAt_Call { - return &MockDiff_ReadAt_Call{Call: _e.mock.On("ReadAt", ctx, buffer, off)} +// - ft *storage.FrameTable +func (_e *MockDiff_Expecter) ReadAt(ctx interface{}, buffer interface{}, off interface{}, ft interface{}) *MockDiff_ReadAt_Call { + return &MockDiff_ReadAt_Call{Call: _e.mock.On("ReadAt", ctx, buffer, off, ft)} } -func (_c *MockDiff_ReadAt_Call) Run(run func(ctx context.Context, buffer []byte, off int64)) *MockDiff_ReadAt_Call { +func (_c *MockDiff_ReadAt_Call) Run(run func(ctx context.Context, buffer []byte, off int64, ft *storage.FrameTable)) *MockDiff_ReadAt_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { @@ -380,10 +382,15 @@ func (_c *MockDiff_ReadAt_Call) Run(run func(ctx context.Context, buffer []byte, if args[2] != nil { arg2 = args[2].(int64) } + var arg3 *storage.FrameTable + if args[3] != nil { + arg3 = args[3].(*storage.FrameTable) + } run( arg0, arg1, arg2, + arg3, ) }) return _c @@ -394,7 +401,7 @@ func (_c *MockDiff_ReadAt_Call) Return(n int, err error) *MockDiff_ReadAt_Call { return _c } -func (_c *MockDiff_ReadAt_Call) RunAndReturn(run func(ctx context.Context, buffer []byte, off int64) (int, error)) *MockDiff_ReadAt_Call { +func (_c *MockDiff_ReadAt_Call) RunAndReturn(run func(ctx context.Context, buffer []byte, off int64, ft *storage.FrameTable) (int, error)) *MockDiff_ReadAt_Call { _c.Call.Return(run) return _c } @@ -460,8 +467,8 @@ func (_c *MockDiff_Size_Call) RunAndReturn(run func(ctx context.Context) (int64, } // Slice provides a mock function for the type MockDiff -func (_mock *MockDiff) Slice(ctx context.Context, off int64, length int64) ([]byte, error) { - ret := _mock.Called(ctx, off, length) +func (_mock *MockDiff) Slice(ctx context.Context, off int64, length int64, ft *storage.FrameTable) ([]byte, error) { + ret := _mock.Called(ctx, off, length, ft) if len(ret) == 0 { panic("no return value specified for Slice") @@ -469,18 +476,18 @@ func (_mock *MockDiff) Slice(ctx context.Context, off int64, length int64) ([]by var r0 []byte var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, int64, int64) ([]byte, error)); ok { - return returnFunc(ctx, off, length) + if returnFunc, ok := ret.Get(0).(func(context.Context, int64, int64, *storage.FrameTable) ([]byte, error)); ok { + return returnFunc(ctx, off, length, ft) } - if returnFunc, ok := ret.Get(0).(func(context.Context, int64, int64) []byte); ok { - r0 = returnFunc(ctx, off, length) + if returnFunc, ok := ret.Get(0).(func(context.Context, int64, int64, *storage.FrameTable) []byte); ok { + r0 = returnFunc(ctx, off, length, ft) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]byte) } } - if returnFunc, ok := ret.Get(1).(func(context.Context, int64, int64) error); ok { - r1 = returnFunc(ctx, off, length) + if returnFunc, ok := ret.Get(1).(func(context.Context, int64, int64, *storage.FrameTable) error); ok { + r1 = returnFunc(ctx, off, length, ft) } else { r1 = ret.Error(1) } @@ -496,11 +503,12 @@ type MockDiff_Slice_Call struct { // - ctx context.Context // - off int64 // - length int64 -func (_e *MockDiff_Expecter) Slice(ctx interface{}, off interface{}, length interface{}) *MockDiff_Slice_Call { - return &MockDiff_Slice_Call{Call: _e.mock.On("Slice", ctx, off, length)} +// - ft *storage.FrameTable +func (_e *MockDiff_Expecter) Slice(ctx interface{}, off interface{}, length interface{}, ft interface{}) *MockDiff_Slice_Call { + return &MockDiff_Slice_Call{Call: _e.mock.On("Slice", ctx, off, length, ft)} } -func (_c *MockDiff_Slice_Call) Run(run func(ctx context.Context, off int64, length int64)) *MockDiff_Slice_Call { +func (_c *MockDiff_Slice_Call) Run(run func(ctx context.Context, off int64, length int64, ft *storage.FrameTable)) *MockDiff_Slice_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { @@ -514,10 +522,15 @@ func (_c *MockDiff_Slice_Call) Run(run func(ctx context.Context, off int64, leng if args[2] != nil { arg2 = args[2].(int64) } + var arg3 *storage.FrameTable + if args[3] != nil { + arg3 = args[3].(*storage.FrameTable) + } run( arg0, arg1, arg2, + arg3, ) }) return _c @@ -528,7 +541,7 @@ func (_c *MockDiff_Slice_Call) Return(bytes []byte, err error) *MockDiff_Slice_C return _c } -func (_c *MockDiff_Slice_Call) RunAndReturn(run func(ctx context.Context, off int64, length int64) ([]byte, error)) *MockDiff_Slice_Call { +func (_c *MockDiff_Slice_Call) RunAndReturn(run func(ctx context.Context, off int64, length int64, ft *storage.FrameTable) ([]byte, error)) *MockDiff_Slice_Call { _c.Call.Return(run) return _c } diff --git a/packages/orchestrator/pkg/sandbox/build/storage_diff.go b/packages/orchestrator/pkg/sandbox/build/storage_diff.go index 1a4736d3d5..008afd21b4 100644 --- a/packages/orchestrator/pkg/sandbox/build/storage_diff.go +++ b/packages/orchestrator/pkg/sandbox/build/storage_diff.go @@ -3,7 +3,6 @@ package build import ( "context" "fmt" - "io" "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/block" blockmetrics "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/block/metrics" @@ -11,21 +10,18 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/storage" ) -func storagePath(buildId string, diffType DiffType) string { - return fmt.Sprintf("%s/%s", buildId, diffType) -} - type StorageDiff struct { - chunker block.Chunker + chunker *block.Chunker cachePath string cacheKey DiffStoreKey storagePath string storageObjectType storage.SeekableObjectType - blockSize int64 - metrics blockmetrics.Metrics - persistence storage.StorageProvider - featureFlags *featureflags.Client + blockSize int64 + metrics blockmetrics.Metrics + persistence storage.StorageProvider + featureFlags *featureflags.Client + uncompressedSize int64 } var _ Diff = (*StorageDiff)(nil) @@ -45,9 +41,10 @@ func newStorageDiff( blockSize int64, metrics blockmetrics.Metrics, persistence storage.StorageProvider, - featureFlags *featureflags.Client, + uncompressedSize int64, + ct storage.CompressionType, + ff *featureflags.Client, ) (*StorageDiff, error) { - storagePath := storagePath(buildId, diffType) storageObjectType, ok := storageObjectType(diffType) if !ok { return nil, UnknownDiffTypeError{diffType} @@ -56,13 +53,14 @@ func newStorageDiff( cachePath := GenerateDiffCachePath(basePath, buildId, diffType) return &StorageDiff{ - storagePath: storagePath, + storagePath: storage.Paths{BuildID: buildId}.DataFile(string(diffType), ct), storageObjectType: storageObjectType, cachePath: cachePath, blockSize: blockSize, metrics: metrics, persistence: persistence, - featureFlags: featureFlags, + featureFlags: ff, + uncompressedSize: uncompressedSize, cacheKey: GetDiffStoreKey(buildId, diffType), }, nil } @@ -88,12 +86,15 @@ func (b *StorageDiff) Init(ctx context.Context) error { return err } - size, err := obj.Size(ctx) - if err != nil { - return fmt.Errorf("failed to get object size: %w", err) + size := b.uncompressedSize + if size == 0 { + size, err = obj.Size(ctx) + if err != nil { + return fmt.Errorf("failed to get object size: %w", err) + } } - c, err := block.NewChunker(ctx, b.featureFlags, size, b.blockSize, obj, b.cachePath, b.metrics) + c, err := block.NewChunker(b.featureFlags, size, b.blockSize, obj, b.cachePath, b.metrics) if err != nil { return fmt.Errorf("failed to create chunker: %w", err) } @@ -111,16 +112,12 @@ func (b *StorageDiff) Close() error { return b.chunker.Close() } -func (b *StorageDiff) ReadAt(ctx context.Context, p []byte, off int64) (int, error) { - return b.chunker.ReadAt(ctx, p, off) -} - -func (b *StorageDiff) Slice(ctx context.Context, off, length int64) ([]byte, error) { - return b.chunker.Slice(ctx, off, length) +func (b *StorageDiff) ReadAt(ctx context.Context, p []byte, off int64, ft *storage.FrameTable) (int, error) { + return b.chunker.ReadAt(ctx, p, off, ft) } -func (b *StorageDiff) WriteTo(ctx context.Context, w io.Writer) (int64, error) { - return b.chunker.WriteTo(ctx, w) +func (b *StorageDiff) Slice(ctx context.Context, off, length int64, ft *storage.FrameTable) ([]byte, error) { + return b.chunker.Slice(ctx, off, length, ft) } // The local file might not be synced. diff --git a/packages/orchestrator/pkg/sandbox/build_upload.go b/packages/orchestrator/pkg/sandbox/build_upload.go new file mode 100644 index 0000000000..c0f70fe184 --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/build_upload.go @@ -0,0 +1,185 @@ +package sandbox + +import ( + "context" + "errors" + "fmt" + + "github.com/google/uuid" + "github.com/launchdarkly/go-sdk-common/v3/ldcontext" + + "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/build" + "github.com/e2b-dev/infra/packages/shared/pkg/featureflags" + "github.com/e2b-dev/infra/packages/shared/pkg/storage" + headers "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" + "github.com/e2b-dev/infra/packages/shared/pkg/utils" +) + +type Upload struct { + buildID uuid.UUID + snap *Snapshot + paths storage.Paths + uploads *Uploads + store storage.StorageProvider + mem storage.CompressConfig + root storage.CompressConfig + objectMetadata storage.ObjectMetadata + future *utils.ErrorOnce +} + +func NewUpload( + ctx context.Context, + uploads *Uploads, + snap *Snapshot, + store storage.StorageProvider, + cfg storage.CompressConfig, + ff *featureflags.Client, + useCase string, + objectMetadata storage.ObjectMetadata, +) (*Upload, error) { + mem, err := resolveCompressConfig(ctx, cfg, ff, storage.MemfileName, snap.MemfileDiffHeader.Metadata.BlockSize, useCase) + if err != nil { + return nil, fmt.Errorf("resolve memfile compress config: %w", err) + } + root, err := resolveCompressConfig(ctx, cfg, ff, storage.RootfsName, snap.RootfsDiffHeader.Metadata.BlockSize, useCase) + if err != nil { + return nil, fmt.Errorf("resolve rootfs compress config: %w", err) + } + + u := &Upload{ + buildID: snap.BuildID, + snap: snap, + paths: storage.Paths{BuildID: snap.BuildID.String()}, + uploads: uploads, + store: store, + mem: mem, + root: root, + objectMetadata: objectMetadata, + } + + if uploads != nil { + fut, err := uploads.Start(snap.BuildID) + if err != nil { + return nil, err + } + u.future = fut + } + + return u, nil +} + +func (u *Upload) Run(ctx context.Context) error { + if !u.mem.IsCompressionEnabled() && !u.root.IsCompressionEnabled() { + return u.runV3(ctx) + } + + return u.runV4(ctx) +} + +// Finish signals the upload's terminal outcome. Same-orch waiters wake on the +// future; cross-orch waiters wake on the Redis hint published here. +func (u *Upload) Finish(ctx context.Context, uploadErr error) { + if u.future != nil { + _ = u.future.SetError(uploadErr) + } + if u.uploads != nil { + u.uploads.publishUploadDoneToRedis(ctx, u.buildID, uploadErr) + } +} + +// publish swaps a finalized header into the local cached device so peers and +// Wait()ers see the build as complete. ErrBuildNotInCache is the one acceptable +// failure mode: nothing was cached locally, nothing to swap. +func (u *Upload) publish(ctx context.Context, t build.DiffType, h *headers.Header) error { + if u.uploads == nil { + return nil + } + + dev, err := u.uploads.find(ctx, u.buildID, t) + if errors.Is(err, ErrBuildNotInCache) { + return nil + } + if err != nil { + return fmt.Errorf("load %s for swap: %w", t, err) + } + + dev.SwapHeader(h) + + return nil +} + +// resolveCompressConfig returns the effective compression config for a given +// file type and use case. Feature flags override the base config when active. +// Returns zero-value CompressConfig when compression is disabled. +// +// fileType and useCase are added to the LD evaluation context so that +// LaunchDarkly targeting rules can differentiate (e.g. compress memfile +// but not rootfs, or compress builds but not pauses). blockSize is the +// in-VM read granularity for this fileType (from the diff header) and +// constrains the legal frame sizes — see validateCompressConfig. +// +// The resolved config is validated; an invalid env or LD-derived config +// surfaces as an error so the upload fails fast rather than streaming with +// a misconfigured frame size. +func resolveCompressConfig(ctx context.Context, base storage.CompressConfig, ff *featureflags.Client, fileType string, blockSize uint64, useCase string) (storage.CompressConfig, error) { + resolved := base + + if ff != nil { + var extra []ldcontext.Context + if fileType != "" { + extra = append(extra, featureflags.CompressFileTypeContext(fileType)) + } + if useCase != "" { + extra = append(extra, featureflags.CompressUseCaseContext(useCase)) + } + ctx = featureflags.AddToContext(ctx, extra...) + + v := ff.JSONFlag(ctx, featureflags.CompressConfigFlag).AsValueMap() + + if v.Get("compressBuilds").BoolValue() { + ct := v.Get("compressionType").StringValue() + ldCfg := storage.CompressConfig{ + Enabled: true, + Type: ct, + Level: v.Get("compressionLevel").IntValue(), + FrameSizeKB: v.Get("frameSizeKB").IntValue(), + MinPartSizeMB: v.Get("minPartSizeMB").IntValue(), + FrameEncodeWorkers: v.Get("frameEncodeWorkers").IntValue(), + EncoderConcurrency: v.Get("encoderConcurrency").IntValue(), + } + if ldCfg.CompressionType() != storage.CompressionNone { + resolved = ldCfg + } + } + } + + if !resolved.IsCompressionEnabled() { + return storage.CompressConfig{}, nil + } + + if err := validateCompressConfig(resolved, blockSize); err != nil { + return storage.CompressConfig{}, err + } + + return resolved, nil +} + +// validateCompressConfig checks that the resolved config is internally +// consistent for the given block size. Frame size must be a positive multiple +// of blockSize so that every block-sized read served by the chunker lies +// inside one frame — otherwise Chunker.fetch fetches only the start frame and +// cache.sliceDirect returns uninitialized mmap bytes for the tail. +func validateCompressConfig(c storage.CompressConfig, blockSize uint64) error { + fs := c.FrameSize() + if fs <= 0 { + return fmt.Errorf("frame size must be positive, got %d KB", c.FrameSizeKB) + } + if blockSize == 0 { + return errors.New("block size must be positive") + } + if uint64(fs)%blockSize != 0 { + return fmt.Errorf("frame size (%d) must be a multiple of block size (%d)", fs, blockSize) + } + + return nil +} diff --git a/packages/orchestrator/pkg/sandbox/build_upload_v3.go b/packages/orchestrator/pkg/sandbox/build_upload_v3.go new file mode 100644 index 0000000000..134c7ede7d --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/build_upload_v3.go @@ -0,0 +1,112 @@ +package sandbox + +import ( + "context" + "fmt" + + "golang.org/x/sync/errgroup" + + "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/build" + "github.com/e2b-dev/infra/packages/shared/pkg/storage" + headers "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +func (u *Upload) runV3(ctx context.Context) error { + memfilePath, err := u.snap.MemfileDiff.CachePath() + if err != nil { + return fmt.Errorf("error getting memfile diff path: %w", err) + } + + rootfsPath, err := u.snap.RootfsDiff.CachePath() + if err != nil { + return fmt.Errorf("error getting rootfs diff path: %w", err) + } + + eg, egCtx := errgroup.WithContext(ctx) + + eg.Go(func() error { + if u.snap.MemfileDiffHeader == nil { + return nil + } + + return headers.StoreHeader(egCtx, u.store, u.paths.MemfileHeader(), finalizeV3(u.snap.MemfileDiffHeader)) + }) + + eg.Go(func() error { + if u.snap.RootfsDiffHeader == nil { + return nil + } + + return headers.StoreHeader(egCtx, u.store, u.paths.RootfsHeader(), finalizeV3(u.snap.RootfsDiffHeader)) + }) + + meta := storage.WithMetadata(u.objectMetadata) + + eg.Go(func() error { + if memfilePath == "" { + return nil + } + + _, _, err := storage.UploadFramed(egCtx, u.store, u.paths.Memfile(), storage.MemfileObjectType, memfilePath, meta) + + return err + }) + + eg.Go(func() error { + if rootfsPath == "" { + return nil + } + + _, _, err := storage.UploadFramed(egCtx, u.store, u.paths.Rootfs(), storage.RootFSObjectType, rootfsPath, meta) + + return err + }) + + eg.Go(func() error { + return storage.UploadBlob(egCtx, u.store, u.paths.Snapfile(), storage.SnapfileObjectType, u.snap.Snapfile.Path(), meta) + }) + + eg.Go(func() error { + return storage.UploadBlob(egCtx, u.store, u.paths.Metadata(), storage.MetadataObjectType, u.snap.Metafile.Path(), meta) + }) + + if err := eg.Wait(); err != nil { + return err + } + + if u.snap.MemfileDiffHeader != nil { + if _, err := u.collectAncestorBuilds(ctx, u.snap.MemfileDiffHeader.Mapping, build.Memfile); err != nil { + return err + } + } + if h := finalizeV3(u.snap.MemfileDiffHeader); h != nil { + if err := u.publish(ctx, build.Memfile, h); err != nil { + return err + } + } + + if u.snap.RootfsDiffHeader != nil { + if _, err := u.collectAncestorBuilds(ctx, u.snap.RootfsDiffHeader.Mapping, build.Rootfs); err != nil { + return err + } + } + if h := finalizeV3(u.snap.RootfsDiffHeader); h != nil { + if err := u.publish(ctx, build.Rootfs, h); err != nil { + return err + } + } + + return nil +} + +// finalizeV3 returns a shallow copy of src with IncompletePendingUpload cleared, +// or nil if src is nil. Safe shallow copy: only the bool field is mutated. +func finalizeV3(src *headers.Header) *headers.Header { + if src == nil { + return nil + } + h := *src + h.IncompletePendingUpload = false + + return &h +} diff --git a/packages/orchestrator/pkg/sandbox/build_upload_v4.go b/packages/orchestrator/pkg/sandbox/build_upload_v4.go new file mode 100644 index 0000000000..0344bf4326 --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/build_upload_v4.go @@ -0,0 +1,151 @@ +package sandbox + +import ( + "context" + "fmt" + + "github.com/google/uuid" + "golang.org/x/sync/errgroup" + + "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/build" + "github.com/e2b-dev/infra/packages/shared/pkg/storage" + headers "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +func (u *Upload) runV4(ctx context.Context) error { + memSrc, err := u.snap.MemfileDiff.CachePath() + if err != nil { + return fmt.Errorf("memfile diff path: %w", err) + } + + rootfsSrc, err := u.snap.RootfsDiff.CachePath() + if err != nil { + return fmt.Errorf("rootfs diff path: %w", err) + } + + eg, ctx := errgroup.WithContext(ctx) + + if u.snap.MemfileDiffHeader != nil { + eg.Go(func() error { + return u.uploadFramed(ctx, build.Memfile, memSrc, u.snap.MemfileDiffHeader, u.mem) + }) + } + + if u.snap.RootfsDiffHeader != nil { + eg.Go(func() error { + return u.uploadFramed(ctx, build.Rootfs, rootfsSrc, u.snap.RootfsDiffHeader, u.root) + }) + } + + meta := storage.WithMetadata(u.objectMetadata) + + eg.Go(func() error { + return storage.UploadBlob(ctx, u.store, u.paths.Snapfile(), storage.SnapfileObjectType, u.snap.Snapfile.Path(), meta) + }) + + eg.Go(func() error { + return storage.UploadBlob(ctx, u.store, u.paths.Metadata(), storage.MetadataObjectType, u.snap.Metafile.Path(), meta) + }) + + return eg.Wait() +} + +func (u *Upload) uploadFramed( + ctx context.Context, + fileType build.DiffType, + srcPath string, + srcHeader *headers.Header, + cfg storage.CompressConfig, +) error { + var selfBuild headers.BuildData + + if srcPath != "" { + ft, checksum, err := storage.UploadFramed(ctx, u.store, u.paths.DataFile(string(fileType), cfg.CompressionType()), seekableTypeFor(fileType), srcPath, storage.WithCompressConfig(cfg), storage.WithMetadata(u.objectMetadata)) + if err != nil { + return fmt.Errorf("%s upload: %w", fileType, err) + } + + // FrameTable count, not os.Stat: sparse memfile diffs stream less than + // they appear on disk. + selfBuild = headers.BuildData{Size: ft.UncompressedSize(), Checksum: checksum} + if ft.IsCompressed() { + selfBuild.FrameData = ft + } + } + + h := srcHeader.CloneForUpload(headers.MetadataVersionV4) + h.IncompletePendingUpload = false + + // Dependency closure is the set of buildIDs referenced by mappings, minus + // self. Each ancestor's BuildData lives in its own finalized header's + // self-entry; Wait routes to local future, peer, or remote storage as + // appropriate. Already-final ancestors resolve immediately (remote storage + // round-trip beats blocking on whatever the immediate parent's upload is + // doing). + ancestors, err := u.collectAncestorBuilds(ctx, srcHeader.Mapping, fileType) + if err != nil { + return err + } + + // Empty diffs still represent a layer descendants must record as an ancestor. + h.Builds = ancestors + h.Builds[u.buildID] = selfBuild + + if err := headers.StoreHeader(ctx, u.store, u.paths.HeaderFile(string(fileType)), h); err != nil { + return fmt.Errorf("store %s header: %w", fileType, err) + } + + return u.publish(ctx, fileType, h) +} + +// collectAncestorBuilds resolves every unique buildID referenced by mappings +// (excluding self) to its finalized BuildData. Local ancestors resolve from the +// in-memory futures map without any I/O; cross-orch ancestors take a single +// remote storage round-trip each. Sequential — the critical path is the slowest +// pending Wait either way, and serial keeps the code simple. +func (u *Upload) collectAncestorBuilds( + ctx context.Context, + mappings []headers.BuildMap, + fileType build.DiffType, +) (map[uuid.UUID]headers.BuildData, error) { + out := make(map[uuid.UUID]headers.BuildData) + if u.uploads == nil { + return out, nil + } + + for _, m := range mappings { + if m.BuildId == u.buildID || m.BuildId == uuid.Nil { + continue + } + if _, dup := out[m.BuildId]; dup { + continue + } + + h, err := u.uploads.Wait(ctx, m.BuildId, fileType) + if err != nil { + return nil, fmt.Errorf("wait for ancestor %s/%s: %w", m.BuildId, fileType, err) + } + // V3 ancestors have Builds=nil (FrameTable is V4-only); their data is + // raw bytes and the read path doesn't consult Builds for them. Skip + // silently so V4 descendants of V3 ancestors still upload. + bd, ok := h.Builds[m.BuildId] + if !ok { + continue + } + + out[m.BuildId] = bd + } + + return out, nil +} + +func seekableTypeFor(fileType build.DiffType) storage.SeekableObjectType { + switch fileType { + case build.Memfile: + return storage.MemfileObjectType + case build.Rootfs: + return storage.RootFSObjectType + } + + return storage.UnknownSeekableObjectType +} diff --git a/packages/orchestrator/pkg/sandbox/nbd/dispatch.go b/packages/orchestrator/pkg/sandbox/nbd/dispatch.go index d24b71c5b4..3235e6e4ef 100644 --- a/packages/orchestrator/pkg/sandbox/nbd/dispatch.go +++ b/packages/orchestrator/pkg/sandbox/nbd/dispatch.go @@ -14,7 +14,6 @@ import ( "go.uber.org/zap" "github.com/e2b-dev/infra/packages/shared/pkg/logger" - "github.com/e2b-dev/infra/packages/shared/pkg/storage" "github.com/e2b-dev/infra/packages/shared/pkg/utils" ) @@ -43,7 +42,8 @@ var dispatchBufPool = sync.Pool{ } type Provider interface { - storage.SeekableReader + ReadAt(ctx context.Context, p []byte, off int64) (int, error) + Size(ctx context.Context) (int64, error) io.WriterAt } diff --git a/packages/orchestrator/pkg/sandbox/nbd/path_direct_slow_test.go b/packages/orchestrator/pkg/sandbox/nbd/path_direct_slow_test.go index d1316247dc..82a12dacee 100644 --- a/packages/orchestrator/pkg/sandbox/nbd/path_direct_slow_test.go +++ b/packages/orchestrator/pkg/sandbox/nbd/path_direct_slow_test.go @@ -57,6 +57,10 @@ func (s *SlowDevice) Header() *header.Header { return s.inner.Header() } +func (s *SlowDevice) SwapHeader(h *header.Header) { + s.inner.SwapHeader(h) +} + func (s *SlowDevice) Close() error { return s.inner.Close() } diff --git a/packages/orchestrator/pkg/sandbox/nbd/testutils/build_device.go b/packages/orchestrator/pkg/sandbox/nbd/testutils/build_device.go index 1041107be1..bc4e94ab00 100644 --- a/packages/orchestrator/pkg/sandbox/nbd/testutils/build_device.go +++ b/packages/orchestrator/pkg/sandbox/nbd/testutils/build_device.go @@ -37,6 +37,10 @@ func (m *BuildDevice) Header() *header.Header { return m.header } +func (m *BuildDevice) SwapHeader(h *header.Header) { + m.header = h +} + func (m *BuildDevice) Size(_ context.Context) (int64, error) { return int64(m.header.Metadata.Size), nil } diff --git a/packages/orchestrator/pkg/sandbox/nbd/testutils/logger_overlay.go b/packages/orchestrator/pkg/sandbox/nbd/testutils/logger_overlay.go index ea33af60cd..932e5fae23 100644 --- a/packages/orchestrator/pkg/sandbox/nbd/testutils/logger_overlay.go +++ b/packages/orchestrator/pkg/sandbox/nbd/testutils/logger_overlay.go @@ -61,6 +61,10 @@ func (l *LoggerOverlay) Header() *header.Header { return l.overlay.Header() } +func (l *LoggerOverlay) SwapHeader(h *header.Header) { + l.overlay.SwapHeader(h) +} + func (l *LoggerOverlay) Close() error { return l.overlay.Close() } diff --git a/packages/orchestrator/pkg/sandbox/nbd/testutils/zero_device.go b/packages/orchestrator/pkg/sandbox/nbd/testutils/zero_device.go index e4733ea20c..357e6e7a89 100644 --- a/packages/orchestrator/pkg/sandbox/nbd/testutils/zero_device.go +++ b/packages/orchestrator/pkg/sandbox/nbd/testutils/zero_device.go @@ -62,6 +62,10 @@ func (z *ZeroDevice) Header() *header.Header { return z.header } +func (z *ZeroDevice) SwapHeader(h *header.Header) { + z.header = h +} + func (z *ZeroDevice) Close() error { return nil } diff --git a/packages/orchestrator/pkg/sandbox/sandbox.go b/packages/orchestrator/pkg/sandbox/sandbox.go index f69100015d..94f2ed7467 100644 --- a/packages/orchestrator/pkg/sandbox/sandbox.go +++ b/packages/orchestrator/pkg/sandbox/sandbox.go @@ -1125,6 +1125,8 @@ func (s *Sandbox) Pause( RootfsDiff: rootfsDiff, RootfsDiffHeader: rootfsDiffHeader, + BuildID: buildID, + cleanup: cleanup, }, nil } diff --git a/packages/orchestrator/pkg/sandbox/snapshot.go b/packages/orchestrator/pkg/sandbox/snapshot.go index 2633e330fb..37ab988e19 100644 --- a/packages/orchestrator/pkg/sandbox/snapshot.go +++ b/packages/orchestrator/pkg/sandbox/snapshot.go @@ -4,9 +4,10 @@ import ( "context" "fmt" + "github.com/google/uuid" + "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/build" "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/template" - "github.com/e2b-dev/infra/packages/shared/pkg/storage" "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" ) @@ -17,63 +18,11 @@ type Snapshot struct { RootfsDiffHeader *header.Header Snapfile template.File Metafile template.File + BuildID uuid.UUID cleanup *Cleanup } -// Upload writes snapshot artifacts to persistence under paths. objectMetadata -// is attached to every uploaded object; pass nil to skip. -func (s *Snapshot) Upload( - ctx context.Context, - persistence storage.StorageProvider, - paths storage.Paths, - objectMetadata storage.ObjectMetadata, -) error { - var memfilePath *string - switch r := s.MemfileDiff.(type) { - case *build.NoDiff: - default: - memfileLocalPath, err := r.CachePath() - if err != nil { - return fmt.Errorf("error getting memfile diff path: %w", err) - } - - memfilePath = &memfileLocalPath - } - - var rootfsPath *string - switch r := s.RootfsDiff.(type) { - case *build.NoDiff: - default: - rootfsLocalPath, err := r.CachePath() - if err != nil { - return fmt.Errorf("error getting rootfs diff path: %w", err) - } - - rootfsPath = &rootfsLocalPath - } - - templateBuild := NewTemplateBuild( - s.MemfileDiffHeader, - s.RootfsDiffHeader, - persistence, - paths, - objectMetadata, - ) - - if err := templateBuild.Upload( - ctx, - s.Metafile.Path(), - s.Snapfile.Path(), - memfilePath, - rootfsPath, - ); err != nil { - return fmt.Errorf("error uploading template files: %w", err) - } - - return nil -} - func (s *Snapshot) Close(ctx context.Context) error { err := s.cleanup.Run(ctx) if err != nil { diff --git a/packages/orchestrator/pkg/sandbox/template/peerclient/blob_test.go b/packages/orchestrator/pkg/sandbox/template/peerclient/blob_test.go index 350d0d972d..a27babc852 100644 --- a/packages/orchestrator/pkg/sandbox/template/peerclient/blob_test.go +++ b/packages/orchestrator/pkg/sandbox/template/peerclient/blob_test.go @@ -15,8 +15,6 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/grpc/orchestrator" orchestratormocks "github.com/e2b-dev/infra/packages/shared/pkg/grpc/orchestrator/mocks" "github.com/e2b-dev/infra/packages/shared/pkg/storage" - storagemocks "github.com/e2b-dev/infra/packages/shared/pkg/storage/mocks" - providermocks "github.com/e2b-dev/infra/packages/shared/pkg/storage/mocks/provider" ) func TestPeerBlob_WriteTo_PeerSucceeds(t *testing.T) { @@ -55,14 +53,13 @@ func TestPeerBlob_WriteTo_PeerNotAvailable_FallsBackToBase(t *testing.T) { client := orchestratormocks.NewMockChunkServiceClient(t) client.EXPECT().GetBuildBlob(mock.Anything, mock.Anything).Return(stream, nil) - baseBlob := storagemocks.NewMockBlob(t) + baseBlob := storage.NewMockBlob(t) baseBlob.EXPECT().WriteTo(mock.Anything, mock.Anything).RunAndReturn(func(_ context.Context, dst io.Writer) (int64, error) { n, err := dst.Write([]byte("from gcs")) return int64(n), err }) - - base := providermocks.NewMockStorageProvider(t) + base := storage.NewMockStorageProvider(t) base.EXPECT().OpenBlob(mock.Anything, "build-1/snapfile", storage.SnapfileObjectType).Return(baseBlob, nil) blob := &peerBlob{peerHandle: peerHandle[storage.Blob]{ @@ -88,14 +85,13 @@ func TestPeerBlob_WriteTo_PeerError_FallsBackToBase(t *testing.T) { client := orchestratormocks.NewMockChunkServiceClient(t) client.EXPECT().GetBuildBlob(mock.Anything, mock.Anything).Return(nil, errors.New("connection refused")) - baseBlob := storagemocks.NewMockBlob(t) + baseBlob := storage.NewMockBlob(t) baseBlob.EXPECT().WriteTo(mock.Anything, mock.Anything).RunAndReturn(func(_ context.Context, dst io.Writer) (int64, error) { n, err := dst.Write([]byte("from gcs")) return int64(n), err }) - - base := providermocks.NewMockStorageProvider(t) + base := storage.NewMockStorageProvider(t) base.EXPECT().OpenBlob(mock.Anything, "build-1/snapfile", storage.SnapfileObjectType).Return(baseBlob, nil) blob := &peerBlob{peerHandle: peerHandle[storage.Blob]{ @@ -134,14 +130,13 @@ func TestPeerBlob_WriteTo_UploadedSetMidStream_CompletesFromPeerThenFallsBack(t client := orchestratormocks.NewMockChunkServiceClient(t) client.EXPECT().GetBuildBlob(mock.Anything, mock.Anything).Return(stream, nil).Once() - baseBlob := storagemocks.NewMockBlob(t) + baseBlob := storage.NewMockBlob(t) baseBlob.EXPECT().WriteTo(mock.Anything, mock.Anything).RunAndReturn(func(_ context.Context, dst io.Writer) (int64, error) { n, err := dst.Write([]byte("from storage")) return int64(n), err }) - - base := providermocks.NewMockStorageProvider(t) + base := storage.NewMockStorageProvider(t) base.EXPECT().OpenBlob(mock.Anything, "build-1/snapfile", storage.SnapfileObjectType).Return(baseBlob, nil) blob := &peerBlob{peerHandle: peerHandle[storage.Blob]{ @@ -190,10 +185,9 @@ func TestPeerBlob_Exists_PeerNotAvailable_FallsBackToBase(t *testing.T) { client := orchestratormocks.NewMockChunkServiceClient(t) client.EXPECT().GetBuildFileExists(mock.Anything, mock.Anything).Return(&orchestrator.GetBuildFileExistsResponse{Availability: &orchestrator.PeerAvailability{NotAvailable: true}}, nil) - baseBlob := storagemocks.NewMockBlob(t) + baseBlob := storage.NewMockBlob(t) baseBlob.EXPECT().Exists(mock.Anything).Return(true, nil) - - base := providermocks.NewMockStorageProvider(t) + base := storage.NewMockStorageProvider(t) base.EXPECT().OpenBlob(mock.Anything, "build-1/snapfile", storage.SnapfileObjectType).Return(baseBlob, nil) blob := &peerBlob{peerHandle: peerHandle[storage.Blob]{ @@ -217,10 +211,9 @@ func TestPeerBlob_Exists_UseStorage_FallsBackToBase(t *testing.T) { client := orchestratormocks.NewMockChunkServiceClient(t) client.EXPECT().GetBuildFileExists(mock.Anything, mock.Anything).Return(&orchestrator.GetBuildFileExistsResponse{Availability: &orchestrator.PeerAvailability{UseStorage: true}}, nil) - baseBlob := storagemocks.NewMockBlob(t) + baseBlob := storage.NewMockBlob(t) baseBlob.EXPECT().Exists(mock.Anything).Return(true, nil) - - base := providermocks.NewMockStorageProvider(t) + base := storage.NewMockStorageProvider(t) base.EXPECT().OpenBlob(mock.Anything, "build-1/snapfile", storage.SnapfileObjectType).Return(baseBlob, nil) uploaded := &atomic.Bool{} diff --git a/packages/orchestrator/pkg/sandbox/template/peerclient/seekable.go b/packages/orchestrator/pkg/sandbox/template/peerclient/seekable.go index eacfc5c488..7d09b34a4f 100644 --- a/packages/orchestrator/pkg/sandbox/template/peerclient/seekable.go +++ b/packages/orchestrator/pkg/sandbox/template/peerclient/seekable.go @@ -20,6 +20,14 @@ var _ storage.Seekable = (*peerSeekable)(nil) // calls (e.g. ReadAt then OpenRangeReader) do not re-open the underlying GCS object. type peerSeekable struct { peerHandle[storage.Seekable] + + // transitionEmitted ensures we signal PeerTransitionedError at most once + // after the peer flips uploaded=true. The caller (build.File) reacts by + // loading the post-upload header from storage; whether that ends up V4 + // (compressed) or V3 (no upgrade) determines how subsequent reads route. + // Either way, after the first emission we fall through to base so V3 + // builds don't loop forever against PeerTransitionedError. + transitionEmitted atomic.Bool } func (s *peerSeekable) Size(ctx context.Context) (int64, error) { @@ -45,53 +53,7 @@ func (s *peerSeekable) Size(ctx context.Context) (int64, error) { ) } -func (s *peerSeekable) ReadAt(ctx context.Context, buf []byte, off int64) (int, error) { - return withPeerFallback(ctx, &s.peerHandle, "read-at peer-seekable", attrOpReadAt, - func(ctx context.Context) (peerAttempt[int], error) { - streamCtx, cancel := context.WithCancel(ctx) - defer cancel() - - recv, err := openPeerSeekableStream(streamCtx, s.client, &orchestrator.ReadAtBuildSeekableRequest{ - BuildId: s.buildID, - FileName: s.fileName, - Offset: off, - Length: int64(len(buf)), - }, s.uploaded) - if err != nil { - logger.L().Warn(ctx, "failed to read build file from peer", logger.WithBuildID(s.buildID), zap.Int64("off", off), zap.Int("buf_len", len(buf)), zap.Error(err)) - - return peerAttempt[int]{}, nil - } - - n := 0 - - for n < len(buf) { - data, recvErr := recv() - if errors.Is(recvErr, io.EOF) { - break - } - - if recvErr != nil { - return peerAttempt[int]{value: n, bytes: int64(n), hit: true}, - fmt.Errorf("failed to receive chunk from peer: %w", recvErr) - } - - n += copy(buf[n:], data) - } - - if n < len(buf) { - return peerAttempt[int]{value: n, bytes: int64(n), hit: true}, io.ErrUnexpectedEOF - } - - return peerAttempt[int]{value: n, bytes: int64(n), hit: true}, nil - }, - func(ctx context.Context, base storage.Seekable) (int, error) { - return base.ReadAt(ctx, buf, off) - }, - ) -} - -func (s *peerSeekable) OpenRangeReader(ctx context.Context, off, length int64) (io.ReadCloser, error) { +func (s *peerSeekable) OpenRangeReader(ctx context.Context, off int64, length int64, frameTable *storage.FrameTable) (io.ReadCloser, error) { return withPeerFallback(ctx, &s.peerHandle, "peer-seekable-open-range-reader", attrOpRangeReader, func(ctx context.Context) (peerAttempt[io.ReadCloser], error) { streamCtx, cancel := context.WithCancel(ctx) @@ -115,16 +77,23 @@ func (s *peerSeekable) OpenRangeReader(ctx context.Context, off, length int64) ( }, nil }, func(ctx context.Context, base storage.Seekable) (io.ReadCloser, error) { - return base.OpenRangeReader(ctx, off, length) + // Signal the caller once to fetch the post-upload header from storage; + // thereafter fall through so V3 builds (no V4 to upgrade to) don't + // loop against PeerTransitionedError. + if s.uploaded != nil && s.uploaded.Load() && s.transitionEmitted.CompareAndSwap(false, true) { + return nil, &storage.PeerTransitionedError{} + } + + return base.OpenRangeReader(ctx, off, length, frameTable) }, ) } -func (s *peerSeekable) StoreFile(ctx context.Context, path string, opts ...storage.PutOption) error { +func (s *peerSeekable) StoreFile(ctx context.Context, path string, opts ...storage.PutOption) (*storage.FrameTable, [32]byte, error) { // Writes always go to the base provider (GCS/S3); the peer is read-only. fallback, err := s.getOrOpenBase(ctx) if err != nil { - return err + return nil, [32]byte{}, err } return fallback.StoreFile(ctx, path, opts...) diff --git a/packages/orchestrator/pkg/sandbox/template/peerclient/seekable_test.go b/packages/orchestrator/pkg/sandbox/template/peerclient/seekable_test.go index 2c0c913a07..60ae758604 100644 --- a/packages/orchestrator/pkg/sandbox/template/peerclient/seekable_test.go +++ b/packages/orchestrator/pkg/sandbox/template/peerclient/seekable_test.go @@ -15,8 +15,6 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/grpc/orchestrator" orchestratormocks "github.com/e2b-dev/infra/packages/shared/pkg/grpc/orchestrator/mocks" "github.com/e2b-dev/infra/packages/shared/pkg/storage" - storagemocks "github.com/e2b-dev/infra/packages/shared/pkg/storage/mocks" - providermocks "github.com/e2b-dev/infra/packages/shared/pkg/storage/mocks/provider" ) func TestPeerSeekable_Size_PeerSucceeds(t *testing.T) { @@ -39,10 +37,10 @@ func TestPeerSeekable_Size_PeerNotAvailable_FallsBackToBase(t *testing.T) { client := orchestratormocks.NewMockChunkServiceClient(t) client.EXPECT().GetBuildFileSize(mock.Anything, mock.Anything).Return(&orchestrator.GetBuildFileSizeResponse{Availability: &orchestrator.PeerAvailability{NotAvailable: true}}, nil) - baseSeekable := storagemocks.NewMockSeekable(t) + baseSeekable := storage.NewMockSeekable(t) baseSeekable.EXPECT().Size(mock.Anything).Return(int64(8192), nil) - base := providermocks.NewMockStorageProvider(t) + base := storage.NewMockStorageProvider(t) base.EXPECT().OpenSeekable(mock.Anything, "build-1/memfile", storage.MemfileObjectType).Return(baseSeekable, nil) s := &peerSeekable{peerHandle: peerHandle[storage.Seekable]{ @@ -59,45 +57,41 @@ func TestPeerSeekable_Size_PeerNotAvailable_FallsBackToBase(t *testing.T) { assert.Equal(t, int64(8192), size) } -func TestPeerSeekable_ReadAt_PeerSucceeds(t *testing.T) { +func TestPeerSeekable_OpenRangeReader_PeerSucceeds(t *testing.T) { t.Parallel() - data := []byte("block data") + data := []byte("range data") stream := orchestratormocks.NewMockChunkService_ReadAtBuildSeekableClient(t) - // ReadAt copies the first message directly into buf; the inner loop is skipped when buf is full. + // OpenRangeReader reads the first message; peerStreamReader.Read calls Recv once more for EOF. stream.EXPECT().Recv().Return(&orchestrator.ReadAtBuildSeekableResponse{Data: data}, nil).Once() + stream.EXPECT().Recv().Return(nil, io.EOF).Once() client := orchestratormocks.NewMockChunkServiceClient(t) client.EXPECT().ReadAtBuildSeekable(mock.Anything, mock.MatchedBy(func(req *orchestrator.ReadAtBuildSeekableRequest) bool { - return req.GetOffset() == 0 && req.GetLength() == int64(len(data)) + return req.GetOffset() == 10 && req.GetLength() == int64(len(data)) })).Return(stream, nil) s := &peerSeekable{peerHandle: peerHandle[storage.Seekable]{client: client, buildID: "build-1", fileName: storage.MemfileName, uploaded: &atomic.Bool{}}} - buf := make([]byte, len(data)) - n, err := s.ReadAt(t.Context(), buf, 0) + rc, err := s.OpenRangeReader(t.Context(), 10, int64(len(data)), nil) + require.NoError(t, err) + defer rc.Close() + + got, err := io.ReadAll(rc) require.NoError(t, err) - assert.Equal(t, len(data), n) - assert.Equal(t, data, buf) + assert.Equal(t, data, got) } -func TestPeerSeekable_ReadAt_PeerNotAvailable_FallsBackToBase(t *testing.T) { +func TestPeerSeekable_OpenRangeReader_PeerError_FallsBackToBase(t *testing.T) { t.Parallel() - baseData := []byte("base data") - stream := orchestratormocks.NewMockChunkService_ReadAtBuildSeekableClient(t) - stream.EXPECT().Recv().Return(&orchestrator.ReadAtBuildSeekableResponse{Availability: &orchestrator.PeerAvailability{NotAvailable: true}}, nil).Once() - + baseData := []byte("base range") client := orchestratormocks.NewMockChunkServiceClient(t) - client.EXPECT().ReadAtBuildSeekable(mock.Anything, mock.Anything).Return(stream, nil) - - baseSeekable := storagemocks.NewMockSeekable(t) - baseSeekable.EXPECT().ReadAt(mock.Anything, mock.Anything, int64(0)).RunAndReturn(func(_ context.Context, buf []byte, _ int64) (int, error) { - n := copy(buf, baseData) + client.EXPECT().ReadAtBuildSeekable(mock.Anything, mock.Anything).Return(nil, errors.New("peer unavailable")) - return n, nil - }) + baseSeekable := storage.NewMockSeekable(t) + baseSeekable.EXPECT().OpenRangeReader(mock.Anything, int64(0), int64(len(baseData)), (*storage.FrameTable)(nil)).Return(io.NopCloser(bytes.NewReader(baseData)), nil) - base := providermocks.NewMockStorageProvider(t) + base := storage.NewMockStorageProvider(t) base.EXPECT().OpenSeekable(mock.Anything, "build-1/memfile", storage.MemfileObjectType).Return(baseSeekable, nil) s := &peerSeekable{peerHandle: peerHandle[storage.Seekable]{ @@ -109,64 +103,40 @@ func TestPeerSeekable_ReadAt_PeerNotAvailable_FallsBackToBase(t *testing.T) { return base.OpenSeekable(ctx, "build-1/memfile", storage.MemfileObjectType) }, }} - buf := make([]byte, len(baseData)) - n, err := s.ReadAt(t.Context(), buf, 0) - require.NoError(t, err) - assert.Equal(t, len(baseData), n) - assert.Equal(t, baseData, buf) -} - -func TestPeerSeekable_OpenRangeReader_PeerSucceeds(t *testing.T) { - t.Parallel() - - data := []byte("range data") - stream := orchestratormocks.NewMockChunkService_ReadAtBuildSeekableClient(t) - // OpenRangeReader reads the first message; peerStreamReader.Read calls Recv once more for EOF. - stream.EXPECT().Recv().Return(&orchestrator.ReadAtBuildSeekableResponse{Data: data}, nil).Once() - stream.EXPECT().Recv().Return(nil, io.EOF).Once() - - client := orchestratormocks.NewMockChunkServiceClient(t) - client.EXPECT().ReadAtBuildSeekable(mock.Anything, mock.MatchedBy(func(req *orchestrator.ReadAtBuildSeekableRequest) bool { - return req.GetOffset() == 10 && req.GetLength() == int64(len(data)) - })).Return(stream, nil) - - s := &peerSeekable{peerHandle: peerHandle[storage.Seekable]{client: client, buildID: "build-1", fileName: storage.MemfileName, uploaded: &atomic.Bool{}}} - rc, err := s.OpenRangeReader(t.Context(), 10, int64(len(data))) + rc, err := s.OpenRangeReader(t.Context(), 0, int64(len(baseData)), nil) require.NoError(t, err) defer rc.Close() got, err := io.ReadAll(rc) require.NoError(t, err) - assert.Equal(t, data, got) + assert.Equal(t, baseData, got) } -func TestPeerSeekable_OpenRangeReader_PeerError_FallsBackToBase(t *testing.T) { +func TestPeerSeekable_OpenRangeReader_Uploaded_ReturnsPeerTransitionedError(t *testing.T) { t.Parallel() - baseData := []byte("base range") client := orchestratormocks.NewMockChunkServiceClient(t) - client.EXPECT().ReadAtBuildSeekable(mock.Anything, mock.Anything).Return(nil, errors.New("peer unavailable")) - baseSeekable := storagemocks.NewMockSeekable(t) - baseSeekable.EXPECT().OpenRangeReader(mock.Anything, int64(0), int64(len(baseData))).Return(io.NopCloser(bytes.NewReader(baseData)), nil) + uploaded := &atomic.Bool{} + uploaded.Store(true) - base := providermocks.NewMockStorageProvider(t) + baseSeekable := storage.NewMockSeekable(t) + base := storage.NewMockStorageProvider(t) base.EXPECT().OpenSeekable(mock.Anything, "build-1/memfile", storage.MemfileObjectType).Return(baseSeekable, nil) s := &peerSeekable{peerHandle: peerHandle[storage.Seekable]{ client: client, buildID: "build-1", fileName: storage.MemfileName, - uploaded: &atomic.Bool{}, + uploaded: uploaded, openFn: func(ctx context.Context) (storage.Seekable, error) { return base.OpenSeekable(ctx, "build-1/memfile", storage.MemfileObjectType) }, }} - rc, err := s.OpenRangeReader(t.Context(), 0, int64(len(baseData))) - require.NoError(t, err) - defer rc.Close() - got, err := io.ReadAll(rc) - require.NoError(t, err) - assert.Equal(t, baseData, got) + _, err := s.OpenRangeReader(t.Context(), 0, 100, nil) + require.Error(t, err) + + var transErr *storage.PeerTransitionedError + require.ErrorAs(t, err, &transErr) } diff --git a/packages/orchestrator/pkg/sandbox/template/peerclient/storage.go b/packages/orchestrator/pkg/sandbox/template/peerclient/storage.go index f683b4f2f6..ec9fc46945 100644 --- a/packages/orchestrator/pkg/sandbox/template/peerclient/storage.go +++ b/packages/orchestrator/pkg/sandbox/template/peerclient/storage.go @@ -34,7 +34,6 @@ var ( attrOpWriteTo = attribute.String("operation", "WriteTo") attrOpExists = attribute.String("operation", "Exists") attrOpSize = attribute.String("operation", "Size") - attrOpReadAt = attribute.String("operation", "ReadAt") attrOpRangeReader = attribute.String("operation", "OpenRangeReader") attrResolveRedisError = attribute.String("peer_resolve", "redis_error") diff --git a/packages/orchestrator/pkg/sandbox/template/peerclient/storage_test.go b/packages/orchestrator/pkg/sandbox/template/peerclient/storage_test.go index ca8ea8106d..7e9e895c3f 100644 --- a/packages/orchestrator/pkg/sandbox/template/peerclient/storage_test.go +++ b/packages/orchestrator/pkg/sandbox/template/peerclient/storage_test.go @@ -13,7 +13,6 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/grpc/orchestrator" orchestratormocks "github.com/e2b-dev/infra/packages/shared/pkg/grpc/orchestrator/mocks" "github.com/e2b-dev/infra/packages/shared/pkg/storage" - providermocks "github.com/e2b-dev/infra/packages/shared/pkg/storage/mocks/provider" ) func TestPeerStorageProvider_OpenBlob_ExtractsFileName(t *testing.T) { @@ -28,7 +27,7 @@ func TestPeerStorageProvider_OpenBlob_ExtractsFileName(t *testing.T) { return req.GetBuildId() == "build-1" && req.GetFileName() == "snapfile" })).Return(stream, nil) - base := providermocks.NewMockStorageProvider(t) + base := storage.NewMockStorageProvider(t) p := newPeerStorageProvider(base, client, &atomic.Bool{}) blob, err := p.OpenBlob(t.Context(), "build-1/snapfile", storage.SnapfileObjectType) @@ -48,13 +47,13 @@ func TestPeerStorageProvider_OpenSeekable_ExtractsFileName(t *testing.T) { return req.GetBuildId() == "build-1" && req.GetFileName() == "memfile" })).Return(&orchestrator.GetBuildFileSizeResponse{TotalSize: 512}, nil) - base := providermocks.NewMockStorageProvider(t) + base := storage.NewMockStorageProvider(t) p := newPeerStorageProvider(base, client, &atomic.Bool{}) - seekable, err := p.OpenSeekable(t.Context(), "build-1/memfile", storage.MemfileObjectType) + ff, err := p.OpenSeekable(t.Context(), "build-1/memfile", storage.MemfileObjectType) require.NoError(t, err) - size, err := seekable.Size(t.Context()) + size, err := ff.Size(t.Context()) require.NoError(t, err) assert.Equal(t, int64(512), size) } diff --git a/packages/orchestrator/pkg/sandbox/template/peerserver/header.go b/packages/orchestrator/pkg/sandbox/template/peerserver/header.go index 9691e217c0..835553d18f 100644 --- a/packages/orchestrator/pkg/sandbox/template/peerserver/header.go +++ b/packages/orchestrator/pkg/sandbox/template/peerserver/header.go @@ -35,7 +35,18 @@ func (f *headerSource) Stream(ctx context.Context, sender Sender) error { return ErrNotAvailable } - data, err := header.Serialize(h.Metadata, h.Mapping) + // V4 headers served via P2P are always for in-flight builds — peers stop + // being routed once the upload finalizes (peerStorageProvider switches to + // base/GCS via the uploaded flag). Force the wire bit on regardless of + // the in-memory state so consumers reliably treat these bytes as a + // pending diff and refresh from GCS once the upload lands. V3 has no + // in-flight notion on the wire, so it ships as-is and is treated as final. + wire := *h + if wire.Metadata.Version >= header.MetadataVersionV4 { + wire.IncompletePendingUpload = true + } + + data, err := header.SerializeHeader(&wire) if err != nil { span.RecordError(err) diff --git a/packages/orchestrator/pkg/sandbox/template/peerserver/resolve.go b/packages/orchestrator/pkg/sandbox/template/peerserver/resolve.go index 448526887b..72381a1e43 100644 --- a/packages/orchestrator/pkg/sandbox/template/peerserver/resolve.go +++ b/packages/orchestrator/pkg/sandbox/template/peerserver/resolve.go @@ -18,9 +18,10 @@ var ErrUnknownFile = errors.New("unknown file") // Returns ErrNotAvailable when the build is not in the local cache. // Returns ErrUnknownFile for unrecognised file names. func ResolveSeekable(cache Cache, buildID, fileName string) (SeekableSource, error) { - switch fileName { + stripped := storage.StripCompression(fileName) + switch stripped { case storage.MemfileName, storage.RootfsName: - diff, ok := cache.LookupDiff(buildID, build.DiffType(fileName)) + diff, ok := cache.LookupDiff(buildID, build.DiffType(stripped)) if !ok { return nil, ErrNotAvailable } diff --git a/packages/orchestrator/pkg/sandbox/template/peerserver/seekable.go b/packages/orchestrator/pkg/sandbox/template/peerserver/seekable.go index 319b9d3c99..a40e83bb30 100644 --- a/packages/orchestrator/pkg/sandbox/template/peerserver/seekable.go +++ b/packages/orchestrator/pkg/sandbox/template/peerserver/seekable.go @@ -33,7 +33,8 @@ func (f *seekableSource) Stream(ctx context.Context, offset, length int64, sende )) defer span.End() - data, err := f.diff.Slice(ctx, offset, length) + // P2P always serves uncompressed bytes — pass nil FrameTable. + data, err := f.diff.Slice(ctx, offset, length, nil) if err != nil { span.RecordError(err) diff --git a/packages/orchestrator/pkg/sandbox/template/peerserver/seekable_test.go b/packages/orchestrator/pkg/sandbox/template/peerserver/seekable_test.go index 66724591bd..c883ea2cd4 100644 --- a/packages/orchestrator/pkg/sandbox/template/peerserver/seekable_test.go +++ b/packages/orchestrator/pkg/sandbox/template/peerserver/seekable_test.go @@ -36,7 +36,7 @@ func TestSeekableSource_Stream(t *testing.T) { data := []byte("diff bytes") diff := buildmocks.NewMockDiff(t) - diff.EXPECT().Slice(mock.Anything, int64(0), int64(len(data))).Return(data, nil) + diff.EXPECT().Slice(mock.Anything, int64(0), int64(len(data)), (*storage.FrameTable)(nil)).Return(data, nil) diff.EXPECT().BlockSize().Return(int64(len(data))) cache := peerservermocks.NewMockCache(t) diff --git a/packages/orchestrator/pkg/sandbox/template/storage.go b/packages/orchestrator/pkg/sandbox/template/storage.go index 605f00686d..3e0308e178 100644 --- a/packages/orchestrator/pkg/sandbox/template/storage.go +++ b/packages/orchestrator/pkg/sandbox/template/storage.go @@ -19,21 +19,9 @@ const ( ) type Storage struct { - header *header.Header source *build.File } -func storageHeaderObjectType(diffType build.DiffType) (storage.ObjectType, bool) { - switch diffType { - case build.Memfile: - return storage.MemfileHeaderObjectType, true - case build.Rootfs: - return storage.RootFSHeaderObjectType, true - default: - return storage.UnknownObjectType, false - } -} - func objectType(diffType build.DiffType) (storage.SeekableObjectType, bool) { switch diffType { case build.Memfile: @@ -57,11 +45,6 @@ func NewStorage( paths := storage.Paths{BuildID: buildId} if h == nil { - headerObjectType, ok := storageHeaderObjectType(fileType) - if !ok { - return nil, build.UnknownDiffTypeError{DiffType: fileType} - } - var hdrPath string switch fileType { case build.Memfile: @@ -72,20 +55,10 @@ func NewStorage( return nil, build.UnknownDiffTypeError{DiffType: fileType} } - headerObject, err := persistence.OpenBlob(ctx, hdrPath, headerObjectType) - if err != nil { - return nil, err - } - - diffHeader, err := header.Deserialize(ctx, headerObject) - - // If we can't find the diff header in storage, we switch to templates without a headers + var err error + h, err = header.LoadHeader(ctx, persistence, hdrPath) if err != nil && !errors.Is(err, storage.ErrObjectNotExist) { - return nil, fmt.Errorf("failed to deserialize header: %w", err) - } - - if err == nil { - h = diffHeader + return nil, err } } @@ -151,7 +124,6 @@ func NewStorage( return &Storage{ source: b, - header: h, }, nil } @@ -160,11 +132,11 @@ func (d *Storage) ReadAt(ctx context.Context, p []byte, off int64) (int, error) } func (d *Storage) Size(_ context.Context) (int64, error) { - return int64(d.header.Metadata.Size), nil + return int64(d.source.Header().Metadata.Size), nil } func (d *Storage) BlockSize() int64 { - return int64(d.header.Metadata.BlockSize) + return int64(d.source.Header().Metadata.BlockSize) } func (d *Storage) Slice(ctx context.Context, off, length int64) ([]byte, error) { @@ -172,7 +144,11 @@ func (d *Storage) Slice(ctx context.Context, off, length int64) ([]byte, error) } func (d *Storage) Header() *header.Header { - return d.header + return d.source.Header() +} + +func (d *Storage) SwapHeader(h *header.Header) { + d.source.SwapHeader(h) } func (d *Storage) Close() error { diff --git a/packages/orchestrator/pkg/sandbox/template_build.go b/packages/orchestrator/pkg/sandbox/template_build.go deleted file mode 100644 index efca1f2f0c..0000000000 --- a/packages/orchestrator/pkg/sandbox/template_build.go +++ /dev/null @@ -1,216 +0,0 @@ -package sandbox - -import ( - "context" - "fmt" - "io" - "os" - - "golang.org/x/sync/errgroup" - - "github.com/e2b-dev/infra/packages/shared/pkg/storage" - headers "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" -) - -type TemplateBuild struct { - paths storage.Paths - persistence storage.StorageProvider - objectMetadata storage.ObjectMetadata - - memfileHeader *headers.Header - rootfsHeader *headers.Header -} - -func NewTemplateBuild( - memfileHeader *headers.Header, - rootfsHeader *headers.Header, - persistence storage.StorageProvider, - paths storage.Paths, - objectMetadata storage.ObjectMetadata, -) *TemplateBuild { - return &TemplateBuild{ - persistence: persistence, - paths: paths, - objectMetadata: objectMetadata, - - memfileHeader: memfileHeader, - rootfsHeader: rootfsHeader, - } -} - -func (t *TemplateBuild) Remove(ctx context.Context) error { - err := t.persistence.DeleteObjectsWithPrefix(ctx, t.paths.StorageDir()) - if err != nil { - return fmt.Errorf("error when removing template build '%s': %w", t.paths.StorageDir(), err) - } - - return nil -} - -func (t *TemplateBuild) putOpts() []storage.PutOption { - if len(t.objectMetadata) == 0 { - return nil - } - - return []storage.PutOption{storage.WithMetadata(t.objectMetadata)} -} - -func (t *TemplateBuild) uploadMemfileHeader(ctx context.Context, h *headers.Header) error { - object, err := t.persistence.OpenBlob(ctx, t.paths.MemfileHeader(), storage.MemfileHeaderObjectType) - if err != nil { - return err - } - - serialized, err := headers.Serialize(h.Metadata, h.Mapping) - if err != nil { - return fmt.Errorf("error when serializing memfile header: %w", err) - } - - err = object.Put(ctx, serialized, t.putOpts()...) - if err != nil { - return fmt.Errorf("error when uploading memfile header: %w", err) - } - - return nil -} - -func (t *TemplateBuild) uploadMemfile(ctx context.Context, memfilePath string) error { - object, err := t.persistence.OpenSeekable(ctx, t.paths.Memfile(), storage.MemfileObjectType) - if err != nil { - return err - } - - err = object.StoreFile(ctx, memfilePath, t.putOpts()...) - if err != nil { - return fmt.Errorf("error when uploading memfile: %w", err) - } - - return nil -} - -func (t *TemplateBuild) uploadRootfsHeader(ctx context.Context, h *headers.Header) error { - object, err := t.persistence.OpenBlob(ctx, t.paths.RootfsHeader(), storage.RootFSHeaderObjectType) - if err != nil { - return err - } - - serialized, err := headers.Serialize(h.Metadata, h.Mapping) - if err != nil { - return fmt.Errorf("error when serializing memfile header: %w", err) - } - - err = object.Put(ctx, serialized, t.putOpts()...) - if err != nil { - return fmt.Errorf("error when uploading memfile header: %w", err) - } - - return nil -} - -func (t *TemplateBuild) uploadRootfs(ctx context.Context, rootfsPath string) error { - object, err := t.persistence.OpenSeekable(ctx, t.paths.Rootfs(), storage.RootFSObjectType) - if err != nil { - return err - } - - err = object.StoreFile(ctx, rootfsPath, t.putOpts()...) - if err != nil { - return fmt.Errorf("error when uploading rootfs: %w", err) - } - - return nil -} - -// Snap-file is small enough so we don't use composite upload. -func (t *TemplateBuild) uploadSnapfile(ctx context.Context, path string) error { - object, err := t.persistence.OpenBlob(ctx, t.paths.Snapfile(), storage.SnapfileObjectType) - if err != nil { - return err - } - - if err = uploadFileAsBlob(ctx, object, path, t.putOpts()...); err != nil { - return fmt.Errorf("error when uploading snapfile: %w", err) - } - - return nil -} - -// Metadata is small enough so we don't use composite upload. -func (t *TemplateBuild) uploadMetadata(ctx context.Context, path string) error { - object, err := t.persistence.OpenBlob(ctx, t.paths.Metadata(), storage.MetadataObjectType) - if err != nil { - return err - } - - if err := uploadFileAsBlob(ctx, object, path, t.putOpts()...); err != nil { - return fmt.Errorf("error when uploading metadata: %w", err) - } - - return nil -} - -func uploadFileAsBlob(ctx context.Context, b storage.Blob, path string, opts ...storage.PutOption) error { - f, err := os.Open(path) - if err != nil { - return fmt.Errorf("failed to open file %s: %w", path, err) - } - defer f.Close() - - data, err := io.ReadAll(f) - if err != nil { - return fmt.Errorf("failed to read file %s: %w", path, err) - } - - err = b.Put(ctx, data, opts...) - if err != nil { - return fmt.Errorf("failed to write data to object: %w", err) - } - - return nil -} - -func (t *TemplateBuild) Upload(ctx context.Context, metadataPath string, fcSnapfilePath string, memfilePath *string, rootfsPath *string) error { - eg, ctx := errgroup.WithContext(ctx) - - eg.Go(func() error { - if t.memfileHeader == nil { - return nil - } - - return t.uploadMemfileHeader(ctx, t.memfileHeader) - }) - - eg.Go(func() error { - if t.rootfsHeader == nil { - return nil - } - - return t.uploadRootfsHeader(ctx, t.rootfsHeader) - }) - - eg.Go(func() error { - if rootfsPath == nil { - return nil - } - - return t.uploadRootfs(ctx, *rootfsPath) - }) - - eg.Go(func() error { - if memfilePath == nil { - return nil - } - - return t.uploadMemfile(ctx, *memfilePath) - }) - - eg.Go(func() error { - return t.uploadSnapfile(ctx, fcSnapfilePath) - }) - - eg.Go(func() error { - return t.uploadMetadata(ctx, metadataPath) - }) - - return eg.Wait() -} diff --git a/packages/orchestrator/pkg/sandbox/uploads.go b/packages/orchestrator/pkg/sandbox/uploads.go new file mode 100644 index 0000000000..57e8a869f7 --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/uploads.go @@ -0,0 +1,219 @@ +package sandbox + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/google/uuid" + "github.com/jellydator/ttlcache/v3" + "github.com/redis/go-redis/v9" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "go.uber.org/zap" + + "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/block" + "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/build" + "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/template" + "github.com/e2b-dev/infra/packages/shared/pkg/logger" + "github.com/e2b-dev/infra/packages/shared/pkg/storage" + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" + "github.com/e2b-dev/infra/packages/shared/pkg/telemetry" + "github.com/e2b-dev/infra/packages/shared/pkg/utils" +) + +var ( + errUploadInFlight = errors.New("upload already in flight for build") + ErrBuildNotInCache = errors.New("build not in template cache") +) + +const ( + futureTTL = 1 * time.Hour + + // refreshHeaderBudget bounds how long an upload Wait polls remote storage + // for a parent's V4 header. Crosses orchestrators: A may still be uploading + // on a remote orch when B's runV4 calls Wait(A) here. Matches the + // per-upload bound in server.uploadTimeout — anything longer means the + // parent's upload is itself stuck and would have failed on its own. + refreshHeaderBudget = 20 * time.Minute + + // uploadDoneChannelPrefix is the Redis pub/sub channel prefix for per-build + // upload-finished signals. Empty payload = success; non-empty = upload error. + uploadDoneChannelPrefix = "orchestrator.upload.done." // followed by buildID String +) + +type templateLookup interface { + GetCachedTemplate(buildID string) (template.Template, bool) +} + +// Uploads is the in-flight upload table. Each entry's future fires when its +// build's V4 header has been swapped, gating child layers that depend on it. +// +// Cross-orch coordination uses Redis pub/sub on per-build channels: the +// uploader publishes on Finish, consumers subscribe inside Wait while polling +// remote storage. The Redis client is optional — nil falls back to ticker-only +// polling. +type Uploads struct { + tc templateLookup + persistence storage.StorageProvider + redis redis.UniversalClient + + futures *ttlcache.Cache[uuid.UUID, *utils.ErrorOnce] +} + +func NewUploads(tc *template.Cache, persistence storage.StorageProvider, redisClient redis.UniversalClient) *Uploads { + futures := ttlcache.New( + ttlcache.WithTTL[uuid.UUID, *utils.ErrorOnce](futureTTL), + ) + go futures.Start() + + return &Uploads{tc: tc, persistence: persistence, redis: redisClient, futures: futures} +} + +func (u *Uploads) Stop() { + u.futures.Stop() +} + +// Start replaces a finished future at the same key; rejects an in-flight one. +// Build IDs are unique per upload so concurrent Starts for the same key are +// not expected — the in-flight check only guards against accidental misuse. +func (u *Uploads) Start(buildID uuid.UUID) (*utils.ErrorOnce, error) { + if existing := u.futures.Get(buildID); existing != nil { + select { + case <-existing.Value().Done(): + default: + return nil, fmt.Errorf("%w: %s", errUploadInFlight, buildID) + } + } + + fut := utils.NewErrorOnce() + u.futures.Set(buildID, fut, ttlcache.DefaultTTL) + + return fut, nil +} + +// Wait returns the parent's post-upload V4 header. Same-orch waits on the local +// future; cross-orch refreshes from remote storage when the locally-cached +// header is stale, optionally accelerated by a per-call Redis subscription. +func (u *Uploads) Wait(ctx context.Context, buildID uuid.UUID, t build.DiffType) (*header.Header, error) { + ctx, span := tracer.Start(ctx, "wait-for-parent-upload", trace.WithAttributes( + telemetry.WithBuildID(buildID.String()), + attribute.String("file_type", string(t)), + )) + defer span.End() + + if item := u.futures.Get(buildID); item != nil { + if err := item.Value().WaitWithContext(ctx); err != nil { + return nil, fmt.Errorf("wait for upload %s: %w", buildID, err) + } + } + + d, err := u.find(ctx, buildID, t) + if errors.Is(err, ErrBuildNotInCache) { + // Ancestor never resumed locally (typical for grand-grandparents + // reached via mappings). It's necessarily finalized — load directly + // from remote storage without an in-memory device or future to track. + hdrPath := storage.Paths{BuildID: buildID.String()}.HeaderFile(string(t)) + + return header.LoadHeader(ctx, u.persistence, hdrPath) + } + if err != nil { + return nil, err + } + + h := d.Header() + if h.IncompletePendingUpload { + // The only way we can still have an incomplete header at this point is + // the P2P path. We already waited on the local upload future and it did + // not finalize the header. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + h, err = build.PollRemoteStorageForHeader(ctx, u.persistence, buildID, t, u.subscribe(ctx, buildID), refreshHeaderBudget) + if err != nil { + return nil, err + } + d.SwapHeader(h) + } + + return h, nil +} + +func (u *Uploads) find(ctx context.Context, buildID uuid.UUID, t build.DiffType) (block.ReadonlyDevice, error) { + tpl, ok := u.tc.GetCachedTemplate(buildID.String()) + if !ok { + return nil, fmt.Errorf("build %s: %w", buildID, ErrBuildNotInCache) + } + + switch t { + case build.Memfile: + return tpl.Memfile(ctx) + case build.Rootfs: + return tpl.Rootfs() + default: + return nil, fmt.Errorf("unsupported file type: %s", t) + } +} + +// --- Cross-orch upload-done signaling (Redis pub/sub on per-build channels) --- + +func uploadDoneChannel(buildID uuid.UUID) string { + return uploadDoneChannelPrefix + buildID.String() +} + +// publishUploadDoneToRedis broadcasts an upload-finished signal so cross-orch waiters can stop +// polling. Best-effort; failures fall through to the ticker poll. Empty +// payload = success; non-empty = the upload error message. +func (u *Uploads) publishUploadDoneToRedis(ctx context.Context, buildID uuid.UUID, uploadErr error) { + if u.redis == nil { + return + } + + payload := "" + if uploadErr != nil { + payload = uploadErr.Error() + } + + if err := u.redis.Publish(ctx, uploadDoneChannel(buildID), payload).Err(); err != nil { + logger.L().Warn(ctx, "failed to publish upload-done signal", + logger.WithBuildID(buildID.String()), + zap.Error(err), + ) + } +} + +// subscribe opens a per-call SUBSCRIBE on buildID's upload-done channel and +// returns a channel that fires once with the upload outcome. The subscription +// is torn down when ctx cancels (caller must use a derived context). Returns +// a nil channel when Redis is not configured — nil channels never fire, so +// LoadV4 cleanly degrades to ticker-only polling. +func (u *Uploads) subscribe(ctx context.Context, buildID uuid.UUID) <-chan error { + if u.redis == nil { + return nil + } + + out := make(chan error, 1) + + go func() { + ps := u.redis.Subscribe(ctx, uploadDoneChannel(buildID)) + defer ps.Close() + + msg, err := ps.ReceiveMessage(ctx) + if err != nil { + return // ctx cancelled or connection error: silent (ticker covers) + } + + var uploadErr error + if msg.Payload != "" { + uploadErr = errors.New(msg.Payload) + } + + select { + case out <- uploadErr: + case <-ctx.Done(): + } + }() + + return out +} diff --git a/packages/orchestrator/pkg/sandbox/uploads_test.go b/packages/orchestrator/pkg/sandbox/uploads_test.go new file mode 100644 index 0000000000..d502364b33 --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/uploads_test.go @@ -0,0 +1,231 @@ +package sandbox + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/google/uuid" + "github.com/jellydator/ttlcache/v3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + blockmocks "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/block/mocks" + "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/build" + "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/template" + templatemocks "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/template/mocks" + headers "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" + "github.com/e2b-dev/infra/packages/shared/pkg/utils" +) + +type fakeCache struct { + mu sync.Mutex + m map[string]template.Template +} + +func newFakeCache() *fakeCache { + return &fakeCache{m: make(map[string]template.Template)} +} + +func (f *fakeCache) GetCachedTemplate(buildID string) (template.Template, bool) { + f.mu.Lock() + defer f.mu.Unlock() + t, ok := f.m[buildID] + + return t, ok +} + +func (f *fakeCache) put(buildID string, tpl template.Template) { + f.mu.Lock() + defer f.mu.Unlock() + f.m[buildID] = tpl +} + +func newUploads(t *testing.T) (*Uploads, *fakeCache) { + t.Helper() + cache := newFakeCache() + futures := ttlcache.New( + ttlcache.WithTTL[uuid.UUID, *utils.ErrorOnce](futureTTL), + ) + go futures.Start() + t.Cleanup(futures.Stop) + + return &Uploads{ + tc: cache, + futures: futures, + }, cache +} + +func putFinalHeader(t *testing.T, cache *fakeCache, buildID uuid.UUID, fileType build.DiffType) { + t.Helper() + tpl := templatemocks.NewMockTemplate(t) + dev := blockmocks.NewMockReadonlyDevice(t) + dev.EXPECT().Header().Return(&headers.Header{ + Metadata: &headers.Metadata{Version: headers.MetadataVersionV4}, + Builds: map[uuid.UUID]headers.BuildData{buildID: {}}, // self-entry → not stale + }).Maybe() + + switch fileType { + case build.Memfile: + tpl.EXPECT().Memfile(mock.Anything).Return(dev, nil).Maybe() + case build.Rootfs: + tpl.EXPECT().Rootfs().Return(dev, nil).Maybe() + } + + cache.put(buildID.String(), tpl) +} + +func TestUploads_BeginDistinctIDsAreIndependent(t *testing.T) { + t.Parallel() + c, _ := newUploads(t) + + a := uuid.New() + b := uuid.New() + + futA, err := c.Start(a) + require.NoError(t, err) + futB, err := c.Start(b) + require.NoError(t, err) + + require.NotSame(t, futA, futB) + require.NoError(t, futA.SetSuccess()) + + select { + case <-futB.Done(): + t.Fatal("futB should not be done after only futA fires") + default: + } +} + +func TestUploads_Wait_BlocksUntilSet(t *testing.T) { + t.Parallel() + c, cache := newUploads(t) + + id := uuid.New() + putFinalHeader(t, cache, id, build.Memfile) + fut, err := c.Start(id) + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + _, _ = c.Wait(context.Background(), id, build.Memfile) + close(done) + }() + + select { + case <-done: + t.Fatal("Wait should block until the future fires") + case <-time.After(50 * time.Millisecond): + } + + require.NoError(t, fut.SetSuccess()) + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("Wait should return after future fires") + } +} + +func TestUploads_Wait_PropagatesUploadError(t *testing.T) { + t.Parallel() + c, cache := newUploads(t) + + id := uuid.New() + putFinalHeader(t, cache, id, build.Memfile) + fut, err := c.Start(id) + require.NoError(t, err) + + uploadErr := errors.New("upload exploded") + require.NoError(t, fut.SetError(uploadErr)) + + _, err = c.Wait(context.Background(), id, build.Memfile) + require.ErrorIs(t, err, uploadErr) +} + +func TestUploads_Wait_ContextCancellation(t *testing.T) { + t.Parallel() + c, _ := newUploads(t) + + id := uuid.New() + _, err := c.Start(id) // never signaled + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + + errCh := make(chan error, 1) + go func() { + _, err := c.Wait(ctx, id, build.Memfile) + errCh <- err + }() + + cancel() + + select { + case err := <-errCh: + require.ErrorIs(t, err, context.Canceled) + case <-time.After(time.Second): + t.Fatal("Wait should return on context cancel") + } +} + +func TestUploads_Wait_NoFuture_ReadsFromCache(t *testing.T) { + t.Parallel() + c, cache := newUploads(t) + + id := uuid.New() + want := &headers.Header{ + Metadata: &headers.Metadata{Version: headers.MetadataVersionV4}, + Builds: map[uuid.UUID]headers.BuildData{id: {}}, + } + + tpl := templatemocks.NewMockTemplate(t) + dev := blockmocks.NewMockReadonlyDevice(t) + dev.EXPECT().Header().Return(want) + tpl.EXPECT().Rootfs().Return(dev, nil) + cache.put(id.String(), tpl) + + got, err := c.Wait(context.Background(), id, build.Rootfs) + require.NoError(t, err) + require.Same(t, want, got) +} + +func TestUploads_ConcurrentBeginsAndWaits(t *testing.T) { + t.Parallel() + c, cache := newUploads(t) + + const n = 10 + + ids := make([]uuid.UUID, n) + futs := make([]*utils.ErrorOnce, n) + for i := range n { + ids[i] = uuid.New() + putFinalHeader(t, cache, ids[i], build.Memfile) + fut, err := c.Start(ids[i]) + require.NoError(t, err) + futs[i] = fut + } + + var done atomic.Int32 + var wg sync.WaitGroup + for i := range n { + wg.Add(1) + go func(i int) { + defer wg.Done() + if _, err := c.Wait(context.Background(), ids[i], build.Memfile); err == nil { + done.Add(1) + } + }(i) + } + + for i := range n { + require.NoError(t, futs[i].SetSuccess()) + } + + wg.Wait() + assert.Equal(t, int32(n), done.Load()) +} diff --git a/packages/orchestrator/pkg/server/main.go b/packages/orchestrator/pkg/server/main.go index 68277c3660..f483ce051c 100644 --- a/packages/orchestrator/pkg/server/main.go +++ b/packages/orchestrator/pkg/server/main.go @@ -45,6 +45,7 @@ type Server struct { startingSandboxes *semaphore.Weighted peerRegistry peerclient.Registry uploadedBuilds *ttlcache.Cache[string, struct{}] + uploads *sandbox.Uploads sandboxCreateDuration metric.Int64Histogram } @@ -61,6 +62,7 @@ type ServiceConfig struct { FeatureFlags *featureflags.Client SbxEventsService *events.EventsService PeerRegistry peerclient.Registry + Uploads *sandbox.Uploads } func New(cfg ServiceConfig) (*Server, error) { @@ -83,6 +85,7 @@ func New(cfg ServiceConfig) (*Server, error) { startingSandboxes: semaphore.NewWeighted(maxStartingInstancesPerNode), peerRegistry: cfg.PeerRegistry, uploadedBuilds: uploadedBuilds, + uploads: cfg.Uploads, } meter := cfg.Tel.MeterProvider.Meter("github.com/e2b-dev/infra/packages/orchestrator/pkg/server") @@ -104,3 +107,9 @@ func New(cfg ServiceConfig) (*Server, error) { return server, nil } + +func (s *Server) Close() error { + s.uploadedBuilds.Stop() + + return nil +} diff --git a/packages/orchestrator/pkg/server/sandboxes.go b/packages/orchestrator/pkg/server/sandboxes.go index 46da167d48..725008676f 100644 --- a/packages/orchestrator/pkg/server/sandboxes.go +++ b/packages/orchestrator/pkg/server/sandboxes.go @@ -45,7 +45,8 @@ const ( acquireTimeout = 15 * time.Second maxStartingInstancesPerNode = 3 - // uploadTimeout is the max time allowed for uploading snapshot files to GCS. + // uploadTimeout is the max time allowed for uploading snapshot files to + // remote storage. uploadTimeout = 20 * time.Minute // redisPeerKeyTTL is slightly longer than uploadTimeout so the key is still // valid for the entire upload window before being cleaned up. @@ -665,9 +666,11 @@ func (s *Server) Checkpoint(ctx context.Context, in *orchestrator.SandboxCheckpo // be paused or resumed later. uploadCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), uploadTimeout) defer cancel() - defer res.completeUpload(uploadCtx) - if err := res.snapshot.Upload(uploadCtx, s.persistence, res.paths, res.objectMetadata); err != nil { + err := res.upload.Run(uploadCtx) + defer res.completeUpload(uploadCtx, err) + + if err != nil { telemetry.ReportCriticalError(ctx, "error uploading snapshot for checkpoint", err, telemetry.WithSandboxID(in.GetSandboxId())) s.sandboxFactory.Sandboxes.MarkStopping(ctx, resumedSbx.Runtime.SandboxID, resumedSbx.LifecycleID) @@ -715,19 +718,17 @@ func (s *Server) getSandboxExecutionData(sbx *sandbox.Sandbox) map[string]any { } } -// snapshotResult holds the data produced by snapshotAndCacheSandbox that callers -// need to start the background GCS upload. +// snapshotResult holds the data produced by snapshotAndCacheSandbox that +// callers need to start the background remote storage upload. type snapshotResult struct { meta metadata.Template - snapshot *sandbox.Snapshot - paths storage.Paths - objectMetadata storage.ObjectMetadata - completeUpload func(ctx context.Context) + upload *sandbox.Upload + completeUpload func(ctx context.Context, uploadErr error) } -// snapshotAndCacheSandbox creates a snapshot of a sandbox and adds it to the local -// template cache. The caller is responsible for starting the GCS upload via -// startSnapshotUploadAsync or uploadSnapshotWithPrefetchAsync. +// snapshotAndCacheSandbox creates a snapshot of a sandbox and adds it to the +// local template cache. The caller is responsible for starting the remote +// storage upload via uploadSnapshotAsync. func (s *Server) snapshotAndCacheSandbox( ctx context.Context, sbx *sandbox.Sandbox, @@ -763,65 +764,70 @@ func (s *Server) snapshotAndCacheSandbox( return nil, fmt.Errorf("error adding snapshot to template cache: %w", err) } - telemetry.ReportEvent(ctx, "added snapshot to template cache") - - paths := storage.Paths{BuildID: meta.Template.BuildID} objectMetadata := storage.ObjectMetadata{ storage.ObjectMetadataTeamID: sbx.Runtime.TeamID, } - // Register in Redis so other orchestrators can find us for peer routing. - if s.featureFlags.BoolFlag(ctx, featureflags.PeerToPeerChunkTransferFlag) { - if err := s.peerRegistry.Register(ctx, meta.Template.BuildID, redisPeerKeyTTL); err != nil { - logger.L().Warn(ctx, "failed to register peer address for routing", zap.String("build_id", meta.Template.BuildID), zap.Error(err)) + // Register the upload only after the snapshot is in the local cache, so a + // failed AddSnapshot doesn't leave an orphan future blocking re-registration. + upload, err := sandbox.NewUpload(ctx, s.uploads, snapshot, s.persistence, s.config.StorageConfig.CompressConfig, s.featureFlags, storage.UseCasePause, objectMetadata) + if err != nil { + return nil, fmt.Errorf("register upload: %w", err) + } + + telemetry.ReportEvent(ctx, "added snapshot to template cache") + + // Capture once so Register and the symmetric Unregister inside + // completeUpload don't drift if the flag flips mid-upload. + peerEnabled := s.featureFlags.BoolFlag(ctx, featureflags.PeerToPeerChunkTransferFlag) + + completeUpload := func(ctx context.Context, uploadErr error) { + upload.Finish(ctx, uploadErr) + + if !peerEnabled { + return } - completeUpload := func(ctx context.Context) { - // Signal in-flight peer streams to switch to GCS. - s.uploadedBuilds.Set(meta.Template.BuildID, struct{}{}, ttlcache.DefaultTTL) + s.uploadedBuilds.Set(meta.Template.BuildID, struct{}{}, ttlcache.DefaultTTL) - // Remove from Redis so new nodes go directly to GCS. - if err := s.peerRegistry.Unregister(ctx, meta.Template.BuildID); err != nil { - logger.L().Warn(ctx, "failed to unregister peer address from routing", zap.String("build_id", meta.Template.BuildID), zap.Error(err)) - } + if err := s.peerRegistry.Unregister(ctx, meta.Template.BuildID); err != nil { + logger.L().Warn(ctx, "failed to unregister peer address from routing", zap.String("build_id", meta.Template.BuildID), zap.Error(err)) } + } - return &snapshotResult{ - meta: meta, - snapshot: snapshot, - paths: paths, - objectMetadata: objectMetadata, - completeUpload: completeUpload, - }, nil + if peerEnabled { + if err := s.peerRegistry.Register(ctx, meta.Template.BuildID, redisPeerKeyTTL); err != nil { + logger.L().Warn(ctx, "failed to register peer address for routing", zap.String("build_id", meta.Template.BuildID), zap.Error(err)) + } } return &snapshotResult{ meta: meta, - snapshot: snapshot, - paths: paths, - objectMetadata: objectMetadata, - completeUpload: func(context.Context) {}, + upload: upload, + completeUpload: completeUpload, }, nil } -// uploadSnapshotAsync uploads snapshot files to GCS in the background and -// cleans up the Redis peer key once done. Used by the Pause handler where no -// prefetch data is available. +// uploadSnapshotAsync uploads snapshot files to remote storage in the +// background and cleans up the Redis peer key once done. Used by the Pause +// handler where no prefetch data is available. func (s *Server) uploadSnapshotAsync(ctx context.Context, sbx *sandbox.Sandbox, res *snapshotResult) { ctx, cancel := context.WithTimeout(context.WithoutCancel(ctx), uploadTimeout) go func() { defer cancel() - defer res.completeUpload(ctx) - err := res.snapshot.Upload(ctx, s.persistence, res.paths, res.objectMetadata) + ctx, span := tracer.Start(ctx, "upload snapshot") + defer span.End() + + err := res.upload.Run(ctx) if err != nil { sbxlogger.I(sbx).Error(ctx, "error uploading snapshot files", zap.Error(err)) - - return + } else { + sbxlogger.I(sbx).Info(ctx, "snapshot finished uploading successfully") } - sbxlogger.E(sbx).Info(ctx, "Snapshot files uploaded to GCS") + res.completeUpload(ctx, err) }() } diff --git a/packages/orchestrator/pkg/template/build/builder.go b/packages/orchestrator/pkg/template/build/builder.go index 3390ccd8cc..425ab7501c 100644 --- a/packages/orchestrator/pkg/template/build/builder.go +++ b/packages/orchestrator/pkg/template/build/builder.go @@ -61,6 +61,7 @@ type Builder struct { templateCache *sbxtemplate.Cache metrics *metrics.BuildMetrics featureFlags *featureflags.Client + uploads *sandbox.Uploads } func NewBuilder( @@ -76,6 +77,7 @@ func NewBuilder( sandboxes *sandbox.Map, templateCache *sbxtemplate.Cache, buildMetrics *metrics.BuildMetrics, + uploads *sandbox.Uploads, ) *Builder { return &Builder{ config: config, @@ -90,6 +92,7 @@ func NewBuilder( sandboxes: sandboxes, templateCache: templateCache, metrics: buildMetrics, + uploads: uploads, } } @@ -259,8 +262,6 @@ func runBuild( index := cache.NewHashIndex(bc.CacheScope, builder.buildStorage, templateStorage) - uploadTracker := layer.NewUploadTracker() - layerExecutor := layer.NewLayerExecutor( bc, builder.logger, @@ -270,7 +271,9 @@ func runBuild( templateStorage, builder.buildStorage, index, - uploadTracker, + builder.uploads, + builder.config.StorageConfig.CompressConfig, + builder.featureFlags, ) baseBuilder := base.New( diff --git a/packages/orchestrator/pkg/template/build/layer/layer_executor.go b/packages/orchestrator/pkg/template/build/layer/layer_executor.go index 45772476d0..bd2f2426be 100644 --- a/packages/orchestrator/pkg/template/build/layer/layer_executor.go +++ b/packages/orchestrator/pkg/template/build/layer/layer_executor.go @@ -16,6 +16,7 @@ import ( "github.com/e2b-dev/infra/packages/orchestrator/pkg/template/build/sandboxtools" "github.com/e2b-dev/infra/packages/orchestrator/pkg/template/build/storage/cache" "github.com/e2b-dev/infra/packages/orchestrator/pkg/template/metadata" + "github.com/e2b-dev/infra/packages/shared/pkg/featureflags" "github.com/e2b-dev/infra/packages/shared/pkg/logger" "github.com/e2b-dev/infra/packages/shared/pkg/storage" ) @@ -33,7 +34,9 @@ type LayerExecutor struct { templateStorage storage.StorageProvider buildStorage storage.StorageProvider index cache.Index - uploadTracker *UploadTracker + uploads *sandbox.Uploads + compressConfig storage.CompressConfig + ff *featureflags.Client } func NewLayerExecutor( @@ -45,7 +48,9 @@ func NewLayerExecutor( templateStorage storage.StorageProvider, buildStorage storage.StorageProvider, index cache.Index, - uploadTracker *UploadTracker, + uploads *sandbox.Uploads, + compressConfig storage.CompressConfig, + ff *featureflags.Client, ) *LayerExecutor { return &LayerExecutor{ BuildContext: buildContext, @@ -58,7 +63,9 @@ func NewLayerExecutor( templateStorage: templateStorage, buildStorage: buildStorage, index: index, - uploadTracker: uploadTracker, + uploads: uploads, + compressConfig: compressConfig, + ff: ff, } } @@ -276,45 +283,32 @@ func (lb *LayerExecutor) PauseAndUpload( // Upload snapshot async, it's added to the template cache immediately userLogger.Debug(ctx, fmt.Sprintf("Saving: %s", meta.Template.BuildID)) - // Register this upload and get functions to signal completion and wait for previous uploads - completeUpload, waitForPreviousUploads := lb.uploadTracker.StartUpload() + objectMetadata := storage.ObjectMetadata{ + storage.ObjectMetadataTeamID: lb.BuildContext.Config.TeamID, + } + + upload, err := sandbox.NewUpload(ctx, lb.uploads, snapshot, lb.templateStorage, lb.compressConfig, lb.ff, storage.UseCaseBuild, objectMetadata) + if err != nil { + return fmt.Errorf("register upload: %w", err) + } - lb.UploadErrGroup.Go(func() error { + lb.UploadErrGroup.Go(func() (uploadErr error) { ctx := context.WithoutCancel(ctx) ctx, span := tracer.Start(ctx, "upload snapshot") defer span.End() - // Always signal completion to unblock waiting goroutines, even on error. - // This prevents deadlocks when an earlier layer fails - later layers can - // still unblock and the errgroup can properly collect all errors. - defer completeUpload() + // Signal even on error so child layers waiting on this build can abort. + defer func() { upload.Finish(ctx, uploadErr) }() - err := snapshot.Upload( - ctx, - lb.templateStorage, - storage.Paths{BuildID: meta.Template.BuildID}, - storage.ObjectMetadata{ - storage.ObjectMetadataTeamID: lb.BuildContext.Config.TeamID, - }, - ) - if err != nil { + if err := upload.Run(ctx); err != nil { return fmt.Errorf("error uploading snapshot: %w", err) } - // Wait for all previous layer uploads to complete before saving the cache entry. - // This prevents race conditions where another build hits this cache entry - // before its dependencies (previous layers) are available in storage. - err = waitForPreviousUploads(ctx) - if err != nil { - return fmt.Errorf("error waiting for previous uploads: %w", err) - } - - err = lb.index.SaveLayerMeta(ctx, hash, cache.LayerMetadata{ + if err := lb.index.SaveLayerMeta(ctx, hash, cache.LayerMetadata{ Template: cache.Template{ BuildID: meta.Template.BuildID, }, - }) - if err != nil { + }); err != nil { // Since the data should be basically identical, this is safe to skip. if !errors.Is(err, storage.ErrObjectRateLimited) { return fmt.Errorf("error saving UUID to hash mapping: %w", err) diff --git a/packages/orchestrator/pkg/template/build/layer/upload_tracker.go b/packages/orchestrator/pkg/template/build/layer/upload_tracker.go deleted file mode 100644 index 213938f147..0000000000 --- a/packages/orchestrator/pkg/template/build/layer/upload_tracker.go +++ /dev/null @@ -1,54 +0,0 @@ -package layer - -import ( - "context" - "sync" -) - -// UploadTracker tracks in-flight uploads and allows waiting for all previous uploads to complete. -// This prevents race conditions where a layer's cache entry is saved before its -// dependencies (previous layers) are fully uploaded. -type UploadTracker struct { - mu sync.Mutex - waitChs []chan struct{} -} - -func NewUploadTracker() *UploadTracker { - return &UploadTracker{ - waitChs: make([]chan struct{}, 0), - } -} - -// StartUpload registers that a new upload has started. -// Returns a function that should be called when the upload completes. -func (t *UploadTracker) StartUpload() (complete func(), waitForPrevious func(context.Context) error) { - t.mu.Lock() - defer t.mu.Unlock() - - // Create a channel for this upload - ch := make(chan struct{}) - t.waitChs = append(t.waitChs, ch) - - // Capture the channels we need to wait for (all previous uploads) - previousChs := make([]chan struct{}, len(t.waitChs)-1) - copy(previousChs, t.waitChs[:len(t.waitChs)-1]) - - complete = func() { - close(ch) - } - - waitForPrevious = func(ctx context.Context) error { - for _, prevCh := range previousChs { - select { - case <-prevCh: - // Previous upload completed - case <-ctx.Done(): - return ctx.Err() - } - } - - return nil - } - - return complete, waitForPrevious -} diff --git a/packages/orchestrator/pkg/template/build/layer/upload_tracker_test.go b/packages/orchestrator/pkg/template/build/layer/upload_tracker_test.go deleted file mode 100644 index 8b0923c6cb..0000000000 --- a/packages/orchestrator/pkg/template/build/layer/upload_tracker_test.go +++ /dev/null @@ -1,281 +0,0 @@ -package layer - -import ( - "context" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestUploadTracker_SingleUpload(t *testing.T) { - t.Parallel() - - tracker := NewUploadTracker() - - complete, waitForPrevious := tracker.StartUpload() - - // First upload has no previous uploads to wait for - ctx := context.Background() - err := waitForPrevious(ctx) - require.NoError(t, err) - - complete() -} - -func TestUploadTracker_SequentialUploads(t *testing.T) { - t.Parallel() - - tracker := NewUploadTracker() - - // Start first upload - complete1, waitForPrevious1 := tracker.StartUpload() - - // Start second upload - complete2, waitForPrevious2 := tracker.StartUpload() - - // Start third upload - complete3, waitForPrevious3 := tracker.StartUpload() - - ctx := context.Background() - - // First upload has no dependencies - err := waitForPrevious1(ctx) - require.NoError(t, err) - complete1() - - // Second upload waits for first - err = waitForPrevious2(ctx) - require.NoError(t, err) - complete2() - - // Third upload waits for first and second - err = waitForPrevious3(ctx) - require.NoError(t, err) - complete3() -} - -func TestUploadTracker_WaitBlocksUntilComplete(t *testing.T) { - t.Parallel() - - tracker := NewUploadTracker() - - // Start first upload - complete1, _ := tracker.StartUpload() - - // Start second upload - _, waitForPrevious2 := tracker.StartUpload() - - // Second upload should block until first completes - done := make(chan struct{}) - go func() { - ctx := context.Background() - _ = waitForPrevious2(ctx) - close(done) - }() - - // Should not complete immediately - select { - case <-done: - t.Fatal("waitForPrevious should have blocked") - case <-time.After(50 * time.Millisecond): - // Expected - still waiting - } - - // Complete first upload - complete1() - - // Now second should complete - select { - case <-done: - // Expected - case <-time.After(time.Second): - t.Fatal("waitForPrevious should have completed after first upload finished") - } -} - -func TestUploadTracker_ContextCancellation(t *testing.T) { - t.Parallel() - - tracker := NewUploadTracker() - - // Start first upload (don't complete it) - _, _ = tracker.StartUpload() - - // Start second upload - _, waitForPrevious2 := tracker.StartUpload() - - // Create a cancellable context - ctx, cancel := context.WithCancel(context.Background()) - - // Start waiting in a goroutine - errCh := make(chan error, 1) - go func() { - errCh <- waitForPrevious2(ctx) - }() - - // Cancel the context - cancel() - - // Should return context error - select { - case err := <-errCh: - require.ErrorIs(t, err, context.Canceled) - case <-time.After(time.Second): - t.Fatal("waitForPrevious should have returned after context cancellation") - } -} - -func TestUploadTracker_ConcurrentUploads(t *testing.T) { - t.Parallel() - - tracker := NewUploadTracker() - - const numUploads = 10 - var completeFuncs []func() - var waitFuncs []func(context.Context) error - - // Start all uploads - for range numUploads { - complete, wait := tracker.StartUpload() - completeFuncs = append(completeFuncs, complete) - waitFuncs = append(waitFuncs, wait) - } - - // Track completion order and errors - var completionOrder []int - var mu sync.Mutex - var wg sync.WaitGroup - errCh := make(chan error, numUploads) - - // Start all waits concurrently - for i := range numUploads { - wg.Add(1) - go func(idx int) { - defer wg.Done() - ctx := context.Background() - err := waitFuncs[idx](ctx) - if err != nil { - errCh <- err - - return - } - - mu.Lock() - completionOrder = append(completionOrder, idx) - mu.Unlock() - }(i) - } - - // Complete uploads in order - for i := range numUploads { - completeFuncs[i]() - // Small delay to allow goroutines to process - time.Sleep(10 * time.Millisecond) - } - - wg.Wait() - close(errCh) - - // Check for errors - for err := range errCh { - require.NoError(t, err) - } - - // Verify all completed - assert.Len(t, completionOrder, numUploads) -} - -func TestUploadTracker_OutOfOrderCompletion(t *testing.T) { - t.Parallel() - - tracker := NewUploadTracker() - - // Start three uploads - complete1, waitForPrevious1 := tracker.StartUpload() - complete2, waitForPrevious2 := tracker.StartUpload() - complete3, waitForPrevious3 := tracker.StartUpload() - - ctx := context.Background() - - // Track when each wait completes - var wait1Done, wait2Done, wait3Done atomic.Bool - - var wg sync.WaitGroup - wg.Add(3) - - go func() { - defer wg.Done() - _ = waitForPrevious1(ctx) - wait1Done.Store(true) - }() - - go func() { - defer wg.Done() - _ = waitForPrevious2(ctx) - wait2Done.Store(true) - }() - - go func() { - defer wg.Done() - _ = waitForPrevious3(ctx) - wait3Done.Store(true) - }() - - // Wait 1 should complete immediately (no dependencies) - time.Sleep(50 * time.Millisecond) - assert.True(t, wait1Done.Load(), "wait1 should complete immediately") - assert.False(t, wait2Done.Load(), "wait2 should still be waiting") - assert.False(t, wait3Done.Load(), "wait3 should still be waiting") - - // Complete upload 1 - complete1() - time.Sleep(50 * time.Millisecond) - - // Wait 2 should now complete - assert.True(t, wait2Done.Load(), "wait2 should complete after upload1") - assert.False(t, wait3Done.Load(), "wait3 should still be waiting for upload2") - - // Complete upload 2 - complete2() - time.Sleep(50 * time.Millisecond) - - // Wait 3 should now complete - assert.True(t, wait3Done.Load(), "wait3 should complete after upload2") - - // Complete upload 3 for cleanup - complete3() - - wg.Wait() -} - -func TestUploadTracker_CompleteBeforeWait(t *testing.T) { - t.Parallel() - - tracker := NewUploadTracker() - - // Start and complete first upload before second even starts waiting - complete1, _ := tracker.StartUpload() - complete1() - - // Start second upload - _, waitForPrevious2 := tracker.StartUpload() - - // Should not block since first is already complete - ctx := context.Background() - done := make(chan struct{}) - go func() { - _ = waitForPrevious2(ctx) - close(done) - }() - - select { - case <-done: - // Expected - should complete immediately - case <-time.After(time.Second): - t.Fatal("waitForPrevious should have completed immediately since previous upload is done") - } -} diff --git a/packages/orchestrator/pkg/template/server/main.go b/packages/orchestrator/pkg/template/server/main.go index 30e240eb0c..1ece7defda 100644 --- a/packages/orchestrator/pkg/template/server/main.go +++ b/packages/orchestrator/pkg/template/server/main.go @@ -61,6 +61,7 @@ func New( templateCache *sbxtemplate.Cache, templatePersistence storage.StorageProvider, buildPersistence storage.StorageProvider, + uploads *sandbox.Uploads, ) (s *ServerStore, e error) { logger.Info(ctx, "Initializing template manager") @@ -107,6 +108,7 @@ func New( sandboxFactory.Sandboxes, templateCache, buildMetrics, + uploads, ) store := &ServerStore{ diff --git a/packages/shared/go.mod b/packages/shared/go.mod index 7b8570ae42..750963c991 100644 --- a/packages/shared/go.mod +++ b/packages/shared/go.mod @@ -31,11 +31,13 @@ require ( github.com/hashicorp/go-retryablehttp v0.7.7 github.com/hashicorp/nomad/api v0.0.0-20251216171439-1dee0671280e github.com/jellydator/ttlcache/v3 v3.4.0 + github.com/klauspost/compress v1.18.5 github.com/launchdarkly/go-sdk-common/v3 v3.3.0 github.com/launchdarkly/go-server-sdk/v7 v7.13.0 github.com/ngrok/firewall_toolkit v0.0.18 github.com/oapi-codegen/runtime v1.4.0 github.com/orcaman/concurrent-map/v2 v2.0.1 + github.com/pierrec/lz4/v4 v4.1.22 github.com/redis/go-redis/extra/redisotel/v9 v9.17.3 github.com/redis/go-redis/v9 v9.17.3 github.com/stretchr/testify v1.11.1 @@ -231,7 +233,6 @@ require ( github.com/json-iterator/go v1.1.12 // indirect github.com/julienschmidt/httprouter v1.3.0 // indirect github.com/kamstrup/intmap v0.5.1 // indirect - github.com/klauspost/compress v1.18.5 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/knadh/koanf/maps v0.1.2 // indirect github.com/knadh/koanf/providers/confmap v1.0.0 // indirect @@ -285,7 +286,6 @@ require ( github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b // indirect github.com/patrickmn/go-cache v2.1.0+incompatible // indirect github.com/pelletier/go-toml/v2 v2.3.0 // indirect - github.com/pierrec/lz4/v4 v4.1.22 // indirect github.com/pires/go-proxyproto v0.7.0 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect github.com/pkg/errors v0.9.1 // indirect diff --git a/packages/shared/pkg/featureflags/context.go b/packages/shared/pkg/featureflags/context.go index 4f0e957ff0..79e52b1557 100644 --- a/packages/shared/pkg/featureflags/context.go +++ b/packages/shared/pkg/featureflags/context.go @@ -164,6 +164,14 @@ func VolumeContext(volumeName string) ldcontext.Context { return ldcontext.NewWithKind(VolumeKind, volumeName) } +func CompressFileTypeContext(fileType string) ldcontext.Context { + return ldcontext.NewWithKind(CompressFileTypeKind, fileType) +} + +func CompressUseCaseContext(useCase string) ldcontext.Context { + return ldcontext.NewWithKind(CompressUseCaseKind, useCase) +} + func VersionContext(orchestratorID, commit string) ldcontext.Context { return ldcontext.NewBuilder(orchestratorID). Kind(OrchestratorKind). diff --git a/packages/shared/pkg/featureflags/flags.go b/packages/shared/pkg/featureflags/flags.go index 497eaeb481..f390e4fe90 100644 --- a/packages/shared/pkg/featureflags/flags.go +++ b/packages/shared/pkg/featureflags/flags.go @@ -18,14 +18,16 @@ const ( SandboxKernelVersionAttribute string = "kernel-version" SandboxFirecrackerVersionAttribute string = "firecracker-version" - TeamKind ldcontext.Kind = "team" - UserKind ldcontext.Kind = "user" - ClusterKind ldcontext.Kind = "cluster" - deploymentKind ldcontext.Kind = "deployment" - TierKind ldcontext.Kind = "tier" - ServiceKind ldcontext.Kind = "service" - TemplateKind ldcontext.Kind = "template" - VolumeKind ldcontext.Kind = "volume" + TeamKind ldcontext.Kind = "team" + UserKind ldcontext.Kind = "user" + ClusterKind ldcontext.Kind = "cluster" + deploymentKind ldcontext.Kind = "deployment" + TierKind ldcontext.Kind = "tier" + ServiceKind ldcontext.Kind = "service" + TemplateKind ldcontext.Kind = "template" + VolumeKind ldcontext.Kind = "volume" + CompressFileTypeKind ldcontext.Kind = "compress-file-type" + CompressUseCaseKind ldcontext.Kind = "compress-use-case" OrchestratorKind ldcontext.Kind = "orchestrator" OrchestratorCommitAttribute string = "commit" @@ -207,6 +209,8 @@ var ( // MaxConcurrentSnapshotBuildQueries limits concurrent GetSnapshotBuilds calls (e.g. sandbox delete). // 0 or negative disables throttling (unlimited concurrency). MaxConcurrentSnapshotBuildQueries = NewIntFlag("max-concurrent-snapshot-build-queries", 0) + + MinChunkerReadSizeKB = NewIntFlag("min-chunker-read-size-kb", 16) ) type StringFlag struct { @@ -318,17 +322,17 @@ func GetTrackedTemplatesSet(ctx context.Context, ff *Client) map[string]struct{} return result } -// ChunkerConfigFlag is a JSON flag controlling the chunker implementation and tuning. -// -// NOTE: Changing useStreaming has no effect on chunkers already created for -// cached templates. A service restart (redeploy) is required for that change -// to take effect. minReadBatchSizeKB is checked just-in-time on each fetch, -// so it takes effect immediately. -// -// JSON format: {"useStreaming": false, "minReadBatchSizeKB": 16} -var ChunkerConfigFlag = NewJSONFlag("chunker-config", ldvalue.FromJSONMarshal(map[string]any{ - "useStreaming": false, - "minReadBatchSizeKB": 16, +// CompressConfigFlag controls compression during template builds. +// When compressBuilds is true, builds upload exclusively compressed data +// (no uncompressed fallback). When false, exclusively uncompressed with V3 headers. +var CompressConfigFlag = NewJSONFlag("compress-config", ldvalue.FromJSONMarshal(map[string]any{ + "compressBuilds": false, + "compressionType": "", + "compressionLevel": 0, + "frameSizeKB": 0, + "minPartSizeMB": 0, + "frameEncodeWorkers": 0, + "encoderConcurrency": 0, })) // TCPFirewallEgressThrottleConfig controls per-sandbox egress throttling via Firecracker's diff --git a/packages/shared/pkg/grpc/orchestrator/chunks.pb.go b/packages/shared/pkg/grpc/orchestrator/chunks.pb.go index 388c9bd808..70aee86388 100644 --- a/packages/shared/pkg/grpc/orchestrator/chunks.pb.go +++ b/packages/shared/pkg/grpc/orchestrator/chunks.pb.go @@ -27,11 +27,12 @@ type PeerAvailability struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - // not_available is true when the file is not in the local cache. - // The caller should fall back to GCS. + // not_available is true when the file is not in the local cache. The caller + // should fall back to remote storage. NotAvailable bool `protobuf:"varint,1,opt,name=not_available,json=notAvailable,proto3" json:"not_available,omitempty"` - // use_storage is true when the GCS upload has completed and the caller - // should switch to reading from GCS/NFS directly instead of this peer. + // use_storage is true when the remote storage upload has completed and the + // caller should switch to reading from remote storage directly instead of + // this peer. UseStorage bool `protobuf:"varint,2,opt,name=use_storage,json=useStorage,proto3" json:"use_storage,omitempty"` } diff --git a/packages/shared/pkg/storage/compress_config.go b/packages/shared/pkg/storage/compress_config.go new file mode 100644 index 0000000000..ed413fecdb --- /dev/null +++ b/packages/shared/pkg/storage/compress_config.go @@ -0,0 +1,63 @@ +package storage + +const ( + // DefaultCompressFrameSize is the default uncompressed size of each + // compression frame (2 MiB). Overridable via CompressConfig.FrameSizeKB. + // The last frame in a file may be shorter. + // + // The chunker fetches one frame at a time from storage on a cache miss. + // Larger frame sizes mean more data cached per fetch (faster warm-up and + // fewer remote storage round-trips), but higher memory and I/O cost per + // miss. + // + // This MUST be multiple of every block/page size: + // - header.HugepageSize (2 MiB) — UFFD huge-page size, also used by prefetch + // - header.RootfsBlockSize (4 KiB) — NBD / rootfs block size + DefaultCompressFrameSize = 2 * 1024 * 1024 + + // Use case identifiers for per-use-case compression targeting via LaunchDarkly. + UseCaseBuild = "build" + UseCasePause = "pause" +) + +// CompressConfig is the base compression configuration, loaded from environment +// variables at startup. Feature flags may override individual fields at runtime +// at the upload boundary. Zero value means compression disabled. +type CompressConfig struct { + Enabled bool `env:"COMPRESS_ENABLED" envDefault:"false"` + Type string `env:"COMPRESS_TYPE" envDefault:""` + Level int `env:"COMPRESS_LEVEL" envDefault:"0"` + FrameSizeKB int `env:"COMPRESS_FRAME_SIZE_KB" envDefault:"0"` + MinPartSizeMB int `env:"COMPRESS_MIN_PART_SIZE_MB" envDefault:"0"` + FrameEncodeWorkers int `env:"COMPRESS_FRAME_ENCODE_WORKERS" envDefault:"0"` + EncoderConcurrency int `env:"COMPRESS_ENCODER_CONCURRENCY" envDefault:"0"` +} + +// CompressionType returns the parsed CompressionType. +func (c CompressConfig) CompressionType() CompressionType { + return parseCompressionType(c.Type) +} + +// FrameSize returns the frame size in bytes. +func (c CompressConfig) FrameSize() int { + if c.FrameSizeKB <= 0 { + return DefaultCompressFrameSize + } + + return c.FrameSizeKB * 1024 +} + +// MinPartSize returns the minimum compressed part size in bytes. +// Parts accumulate frames until they reach this threshold. +func (c CompressConfig) MinPartSize() int64 { + if c.MinPartSizeMB <= 0 { + return int64(gcpMultipartUploadChunkSize) + } + + return int64(c.MinPartSizeMB) * (1 << 20) +} + +// IsCompressionEnabled reports whether compression is configured and active. +func (c CompressConfig) IsCompressionEnabled() bool { + return c.Enabled && c.CompressionType() != CompressionNone +} diff --git a/packages/shared/pkg/storage/compress_decode.go b/packages/shared/pkg/storage/compress_decode.go new file mode 100644 index 0000000000..01e40ed6a5 --- /dev/null +++ b/packages/shared/pkg/storage/compress_decode.go @@ -0,0 +1,128 @@ +package storage + +import ( + "fmt" + "io" + "sync" + + "github.com/klauspost/compress/zstd" + lz4 "github.com/pierrec/lz4/v4" +) + +var lz4DecoderPool sync.Pool + +func getLZ4Decoder(r io.Reader) *lz4.Reader { + if v := lz4DecoderPool.Get(); v != nil { + dec := v.(*lz4.Reader) + dec.Reset(r) + + return dec + } + + return lz4.NewReader(r) +} + +func putLZ4Decoder(dec *lz4.Reader) { + dec.Reset(nil) + lz4DecoderPool.Put(dec) +} + +// zstd concurrency is hardcoded to 1: benchmarks show higher values hurt +// throughput for single 2MiB frame decodes. +var zstdDecoderPool sync.Pool + +func getZstdDecoder(r io.Reader) (*zstd.Decoder, error) { + if v := zstdDecoderPool.Get(); v != nil { + dec := v.(*zstd.Decoder) + if err := dec.Reset(r); err != nil { + dec.Close() + + return nil, err + } + + return dec, nil + } + + return zstd.NewReader(r) +} + +func putZstdDecoder(dec *zstd.Decoder) { + dec.Reset(nil) + zstdDecoderPool.Put(dec) +} + +// NewDecompressingReader wraps a reader with the appropriate decompressor. +// Close releases the decompressor back to its pool but does NOT close the +// underlying reader — the caller is responsible for closing it. +func NewDecompressingReader(raw io.Reader, ct CompressionType) (io.ReadCloser, error) { + switch ct { + case CompressionLZ4: + dec := getLZ4Decoder(raw) + + return &pooledDecoder{ + Reader: dec, + close: func() { putLZ4Decoder(dec) }, + }, nil + + case CompressionZstd: + dec, err := getZstdDecoder(raw) + if err != nil { + return nil, fmt.Errorf("failed to create zstd decoder: %w", err) + } + + return &pooledDecoder{ + Reader: dec, + close: func() { putZstdDecoder(dec) }, + }, nil + + default: + return nil, fmt.Errorf("unsupported compression type: %s", ct) + } +} + +// pooledDecoder wraps a decompressor from a sync.Pool. +// Close returns the decompressor to the pool. +type pooledDecoder struct { + io.Reader + + close func() +} + +func (r *pooledDecoder) Close() error { + r.close() + + return nil +} + +// newDecompressingReadCloser wraps raw with the appropriate decompressor and +// takes ownership: Close releases the decompressor back to the pool AND closes raw. +func newDecompressingReadCloser(raw io.ReadCloser, ct CompressionType) (io.ReadCloser, error) { + dec, err := NewDecompressingReader(raw, ct) + if err != nil { + return nil, err + } + + return &decompressingReadCloser{dec: dec, raw: raw}, nil +} + +// decompressingReadCloser reads from the decompressor and closes both the +// decompressor (returning it to the pool) and the underlying raw stream. +type decompressingReadCloser struct { + dec io.ReadCloser // decompressor — reads from raw + raw io.Closer // underlying stream +} + +func (c *decompressingReadCloser) Read(p []byte) (int, error) { + return c.dec.Read(p) +} + +func (c *decompressingReadCloser) Close() error { + decErr := c.dec.Close() + rawErr := c.raw.Close() + + if decErr != nil { + return decErr + } + + return rawErr +} diff --git a/packages/shared/pkg/storage/compress_encode.go b/packages/shared/pkg/storage/compress_encode.go new file mode 100644 index 0000000000..3515a9d765 --- /dev/null +++ b/packages/shared/pkg/storage/compress_encode.go @@ -0,0 +1,122 @@ +package storage + +import ( + "bytes" + "context" + "fmt" + "sync" + + "github.com/klauspost/compress/zstd" + lz4 "github.com/pierrec/lz4/v4" +) + +// compressor compresses individual frames. Implementations are pooled and +// reused across frames within a single CompressStream call. +type compressor interface { + compress(src []byte) ([]byte, error) +} + +// lz4Compressor wraps a pooled lz4.Writer. The writer is reused via Reset +// between frames to avoid re-allocating internal hash tables (~64KB). +type lz4Compressor struct { + w *lz4.Writer +} + +func (c *lz4Compressor) compress(src []byte) ([]byte, error) { + var buf bytes.Buffer + buf.Grow(lz4.CompressBlockBound(len(src))) + c.w.Reset(&buf) + + if _, err := c.w.Write(src); err != nil { + return nil, fmt.Errorf("lz4 compress: %w", err) + } + + if err := c.w.Close(); err != nil { + return nil, fmt.Errorf("lz4 compress close: %w", err) + } + + return buf.Bytes(), nil +} + +// zstdCompressor wraps a pooled zstd.Encoder using EncodeAll. +type zstdCompressor struct { + enc *zstd.Encoder +} + +func (z *zstdCompressor) compress(src []byte) ([]byte, error) { //nolint:unparam // satisfies compressor interface + return z.enc.EncodeAll(src, make([]byte, 0, len(src))), nil +} + +// newCompressorPool returns a pool of compressors for the given config. +// Both LZ4 and zstd encoders are pooled and reused via Reset/EncodeAll. +// The config is validated eagerly — if zstd options are invalid, an error +// is returned immediately rather than deferred to pool.Get(). +func newCompressorPool(cfg CompressConfig) (*sync.Pool, error) { + pool := &sync.Pool{} + + switch cfg.CompressionType() { + case CompressionZstd: + zstdOpts := []zstd.EOption{ + zstd.WithEncoderLevel(zstd.EncoderLevel(cfg.Level)), + zstd.WithEncoderCRC(true), + } + if cfg.FrameSize() > 0 { + zstdOpts = append(zstdOpts, zstd.WithWindowSize(cfg.FrameSize())) + } + if cfg.EncoderConcurrency > 0 { + zstdOpts = append(zstdOpts, zstd.WithEncoderConcurrency(cfg.EncoderConcurrency)) + } + + // Validate options by creating one encoder upfront. + first, err := zstd.NewWriter(nil, zstdOpts...) + if err != nil { + return nil, fmt.Errorf("zstd encoder: %w", err) + } + pool.Put(&zstdCompressor{enc: first}) + + pool.New = func() any { + // Options are already validated; NewWriter won't fail. + enc, _ := zstd.NewWriter(nil, zstdOpts...) + + return &zstdCompressor{enc: enc} + } + case CompressionLZ4: + lz4Opts := []lz4.Option{ + lz4.BlockSizeOption(lz4.Block4Mb), + lz4.BlockChecksumOption(true), + lz4.ChecksumOption(false), + lz4.ConcurrencyOption(1), + lz4.CompressionLevelOption(lz4.Fast), + } + + // Validate options by creating one encoder upfront. + first := lz4.NewWriter(nil) + if err := first.Apply(lz4Opts...); err != nil { + return nil, fmt.Errorf("lz4 encoder: %w", err) + } + pool.Put(&lz4Compressor{w: first}) + + pool.New = func() any { + w := lz4.NewWriter(nil) + _ = w.Apply(lz4Opts...) //nolint:errcheck // options validated above + + return &lz4Compressor{w: w} + } + default: + return nil, fmt.Errorf("unsupported compression type: %s", cfg.CompressionType()) + } + + return pool, nil +} + +func CompressBytes(ctx context.Context, data []byte, cfg CompressConfig) (*FrameTable, []byte, [32]byte, error) { + up := &memPartUploader{} + + const compressBytesConcurrency = 1 + ft, checksum, err := compressStream(ctx, bytes.NewReader(data), cfg, up, compressBytesConcurrency) + if err != nil { + return nil, nil, [32]byte{}, err + } + + return ft, up.Assemble(), checksum, nil +} diff --git a/packages/shared/pkg/storage/compress_frame_table.go b/packages/shared/pkg/storage/compress_frame_table.go new file mode 100644 index 0000000000..9b8079ae7c --- /dev/null +++ b/packages/shared/pkg/storage/compress_frame_table.go @@ -0,0 +1,359 @@ +package storage + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "sort" +) + +type CompressionType byte + +const ( + CompressionNone = CompressionType(iota) + CompressionZstd + CompressionLZ4 + + // maxDeserializedFrames caps the number of frames read from a serialized + // FrameTable to prevent OOM from corrupted headers. 1M frames = 2 TiB + // uncompressed at 2 MiB frame size. + maxDeserializedFrames = 1024 * 1024 +) + +// FrameTable is a decompression index for compressed diff files. +// +// Dirty blocks are grouped into frames and each frame is compressed +// independently. Two address spaces describe the same data: +// +// U-space (uncompressed): |-- frame 0 (2M) --|-- frame 1 (2M) --| ... +// C-space (compressed): |-- f0 (.6M) --|-- f1 (.7M) --| ... +// +// Each frame is a frameEntry with absolute offsets (StartU, StartC) and +// sizes (SizeU, SizeC). Lookups are a binary search on StartU. + +// FrameSize holds the uncompressed (U) and compressed (C) byte size of a +// single frame. +type FrameSize struct { + U int32 + C int32 +} + +func (s FrameSize) String() string { + return fmt.Sprintf("U:%d/C:%d", s.U, s.C) +} + +type Range struct { + Offset int64 + Length int +} + +func (r Range) String() string { + return fmt.Sprintf("%d/%d", r.Offset, r.Length) +} + +// frameEntry stores one frame as an absolute start offset plus size in both +// address spaces. Fields must be exported for encoding/binary (Read/Write use reflection). +// Field order chosen for optimal alignment: two int64 then two uint32 = 24 bytes, no padding. +type frameEntry struct { + StartU int64 + StartC int64 + SizeU int32 + SizeC int32 +} + +// FrameTable is the decompression index for a compressed diff file. +// Immutable after construction; safe to share across goroutines. +// Sparse tables (gaps between entries) are supported. +type FrameTable struct { + compressionType CompressionType + entries []frameEntry // sorted by StartU +} + +// newFrameTableFromEntries creates a FrameTable from pre-computed absolute-offset entries. +func newFrameTableFromEntries(ct CompressionType, entries []frameEntry) *FrameTable { + return &FrameTable{compressionType: ct, entries: entries} +} + +// NewFrameTable creates a FrameTable from consecutive frame sizes, computing +// absolute offsets starting from zero. +func NewFrameTable(ct CompressionType, sizes []FrameSize) *FrameTable { + if len(sizes) == 0 { + return newFrameTableFromEntries(ct, nil) + } + + entries := make([]frameEntry, len(sizes)) + + var u, c int64 + for i, s := range sizes { + entries[i] = frameEntry{ + StartU: u, + StartC: c, + SizeU: s.U, + SizeC: s.C, + } + u += int64(s.U) + c += int64(s.C) + } + + return newFrameTableFromEntries(ct, entries) +} + +// CompressionType returns the compression type. Nil-safe: returns CompressionNone for nil. +func (ft *FrameTable) CompressionType() CompressionType { + if ft == nil { + return CompressionNone + } + + return ft.compressionType +} + +// IsCompressed reports whether ft is non-nil and has a compression type set. +func (ft *FrameTable) IsCompressed() bool { + return ft != nil && ft.compressionType != CompressionNone +} + +func (ft *FrameTable) NumFrames() int { + if ft == nil { + return 0 + } + + return len(ft.entries) +} + +func (e frameEntry) endU() int64 { return e.StartU + int64(e.SizeU) } +func (e frameEntry) endC() int64 { return e.StartC + int64(e.SizeC) } + +func (ft *FrameTable) FrameAt(i int) (startU, endU, startC, endC int64) { + e := ft.entries[i] + + return e.StartU, e.endU(), e.StartC, e.endC() +} + +// UncompressedSize returns the total uncompressed size across all frames. +// Nil-safe: returns 0 for nil (uncompressed leg in mixed-mode V4 upload). +func (ft *FrameTable) UncompressedSize() int64 { + if ft == nil { + return 0 + } + + var total int64 + for _, e := range ft.entries { + total += int64(e.SizeU) + } + + return total +} + +// CompressedSize returns the total compressed size across all frames. +// Nil-safe: returns 0 for nil (uncompressed leg in mixed-mode V4 upload). +func (ft *FrameTable) CompressedSize() int64 { + if ft == nil { + return 0 + } + + var total int64 + for _, e := range ft.entries { + total += int64(e.SizeC) + } + + return total +} + +// locate finds the frame containing the given uncompressed offset. +func (ft *FrameTable) locate(offset int64) (frameEntry, error) { + if ft == nil { + return frameEntry{}, errors.New("locate called with nil frame table — data is not compressed") + } + + // Binary search: find the last entry whose StartU <= offset. + i := sort.Search(len(ft.entries), func(i int) bool { + return ft.entries[i].StartU > offset + }) - 1 + + if i < 0 { + return frameEntry{}, fmt.Errorf("offset %d not found in frame table", offset) + } + + e := ft.entries[i] + if offset >= e.endU() { + return frameEntry{}, fmt.Errorf("offset %d is in a gap (not covered by any frame)", offset) + } + + return e, nil +} + +// LocateCompressed maps a U-space offset to its C-space byte range. +// This is the final step of the read path: after GetShiftedMapping resolves +// the virtual offset to a build-local U-offset, this locates the compressed +// bytes to fetch from storage. +func (ft *FrameTable) LocateCompressed(offset int64) (Range, error) { + e, err := ft.locate(offset) + if err != nil { + return Range{}, err + } + + return Range{Offset: e.StartC, Length: int(e.SizeC)}, nil +} + +// LocateUncompressed returns the uncompressed byte range for the frame +// containing the given uncompressed offset. +func (ft *FrameTable) LocateUncompressed(offset int64) (Range, error) { + e, err := ft.locate(offset) + if err != nil { + return Range{}, err + } + + return Range{Offset: e.StartU, Length: int(e.SizeU)}, nil +} + +// Serialize writes the frame table to w in binary little-endian format. +// Nil-safe: writes zeros for type and count. +func (ft *FrameTable) Serialize(w io.Writer) error { + var ct CompressionType + var n int + if ft != nil && ft.compressionType != CompressionNone { + ct = ft.compressionType + n = len(ft.entries) + } + + if err := binary.Write(w, binary.LittleEndian, uint32(ct)); err != nil { + return err + } + + if err := binary.Write(w, binary.LittleEndian, uint32(n)); err != nil { + return err + } + + if n > 0 { + if err := binary.Write(w, binary.LittleEndian, ft.entries); err != nil { + return err + } + } + + return nil +} + +// DeserializeFrameTable reads a FrameTable from r. Returns nil for +// uncompressed builds (compressionType=0 or numFrames=0). +func DeserializeFrameTable(r io.Reader) (*FrameTable, error) { + var ct uint32 + + if err := binary.Read(r, binary.LittleEndian, &ct); err != nil { + return nil, fmt.Errorf("read compression type: %w", err) + } + + var n uint32 + + if err := binary.Read(r, binary.LittleEndian, &n); err != nil { + return nil, fmt.Errorf("read frame count: %w", err) + } + + if ct == 0 && n > 0 { + return nil, fmt.Errorf("compression type is 0 but frame count is %d: corrupted header", n) + } + if ct == 0 || n == 0 { + return nil, nil + } + + if n > maxDeserializedFrames { + return nil, fmt.Errorf("frame count %d exceeds maximum %d", n, maxDeserializedFrames) + } + + entries := make([]frameEntry, n) + if err := binary.Read(r, binary.LittleEndian, entries); err != nil { + return nil, fmt.Errorf("read frame entries: %w", err) + } + + for i := range entries { + if entries[i].SizeU <= 0 || entries[i].SizeC <= 0 { + return nil, fmt.Errorf("frame %d has zero or negative size: SizeU=%d SizeC=%d", i, entries[i].SizeU, entries[i].SizeC) + } + if i > 0 && entries[i].StartU < entries[i-1].endU() { + return nil, fmt.Errorf("frame %d StartU %d < previous endU %d: U-entries not sorted", i, entries[i].StartU, entries[i-1].endU()) + } + if i > 0 && entries[i].StartC < entries[i-1].endC() { + return nil, fmt.Errorf("frame %d StartC %d < previous endC %d: C-entries not sorted", i, entries[i].StartC, entries[i-1].endC()) + } + } + + return newFrameTableFromEntries(CompressionType(ct), entries), nil +} + +// TrimToRanges returns a new FrameTable containing only the frames that +// overlap with at least one of the given U-space byte ranges. +// Used during V4 header serialization to keep headers compact when a build +// has many frames but only a few are referenced in the current layer. +// Nil-safe: returns ft unchanged when ft is nil or ranges is empty. +func (ft *FrameTable) TrimToRanges(ranges []Range) *FrameTable { + if ft == nil || len(ft.entries) == 0 || len(ranges) == 0 { + return ft + } + + keep := make([]bool, len(ft.entries)) + kept := 0 + + for _, r := range ranges { + startU, endU := r.Offset, r.Offset+int64(r.Length) + + // Binary search: first frame whose EndU > startU. + lo := sort.Search(len(ft.entries), func(i int) bool { + return ft.entries[i].endU() > startU + }) + + for i := lo; i < len(ft.entries) && ft.entries[i].StartU < endU; i++ { + if !keep[i] { + keep[i] = true + kept++ + } + } + } + + if kept == len(ft.entries) { + return ft // nothing trimmed + } + + trimmed := make([]frameEntry, 0, kept) + for i, e := range ft.entries { + if keep[i] { + trimmed = append(trimmed, e) + } + } + + return newFrameTableFromEntries(ft.compressionType, trimmed) +} + +func (ct CompressionType) Suffix() string { + switch ct { + case CompressionZstd: + return ".zstd" + case CompressionLZ4: + return ".lz4" + default: + return "" + } +} + +func (ct CompressionType) String() string { + switch ct { + case CompressionZstd: + return "zstd" + case CompressionLZ4: + return "lz4" + default: + return "none" + } +} + +// parseCompressionType converts a string to CompressionType. +// Returns CompressionNone for unrecognised values. +func parseCompressionType(s string) CompressionType { + switch s { + case "lz4": + return CompressionLZ4 + case "zstd": + return CompressionZstd + default: + return CompressionNone + } +} diff --git a/packages/shared/pkg/storage/compress_frame_table_test.go b/packages/shared/pkg/storage/compress_frame_table_test.go new file mode 100644 index 0000000000..4a10af5732 --- /dev/null +++ b/packages/shared/pkg/storage/compress_frame_table_test.go @@ -0,0 +1,225 @@ +package storage + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" +) + +// threeFrameFT returns a FrameTable with three 1MB uncompressed frames +// and varying compressed sizes, starting at the given offset. +func threeFrameFT(startU, startC int64) *FrameTable { + return &FrameTable{ + compressionType: CompressionLZ4, + entries: []frameEntry{ + {StartU: startU, StartC: startC, SizeU: 1 << 20, SizeC: 500_000}, // frame 0 + {StartU: startU + 1<<20, StartC: startC + 500_000, SizeU: 1 << 20, SizeC: 600_000}, // frame 1 + {StartU: startU + 2<<20, StartC: startC + 1_100_000, SizeU: 1 << 20, SizeC: 400_000}, // frame 2 + }, + } +} + +func TestLocate(t *testing.T) { + t.Parallel() + ft := threeFrameFT(0, 0) + + t.Run("first byte of each frame uncompressed", func(t *testing.T) { + t.Parallel() + for i, wantU := range []int64{0, 1 << 20, 2 << 20} { + r, err := ft.LocateUncompressed(wantU) + require.NoError(t, err, "frame %d", i) + require.Equal(t, wantU, r.Offset) + require.Equal(t, 1<<20, r.Length) + } + }) + + t.Run("first byte of each frame compressed", func(t *testing.T) { + t.Parallel() + wantC := []int64{0, 500_000, 1_100_000} + wantLen := []int{500_000, 600_000, 400_000} + for i, offsetU := range []int64{0, 1 << 20, 2 << 20} { + r, err := ft.LocateCompressed(offsetU) + require.NoError(t, err, "frame %d", i) + require.Equal(t, wantC[i], r.Offset, "frame %d C start", i) + require.Equal(t, wantLen[i], r.Length, "frame %d C length", i) + } + }) + + t.Run("last byte of frame", func(t *testing.T) { + t.Parallel() + r, err := ft.LocateUncompressed((1 << 20) - 1) + require.NoError(t, err) + require.Equal(t, int64(0), r.Offset) + }) + + t.Run("returns correct C offset", func(t *testing.T) { + t.Parallel() + r, err := ft.LocateCompressed(2 << 20) + require.NoError(t, err) + require.Equal(t, int64(1_100_000), r.Offset) // 500k + 600k + }) + + t.Run("beyond end errors", func(t *testing.T) { + t.Parallel() + _, err := ft.LocateCompressed(3 << 20) + require.Error(t, err) + }) + + t.Run("nil table errors", func(t *testing.T) { + t.Parallel() + _, err := (*FrameTable)(nil).LocateCompressed(0) + require.Error(t, err) + }) + + t.Run("non-zero start offset", func(t *testing.T) { + t.Parallel() + sub := threeFrameFT(1<<20, 500_000) + + r, err := sub.LocateUncompressed(1 << 20) + require.NoError(t, err) + require.Equal(t, int64(1<<20), r.Offset) + + r, err = sub.LocateCompressed(1 << 20) + require.NoError(t, err) + require.Equal(t, int64(500_000), r.Offset) + + // Before first entry — no frame should contain offset 0. + _, err = sub.LocateUncompressed(0) + require.Error(t, err) + }) +} + +func TestNewFrameTable(t *testing.T) { + t.Parallel() + + ft := NewFrameTable(CompressionZstd, []FrameSize{ + {U: 1 << 20, C: 500_000}, + {U: 1 << 20, C: 600_000}, + }) + + require.Equal(t, 2, ft.NumFrames()) + require.Equal(t, CompressionZstd, ft.CompressionType()) + require.True(t, ft.IsCompressed()) + require.Equal(t, int64(2<<20), ft.UncompressedSize()) + require.Equal(t, int64(1_100_000), ft.CompressedSize()) + + startU, endU, startC, endC := ft.FrameAt(0) + require.Equal(t, int64(0), startU) + require.Equal(t, int64(1<<20), endU) + require.Equal(t, int64(0), startC) + require.Equal(t, int64(500_000), endC) + + startU, _, startC, _ = ft.FrameAt(1) + require.Equal(t, int64(1<<20), startU) + require.Equal(t, int64(500_000), startC) +} + +func TestFrameTable_TrimToRanges(t *testing.T) { + t.Parallel() + + ft := NewFrameTable(CompressionLZ4, []FrameSize{ + {U: 1 << 20, C: 500_000}, + {U: 1 << 20, C: 600_000}, + {U: 1 << 20, C: 400_000}, + {U: 1 << 20, C: 700_000}, + }) + + t.Run("all frames retained", func(t *testing.T) { + t.Parallel() + trimmed := ft.TrimToRanges([]Range{{Offset: 0, Length: 4 << 20}}) + require.Same(t, ft, trimmed) + }) + + t.Run("single range trims to subset", func(t *testing.T) { + t.Parallel() + trimmed := ft.TrimToRanges([]Range{{Offset: 1 << 20, Length: 2 << 20}}) + require.Equal(t, 2, trimmed.NumFrames()) + + startU, _, _, _ := trimmed.FrameAt(0) + require.Equal(t, int64(1<<20), startU) + + startU, _, _, _ = trimmed.FrameAt(1) + require.Equal(t, int64(2<<20), startU) + }) + + t.Run("two disjoint ranges", func(t *testing.T) { + t.Parallel() + trimmed := ft.TrimToRanges([]Range{ + {Offset: 0, Length: 1 << 20}, + {Offset: 3 << 20, Length: 1 << 20}, + }) + require.Equal(t, 2, trimmed.NumFrames()) + + startU, _, _, _ := trimmed.FrameAt(0) + require.Equal(t, int64(0), startU) + + startU, _, _, _ = trimmed.FrameAt(1) + require.Equal(t, int64(3<<20), startU) + }) + + t.Run("nil table", func(t *testing.T) { + t.Parallel() + var nilFT *FrameTable + require.Nil(t, nilFT.TrimToRanges([]Range{{Offset: 0, Length: 100}})) + }) + + t.Run("sparse lookup works", func(t *testing.T) { + t.Parallel() + trimmed := ft.TrimToRanges([]Range{ + {Offset: 0, Length: 1 << 20}, + {Offset: 3 << 20, Length: 1 << 20}, + }) + + r, err := trimmed.LocateCompressed(0) + require.NoError(t, err) + require.Equal(t, int64(0), r.Offset) + + r, err = trimmed.LocateCompressed(3 << 20) + require.NoError(t, err) + require.Equal(t, int64(500_000+600_000+400_000), r.Offset) + + // Gap lookup fails + _, err = trimmed.LocateCompressed(1 << 20) + require.Error(t, err) + }) +} + +func TestSerializeDeserializeFrameTable(t *testing.T) { + t.Parallel() + + t.Run("round-trip", func(t *testing.T) { + t.Parallel() + ft := NewFrameTable(CompressionZstd, []FrameSize{ + {U: 2048, C: 1024}, + {U: 4096, C: 3500}, + }) + + var buf bytes.Buffer + require.NoError(t, ft.Serialize(&buf)) + + got, err := DeserializeFrameTable(&buf) + require.NoError(t, err) + require.Equal(t, ft.NumFrames(), got.NumFrames()) + require.Equal(t, ft.CompressionType(), got.CompressionType()) + + for i := range ft.NumFrames() { + wSU, wEU, wSC, wEC := ft.FrameAt(i) + gSU, gEU, gSC, gEC := got.FrameAt(i) + require.Equal(t, wSU, gSU, "frame %d StartU", i) + require.Equal(t, wEU, gEU, "frame %d EndU", i) + require.Equal(t, wSC, gSC, "frame %d StartC", i) + require.Equal(t, wEC, gEC, "frame %d EndC", i) + } + }) + + t.Run("nil writes zeros", func(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + require.NoError(t, (*FrameTable)(nil).Serialize(&buf)) + + got, err := DeserializeFrameTable(&buf) + require.NoError(t, err) + require.Nil(t, got) + }) +} diff --git a/packages/shared/pkg/storage/compress_upload.go b/packages/shared/pkg/storage/compress_upload.go new file mode 100644 index 0000000000..b7a25effbe --- /dev/null +++ b/packages/shared/pkg/storage/compress_upload.go @@ -0,0 +1,228 @@ +package storage + +import ( + "bytes" + "context" + "crypto/sha256" + "errors" + "fmt" + "io" + "slices" + "sync" + "sync/atomic" + + "golang.org/x/sync/errgroup" +) + +type partUploader interface { + Start(ctx context.Context) error + UploadPart(ctx context.Context, partIndex int, data ...[]byte) error + Complete(ctx context.Context) error + Close() error +} + +type memPartUploader struct { + mu sync.Mutex + parts map[int][]byte +} + +func (m *memPartUploader) Start(context.Context) error { + m.parts = make(map[int][]byte) + + return nil +} + +func (m *memPartUploader) UploadPart(_ context.Context, partIndex int, data ...[]byte) error { + var buf bytes.Buffer + for _, d := range data { + buf.Write(d) + } + m.mu.Lock() + m.parts[partIndex] = buf.Bytes() + m.mu.Unlock() + + return nil +} + +func (m *memPartUploader) Complete(context.Context) error { return nil } +func (m *memPartUploader) Close() error { return nil } + +func (m *memPartUploader) Assemble() []byte { + keys := make([]int, 0, len(m.parts)) + for k := range m.parts { + keys = append(keys, k) + } + slices.Sort(keys) + + var buf bytes.Buffer + for _, k := range keys { + buf.Write(m.parts[k]) + } + + return buf.Bytes() +} + +type frame struct { + uncompressedSize int + compressed []byte +} + +type part struct { + index int + frames []*frame + compressedSize atomic.Int64 + compress *errgroup.Group +} + +func newPart(index int, parentCtx context.Context, workers int) (*part, context.Context) { + p := &part{index: index} + var ctx context.Context + p.compress, ctx = errgroup.WithContext(parentCtx) + p.compress.SetLimit(workers) + + return p, ctx +} + +func (p *part) addFrame(ctx context.Context, uncompressedData []byte, pool *sync.Pool) { + frameInPart := &frame{uncompressedSize: len(uncompressedData)} + p.frames = append(p.frames, frameInPart) + + p.compress.Go(func() error { + if err := ctx.Err(); err != nil { + return err + } + c := pool.Get().(compressor) + out, err := c.compress(uncompressedData) + pool.Put(c) + if err != nil { + return err + } + frameInPart.compressed = out + p.compressedSize.Add(int64(len(out))) + + return nil + }) +} + +func compressStream(ctx context.Context, in io.Reader, cfg CompressConfig, uploader partUploader, maxUploadConcurrency int) (*FrameTable, [32]byte, error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + if err := uploader.Start(ctx); err != nil { + return nil, [32]byte{}, fmt.Errorf("start upload: %w", err) + } + defer uploader.Close() + + // The read loop goroutine holds one slot for the duration of the stream; + // at least one additional slot is required for uploaders to make progress. + if maxUploadConcurrency < 1 { + maxUploadConcurrency = 1 + } + work, workCtx := errgroup.WithContext(ctx) + work.SetLimit(maxUploadConcurrency + 1) + + // Start the read loop. + q := make(chan *part, 2) + hasher := sha256.New() + work.Go(func() error { + defer close(q) + + return readLoop(workCtx, in, cfg, hasher, q) + }) + + // Upload loop. + var frameSizes []FrameSize + var loopErr error + for p := range q { + if err := p.compress.Wait(); err != nil { + loopErr = fmt.Errorf("compress frames: %w", err) + cancel() + + break + } + + var compressed [][]byte + for _, f := range p.frames { + frameSizes = append(frameSizes, FrameSize{U: int32(f.uncompressedSize), C: int32(len(f.compressed))}) + compressed = append(compressed, f.compressed) + } + + pi := p.index + work.Go(func() error { + return uploader.UploadPart(workCtx, pi, compressed...) + }) + } + + // Drain q so the read loop can exit and close it, then wait for all + // in-flight uploads to finish before the deferred uploader.Close(). + for range q { //nolint:revive // intentional drain + } + workErr := work.Wait() + + if err := errors.Join(loopErr, workErr); err != nil { + return nil, [32]byte{}, err + } + + if err := uploader.Complete(ctx); err != nil { + return nil, [32]byte{}, fmt.Errorf("complete upload: %w", err) + } + + var checksum [32]byte + copy(checksum[:], hasher.Sum(nil)) + ft := NewFrameTable(cfg.CompressionType(), frameSizes) + + return ft, checksum, nil +} + +func readLoop(ctx context.Context, in io.Reader, cfg CompressConfig, hasher io.Writer, q chan<- *part) error { + compressors, err := newCompressorPool(cfg) + if err != nil { + return err + } + + frameSize := cfg.FrameSize() + minPartSize := cfg.MinPartSize() + workers := max(cfg.FrameEncodeWorkers, 1) + p, compressCtx := newPart(1, ctx, workers) + + for { + if err := ctx.Err(); err != nil { + return err + } + + buf := make([]byte, frameSize) + n, err := io.ReadFull(in, buf) + + eof := errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) + if err != nil && !eof { + return fmt.Errorf("read frame: %w", err) + } + + if n > 0 { + hasher.Write(buf[:n]) + p.addFrame(compressCtx, buf[:n], compressors) + } + + if eof { + if len(p.frames) > 0 { + select { + case q <- p: + case <-ctx.Done(): + return ctx.Err() + } + } + + return nil + } + + if p.compressedSize.Load() >= minPartSize { + select { + case q <- p: + case <-ctx.Done(): + return ctx.Err() + } + + p, compressCtx = newPart(p.index+1, ctx, workers) + } + } +} diff --git a/packages/shared/pkg/storage/compress_upload_test.go b/packages/shared/pkg/storage/compress_upload_test.go new file mode 100644 index 0000000000..413aefec34 --- /dev/null +++ b/packages/shared/pkg/storage/compress_upload_test.go @@ -0,0 +1,458 @@ +package storage + +import ( + "bytes" + "context" + crand "crypto/rand" + "crypto/sha256" + "fmt" + "io" + "math/rand/v2" + "os" + "path/filepath" + "slices" + "sync/atomic" + "testing" + "time" + + "github.com/klauspost/compress/zstd" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +// generateSemiRandomData produces deterministic, compressible data. +// Random byte repeated 1-16 times — gives ~0.5-0.7 compression ratio. +func generateSemiRandomData(size int) []byte { + data := make([]byte, size) + rng := rand.New(rand.NewPCG(1, 2)) //nolint:gosec // deterministic + i := 0 + for i < size { + runLen := rng.IntN(16) + 1 + if i+runLen > size { + runLen = size - i + } + b := byte(rng.IntN(256)) + for j := range runLen { + data[i+j] = b + } + i += runLen + } + + return data +} + +// ThrottledPartUploader wraps memPartUploader with simulated upload bandwidth. +type ThrottledPartUploader struct { + memPartUploader + + bandwidth int64 // bytes/sec; 0 = unlimited +} + +func (t *ThrottledPartUploader) UploadPart(ctx context.Context, partIndex int, data ...[]byte) error { + if t.bandwidth > 0 { + total := 0 + for _, d := range data { + total += len(d) + } + time.Sleep(time.Duration(float64(total) / float64(t.bandwidth) * float64(time.Second))) + } + + return t.memPartUploader.UploadPart(ctx, partIndex, data...) +} + +// decompressAll walks the FrameTable and decompresses each frame from the +// concatenated compressed blob, returning the original uncompressed data. +func decompressAll(ft *FrameTable, compressed []byte) ([]byte, error) { + var result []byte + + for i := range ft.NumFrames() { + _, _, startC, endC := ft.FrameAt(i) + cLen := endC - startC + + if startC+cLen > int64(len(compressed)) { + return nil, fmt.Errorf("frame %d: compressed data truncated (need %d, have %d)", i, startC+cLen, len(compressed)) + } + + frameData := compressed[startC:endC] + + var frame []byte + var err error + + switch ft.CompressionType() { + case CompressionLZ4: + dec := getLZ4Decoder(bytes.NewReader(frameData)) + frame, err = io.ReadAll(dec) + putLZ4Decoder(dec) + case CompressionZstd: + var dec *zstd.Decoder + dec, err = getZstdDecoder(bytes.NewReader(frameData)) + if err == nil { + frame, err = io.ReadAll(dec) + putZstdDecoder(dec) + } + } + if err != nil { + return nil, fmt.Errorf("frame %d: %w", i, err) + } + result = append(result, frame...) + } + + return result, nil +} + +// defaultCfg returns a CompressConfig with the given overrides applied. +func defaultCfg(ct CompressionType, workers, frameSize int) CompressConfig { + level := 2 // zstd default + if ct == CompressionLZ4 { + level = 0 + } + + return CompressConfig{ + Enabled: true, + Type: ct.String(), + Level: level, + EncoderConcurrency: 1, + FrameEncodeWorkers: workers, + FrameSizeKB: frameSize / 1024, + MinPartSizeMB: 50, + } +} + +func TestCompressStreamRoundTrip(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + dataSize int + frameSize int + workers int + codec CompressionType + incompressible bool // use crypto/rand data that cannot be compressed + }{ + {"basic", 10 * megabyte, 2 * megabyte, 4, CompressionZstd, false}, + {"workers_1", 10 * megabyte, 2 * megabyte, 1, CompressionZstd, false}, + {"workers_2", 10 * megabyte, 2 * megabyte, 2, CompressionZstd, false}, + {"not_frame_aligned", 10*megabyte + 1, 2 * megabyte, 4, CompressionZstd, false}, + {"smaller_than_frame", 100 * 1024, 2 * megabyte, 4, CompressionZstd, false}, + {"smaller_than_part", 5 * megabyte, 2 * megabyte, 4, CompressionZstd, false}, + {"empty", 0, 2 * megabyte, 4, CompressionZstd, false}, + {"single_byte", 1, 2 * megabyte, 1, CompressionZstd, false}, + {"lz4", 10 * megabyte, 2 * megabyte, 4, CompressionLZ4, false}, + {"lz4_incompressible", 10 * megabyte, 2 * megabyte, 4, CompressionLZ4, true}, + {"zstd_incompressible", 10 * megabyte, 2 * megabyte, 4, CompressionZstd, true}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var original []byte + if tc.dataSize > 0 { + if tc.incompressible { + original = make([]byte, tc.dataSize) + _, err := crand.Read(original) + require.NoError(t, err) + } else { + original = generateSemiRandomData(tc.dataSize) + } + } + + up := &memPartUploader{} + cfg := defaultCfg(tc.codec, tc.workers, tc.frameSize) + + ft, checksum, err := compressStream( + context.Background(), + bytes.NewReader(original), + cfg, + up, + 4, + ) + require.NoError(t, err) + + if tc.dataSize == 0 { + require.Equal(t, 0, ft.NumFrames()) + require.Equal(t, sha256.Sum256(nil), checksum) + + return + } + + // Verify frame count. + expectedFrames := (tc.dataSize + tc.frameSize - 1) / tc.frameSize + require.Equal(t, expectedFrames, ft.NumFrames()) + + // Verify checksum. + require.Equal(t, sha256.Sum256(original), checksum) + + // Round-trip: decompress and compare. + compressed := up.Assemble() + decompressed, err := decompressAll(ft, compressed) + require.NoError(t, err) + require.Equal(t, original, decompressed) + }) + } +} + +func TestCompressStreamContextCancel(t *testing.T) { + t.Parallel() + + data := generateSemiRandomData(10 * megabyte) + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(10 * time.Millisecond) + cancel() + }() + + up := &memPartUploader{} + cfg := defaultCfg(CompressionZstd, 4, 2*megabyte) + + _, _, err := compressStream(ctx, bytes.NewReader(data), cfg, up, 4) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) +} + +func TestCompressStreamPartSizeMinimum(t *testing.T) { + t.Parallel() + + // Generate once; subtests slice to their needed size. + sharedData := generateSemiRandomData(100 * megabyte) + + tests := []struct { + name string + dataSize int + frameSize int + minPartSizeMB int + }{ + {"large_file", 100 * megabyte, 2 * megabyte, 50}, + {"small_file_one_part", 5 * megabyte, 2 * megabyte, 50}, + {"small_target", 100 * megabyte, 2 * megabyte, 10}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + data := sharedData[:tc.dataSize] + up := &memPartUploader{} + cfg := defaultCfg(CompressionZstd, 4, tc.frameSize) + cfg.MinPartSizeMB = tc.minPartSizeMB + + _, _, err := compressStream(context.Background(), bytes.NewReader(data), cfg, up, 4) + require.NoError(t, err) + + // Verify: no non-final part is under 5 MiB. + keys := make([]int, 0, len(up.parts)) + for k := range up.parts { + keys = append(keys, k) + } + slices.Sort(keys) + + for i, k := range keys { + isFinal := i == len(keys)-1 + if !isFinal { + require.GreaterOrEqual(t, len(up.parts[k]), 5*1024*1024, + "non-final part %d is under 5 MiB (%d bytes)", k, len(up.parts[k])) + } + } + + require.NotEmpty(t, up.parts, "should have at least one part") + }) + } +} + +// TestCompressStreamRace runs many concurrent CompressStream calls with high +// worker counts to shake out data races in the compressor pool, memPartUploader, +// and errgroup coordination. Run with -race. +func TestCompressStreamRace(t *testing.T) { + t.Parallel() + + const ( + streams = 8 // concurrent CompressStream calls + dataSize = 4 * megabyte // small enough to be fast, big enough to exercise batching + frameSize = 128 * 1024 // 128 KB — many frames per part + workers = 8 // high worker count to maximise contention + minPartSizeMB = 1 // small parts → many parts per stream + ) + + data := generateSemiRandomData(dataSize) + wantChecksum := sha256.Sum256(data) + + // Use an errgroup to run all streams concurrently. + eg, ctx := errgroup.WithContext(context.Background()) + for i := range streams { + codec := CompressionZstd + if i%2 == 1 { + codec = CompressionLZ4 // mix codecs for more coverage + } + + eg.Go(func() error { + up := &memPartUploader{} + cfg := defaultCfg(codec, workers, frameSize) + cfg.MinPartSizeMB = minPartSizeMB + if codec == CompressionZstd { + cfg.EncoderConcurrency = 4 // multi-threaded zstd encoders for more contention + } + + ft, checksum, err := compressStream(ctx, bytes.NewReader(data), cfg, up, 4) + if err != nil { + return fmt.Errorf("stream %d: compress: %w", i, err) + } + + if checksum != wantChecksum { + return fmt.Errorf("stream %d: checksum mismatch", i) + } + + decompressed, err := decompressAll(ft, up.Assemble()) + if err != nil { + return fmt.Errorf("stream %d: decompress: %w", i, err) + } + + if !bytes.Equal(data, decompressed) { + return fmt.Errorf("stream %d: round-trip data mismatch", i) + } + + return nil + }) + } + + require.NoError(t, eg.Wait()) +} + +func BenchmarkCompress(b *testing.B) { + const dataSize = 256 * megabyte + data := generateSemiRandomData(dataSize) + + configs := []struct { + name string + workers int + bandwidth int64 // bytes/sec; 0 = unlimited + }{ + {"w1_unlimited", 1, 0}, + {"w2_unlimited", 2, 0}, + {"w4_unlimited", 4, 0}, + {"w1_200MBs", 1, 200 * megabyte}, + {"w4_200MBs", 4, 200 * megabyte}, + {"w4_100MBs", 4, 100 * megabyte}, + } + + for _, bcfg := range configs { + b.Run(bcfg.name, func(b *testing.B) { + compCfg := CompressConfig{ + Enabled: true, + Type: CompressionZstd.String(), + Level: 2, + EncoderConcurrency: 1, + FrameEncodeWorkers: bcfg.workers, + FrameSizeKB: 2 * 1024, + MinPartSizeMB: 50, + } + + var lastParts atomic.Int32 + + b.ResetTimer() + b.SetBytes(int64(dataSize)) + + for range b.N { + up := &ThrottledPartUploader{bandwidth: bcfg.bandwidth} + + _, _, err := compressStream( + context.Background(), + bytes.NewReader(data), + compCfg, + up, 4, + ) + if err != nil { + b.Fatal(err) + } + + lastParts.Store(int32(len(up.parts))) + } + + // Report after all iterations using last run's values. + // b.SetBytes already reports MB/s (uncompressed throughput). + b.ReportMetric(float64(lastParts.Load()), "parts") + }) + } +} + +func BenchmarkStoreFile(b *testing.B) { + const dataSize = 1024 * megabyte // 1 GB + + data := generateSemiRandomData(dataSize) + inputDir := b.TempDir() + inputPath := filepath.Join(inputDir, "input.bin") + require.NoError(b, os.WriteFile(inputPath, data, 0o644)) + data = nil //nolint:ineffassign,wastedassign // hint GC to free 1GB before benchmark loop + + codecs := []struct { + name string + codec CompressionType + level int + }{ + {"zstd1", CompressionZstd, 1}, + {"zstd2", CompressionZstd, 2}, + {"zstd3", CompressionZstd, 3}, + {"lz4", CompressionLZ4, 0}, + } + workerCounts := []int{1, 2, 4, 8} + + for _, codec := range codecs { + for _, workers := range workerCounts { + name := fmt.Sprintf("%s/w%d", codec.name, workers) + b.Run(name, func(b *testing.B) { + compCfg := CompressConfig{ + Enabled: true, + Type: codec.codec.String(), + Level: codec.level, + EncoderConcurrency: 1, + FrameEncodeWorkers: workers, + FrameSizeKB: 2 * 1024, + MinPartSizeMB: 50, + } + + b.SetBytes(int64(dataSize)) + b.ResetTimer() + + for range b.N { + outDir := b.TempDir() + outPath := filepath.Join(outDir, "output.dat") + obj := &fsObject{path: outPath} + + ft, _, err := obj.StoreFile(b.Context(), inputPath, WithCompressConfig(compCfg)) + if err != nil { + b.Fatal(err) + } + + b.ReportMetric(float64(ft.CompressedSize())/float64(ft.UncompressedSize()), "ratio") + } + }) + } + } + + b.Run("uncompressed", func(b *testing.B) { + b.SetBytes(int64(dataSize)) + b.ResetTimer() + + for range b.N { + outDir := b.TempDir() + outPath := filepath.Join(outDir, "output.dat") + + in, err := os.Open(inputPath) + if err != nil { + b.Fatal(err) + } + out, err := os.Create(outPath) + if err != nil { + in.Close() + b.Fatal(err) + } + if _, err := io.Copy(out, in); err != nil { + in.Close() + out.Close() + b.Fatal(err) + } + in.Close() + out.Close() + } + }) +} diff --git a/packages/shared/pkg/storage/gcp_multipart.go b/packages/shared/pkg/storage/gcp_multipart.go index 94315a76fb..918b0501a0 100644 --- a/packages/shared/pkg/storage/gcp_multipart.go +++ b/packages/shared/pkg/storage/gcp_multipart.go @@ -140,6 +140,57 @@ type MultipartUploader struct { retryConfig RetryConfig metadata ObjectMetadata baseURL string // Allow overriding for testing + + // Fields for partUploader interface + uploadID string + mu sync.Mutex + parts []Part +} + +var _ partUploader = (*MultipartUploader)(nil) + +// Start initiates the GCS multipart upload. +func (m *MultipartUploader) Start(ctx context.Context) error { + uploadID, err := m.initiateUpload(ctx) + if err != nil { + return fmt.Errorf("failed to initiate multipart upload: %w", err) + } + + m.uploadID = uploadID + + return nil +} + +// UploadPart uploads a single part to GCS. Multiple data slices are hashed +// and uploaded without copying into a single contiguous buffer. +func (m *MultipartUploader) UploadPart(ctx context.Context, partIndex int, data ...[]byte) error { + etag, err := m.uploadPartSlices(ctx, m.uploadID, partIndex, data) + if err != nil { + return fmt.Errorf("failed to upload part %d: %w", partIndex, err) + } + + m.mu.Lock() + m.parts = append(m.parts, Part{ + PartNumber: partIndex, + ETag: etag, + }) + m.mu.Unlock() + + return nil +} + +// Complete finalizes the GCS multipart upload with all collected parts. +func (m *MultipartUploader) Complete(ctx context.Context) error { + m.mu.Lock() + parts := make([]Part, len(m.parts)) + copy(parts, m.parts) + m.mu.Unlock() + + return m.completeUpload(ctx, m.uploadID, parts) +} + +func (m *MultipartUploader) Close() error { + return nil } func NewMultipartUploaderWithRetryConfig(ctx context.Context, bucketName, objectName string, retryConfig RetryConfig, metadata ObjectMetadata) (*MultipartUploader, error) { @@ -238,6 +289,60 @@ func (m *MultipartUploader) uploadPart(ctx context.Context, uploadID string, par return etag, nil } +// uploadPartSlices uploads a part from multiple byte slices without concatenating them. +// It computes MD5 by hashing each slice and uses a ReaderFunc for retryable reads. +func (m *MultipartUploader) uploadPartSlices(ctx context.Context, uploadID string, partNumber int, slices [][]byte) (string, error) { + // Compute MD5 and total length without copying + hasher := md5.New() + totalLen := 0 + for _, s := range slices { + hasher.Write(s) + totalLen += len(s) + } + md5Sum := base64.StdEncoding.EncodeToString(hasher.Sum(nil)) + + url := fmt.Sprintf("%s/%s?partNumber=%d&uploadId=%s", + m.baseURL, m.objectName, partNumber, uploadID) + + // Use a ReaderFunc so the retryable client can replay the body on retries + bodyFn := func() (io.Reader, error) { + readers := make([]io.Reader, len(slices)) + for i, s := range slices { + readers[i] = bytes.NewReader(s) + } + + return io.MultiReader(readers...), nil + } + + req, err := retryablehttp.NewRequestWithContext(ctx, "PUT", url, retryablehttp.ReaderFunc(bodyFn)) + if err != nil { + return "", err + } + + req.Header.Set("Authorization", "Bearer "+m.token) + req.Header.Set("Content-Length", fmt.Sprintf("%d", totalLen)) + req.Header.Set("Content-MD5", md5Sum) + + resp, err := m.client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + + return "", fmt.Errorf("failed to upload part %d (status %d): %s", partNumber, resp.StatusCode, string(body)) + } + + etag := resp.Header.Get("ETag") + if etag == "" { + return "", fmt.Errorf("no ETag returned for part %d", partNumber) + } + + return etag, nil +} + func (m *MultipartUploader) completeUpload(ctx context.Context, uploadID string, parts []Part) error { // Sort parts by part number sort.Slice(parts, func(i, j int) bool { diff --git a/packages/shared/pkg/storage/gcp_multipart_test.go b/packages/shared/pkg/storage/gcp_multipart_test.go index 7fe4d397ce..49ed8bbc51 100644 --- a/packages/shared/pkg/storage/gcp_multipart_test.go +++ b/packages/shared/pkg/storage/gcp_multipart_test.go @@ -1,6 +1,8 @@ package storage import ( + "crypto/md5" + "encoding/base64" "encoding/xml" "fmt" "io" @@ -115,6 +117,42 @@ func TestMultipartUploader_UploadPart_Success(t *testing.T) { require.Equal(t, expectedETag, etag) } +func TestMultipartUploader_UploadPartSlices_Success(t *testing.T) { + t.Parallel() + expectedETag := `"slice-etag"` + slices := [][]byte{[]byte("hello "), []byte("world"), []byte("!")} + + // Compute expected MD5 over all slices. + h := md5.New() + for _, s := range slices { + h.Write(s) + } + expectedMD5 := base64.StdEncoding.EncodeToString(h.Sum(nil)) + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "PUT", r.Method) + assert.Contains(t, r.URL.RawQuery, "partNumber=3") + assert.Contains(t, r.URL.RawQuery, "uploadId=test-upload-id") + + // Verify MD5 matches the expected hash of all slices. + assert.Equal(t, expectedMD5, r.Header.Get("Content-MD5")) + + // Verify body is the concatenation of all slices. + body, err := io.ReadAll(r.Body) + assert.NoError(t, err) + assert.Equal(t, []byte("hello world!"), body) + + w.Header().Set("ETag", expectedETag) + w.WriteHeader(http.StatusOK) + }) + + uploader := createTestMultipartUploader(t, handler) + etag, err := uploader.uploadPartSlices(t.Context(), "test-upload-id", 3, slices) + + require.NoError(t, err) + require.Equal(t, expectedETag, etag) +} + func TestMultipartUploader_UploadPart_MissingETag(t *testing.T) { t.Parallel() handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { @@ -170,7 +208,6 @@ func TestMultipartUploader_UploadFileInParallel_Success(t *testing.T) { err := os.WriteFile(testFile, []byte(testContent), 0o644) require.NoError(t, err) - var uploadID string var initiateCount, uploadPartCount, completeCount int32 receivedParts := sync.Map{} @@ -179,11 +216,10 @@ func TestMultipartUploader_UploadFileInParallel_Success(t *testing.T) { case r.URL.RawQuery == uploadsPath: // Initiate upload atomic.AddInt32(&initiateCount, 1) - uploadID = "test-upload-id-123" response := InitiateMultipartUploadResult{ Bucket: testBucketName, Key: testObjectName, - UploadID: uploadID, + UploadID: "test-upload-id-123", } xmlData, _ := xml.Marshal(response) w.Header().Set("Content-Type", "application/xml") @@ -524,7 +560,7 @@ func TestMultipartUploader_EdgeCases_VerySmallFile(t *testing.T) { err := os.WriteFile(smallFile, []byte(smallContent), 0o644) require.NoError(t, err) - var receivedData string + var receivedParts sync.Map handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { @@ -540,7 +576,8 @@ func TestMultipartUploader_EdgeCases_VerySmallFile(t *testing.T) { case strings.Contains(r.URL.RawQuery, "partNumber"): body, _ := io.ReadAll(r.Body) - receivedData = string(body) + partNum := r.URL.Query().Get("partNumber") + receivedParts.Store(partNum, string(body)) w.Header().Set("ETag", `"small-etag"`) w.WriteHeader(http.StatusOK) @@ -553,7 +590,18 @@ func TestMultipartUploader_EdgeCases_VerySmallFile(t *testing.T) { uploader := createTestMultipartUploader(t, handler) _, err = uploader.UploadFileInParallel(t.Context(), smallFile, 10) // High concurrency for small file require.NoError(t, err) - require.Equal(t, smallContent, receivedData) + + // Small file should produce exactly one part + var partCount int + receivedParts.Range(func(_, _ any) bool { + partCount++ + + return true + }) + require.Equal(t, 1, partCount) + data, ok := receivedParts.Load("1") + require.True(t, ok) + require.Equal(t, smallContent, data.(string)) } type repeatReader struct { @@ -692,8 +740,9 @@ func TestMultipartUploader_BoundaryConditions_ExactChunkSize(t *testing.T) { // Should have exactly 2 parts, each of ChunkSize require.Len(t, partSizes, 2) - require.Equal(t, gcpMultipartUploadChunkSize, partSizes[0]) - require.Equal(t, gcpMultipartUploadChunkSize, partSizes[1]) + for _, size := range partSizes { + require.Equal(t, gcpMultipartUploadChunkSize, size) + } } func TestMultipartUploader_FileNotFound_Error(t *testing.T) { diff --git a/packages/shared/pkg/storage/header/header.go b/packages/shared/pkg/storage/header/header.go index 32067b7887..ecc30e39ca 100644 --- a/packages/shared/pkg/storage/header/header.go +++ b/packages/shared/pkg/storage/header/header.go @@ -1,22 +1,76 @@ package header import ( + "cmp" "context" "errors" "fmt" + "maps" + "slices" "sort" "github.com/google/uuid" "go.uber.org/zap" "github.com/e2b-dev/infra/packages/shared/pkg/logger" + "github.com/e2b-dev/infra/packages/shared/pkg/storage" ) +// BuildData holds per-build metadata stored in V4 headers. +// Each layer's header carries a Builds map; child headers inherit parent +// entries for still-referenced build IDs via newDiffHeader. +type BuildData struct { + Size int64 // uncompressed file size + Checksum [32]byte // SHA-256 of uncompressed data; zero value means unknown + FrameData *storage.FrameTable // nil for uncompressed builds +} + const NormalizeFixVersion = 3 type Header struct { Metadata *Metadata - Mapping []BuildMap + // Builds maps build IDs to per-build metadata (size, checksum, FrameTable). + // nil for V3 (uncompressed) headers; the read path falls back to a Size() + // RPC and reads uncompressed data when nil. + Builds map[uuid.UUID]BuildData + + Mapping []BuildMap + + // IncompletePendingUpload is set on diff headers produced by ToDiffHeader and + // cleared on the finalized headers swapped in by the upload pipeline. It + // is in-memory only (never serialized), and signals that the build's data + // has not yet reached object storage — readers must serve from the local + // cache and skip FrameTable lookups for the still-missing self entry. + IncompletePendingUpload bool +} + +// CloneForUpload returns a clone with copied Mapping and Builds, safe to +// mutate for serialization without racing with concurrent readers of the +// original. The version is set on the clone. +func (t *Header) CloneForUpload(version uint64) *Header { + metaCopy := *t.Metadata + metaCopy.Version = version + + clone := &Header{ + Metadata: &metaCopy, + Mapping: slices.Clone(t.Mapping), + } + + if t.Builds != nil { + clone.Builds = make(map[uuid.UUID]BuildData, len(t.Builds)) + maps.Copy(clone.Builds, t.Builds) + } + + return clone +} + +// SetBuild adds or replaces build metadata for the given build ID. +func (t *Header) SetBuild(buildID uuid.UUID, bd BuildData) { + if t.Builds == nil { + t.Builds = make(map[uuid.UUID]BuildData) + } + + t.Builds[buildID] = bd } func NewHeader(metadata *Metadata, mapping []BuildMap) (*Header, error) { @@ -39,27 +93,74 @@ func NewHeader(metadata *Metadata, mapping []BuildMap) (*Header, error) { }, nil } +func newDiffHeader(metadata *Metadata, mapping []BuildMap, sourceBuilds map[uuid.UUID]BuildData) (*Header, error) { + h, err := NewHeader(metadata, mapping) + if err != nil { + return nil, err + } + + if sourceBuilds != nil { + referenced := make(map[uuid.UUID]struct{}, len(h.Mapping)) + for _, m := range h.Mapping { + referenced[m.BuildId] = struct{}{} + } + + h.Builds = make(map[uuid.UUID]BuildData, len(referenced)) + for id := range referenced { + if bd, ok := sourceBuilds[id]; ok { + h.Builds[id] = bd + } + } + } + + h.IncompletePendingUpload = true + + return h, nil +} + +func (t *Header) String() string { + if t == nil { + return "[nil Header]" + } + + return fmt.Sprintf("[Header: version=%d, size=%d, blockSize=%d, generation=%d, buildId=%s, mappings=%d]", + t.Metadata.Version, + t.Metadata.Size, + t.Metadata.BlockSize, + t.Metadata.Generation, + t.Metadata.BuildId.String(), + len(t.Mapping), + ) +} + // IsNormalizeFixApplied is a helper method to soft fail for older versions of the header where fix for normalization was not applied. // This should be removed in the future. func (t *Header) IsNormalizeFixApplied() bool { return t.Metadata.Version >= NormalizeFixVersion } -func (t *Header) GetShiftedMapping(ctx context.Context, offset int64) (mappedOffset int64, mappedLength int64, buildID *uuid.UUID, err error) { +// GetShiftedMapping resolves a virtual offset to a build-local range. +// The read path uses this to find which build owns the data, then calls +// GetBuildFrameData to get the FrameTable for C-space lookup. +func (t *Header) GetShiftedMapping(ctx context.Context, offset int64) (BuildMap, error) { mapping, shift, err := t.getMapping(ctx, offset) if err != nil { - return 0, 0, nil, err + return BuildMap{}, err } + mappedLength := int64(mapping.Length) - shift - mappedOffset = int64(mapping.BuildStorageOffset) + shift - mappedLength = int64(mapping.Length) - shift - buildID = &mapping.BuildId + b := BuildMap{ + Offset: mapping.BuildStorageOffset + uint64(shift), + Length: uint64(mappedLength), + BuildId: mapping.BuildId, + } if mappedLength < 0 { if t.IsNormalizeFixApplied() { - return 0, 0, nil, fmt.Errorf("mapped length for offset %d is negative: %d", offset, mappedLength) + return BuildMap{}, fmt.Errorf("mapped length for offset %d is negative: %d", offset, mappedLength) } + b.Length = 0 logger.L().Warn(ctx, "mapped length is negative, but normalize fix is not applied", zap.Int64("offset", offset), zap.Int64("mappedLength", mappedLength), @@ -67,7 +168,17 @@ func (t *Header) GetShiftedMapping(ctx context.Context, offset int64) (mappedOff ) } - return mappedOffset, mappedLength, buildID, nil + return b, nil +} + +// GetBuildFrameData returns the FrameTable for a build, or nil. +// nil means the build is uncompressed — the caller reads raw bytes instead. +func (t *Header) GetBuildFrameData(buildID uuid.UUID) *storage.FrameTable { + if t.Builds == nil { + return nil + } + + return t.Builds[buildID].FrameData } func (t *Header) getMapping(ctx context.Context, offset int64) (*BuildMap, int64, error) { @@ -122,3 +233,81 @@ func (t *Header) getMapping(ctx context.Context, offset int64) (*BuildMap, int64 return mapping, shift, nil } + +// ValidateHeader checks header integrity and returns an error if corruption is detected. +// This verifies: +// 1. Header and metadata are valid +// 2. Mappings cover the entire file [0, Size) with no gaps +// 3. Mappings don't extend beyond file size (with block alignment tolerance) +func ValidateHeader(h *Header) error { + if h == nil { + return errors.New("header is nil") + } + if h.Metadata == nil { + return errors.New("header metadata is nil") + } + if h.Metadata.BlockSize == 0 { + return errors.New("header has zero block size") + } + if h.Metadata.Size == 0 { + return errors.New("header has zero size") + } + if len(h.Mapping) == 0 { + return errors.New("header has no mappings") + } + + // Sort mappings by offset to check for gaps/overlaps + sortedMappings := slices.Clone(h.Mapping) + slices.SortFunc(sortedMappings, func(a, b BuildMap) int { + return cmp.Compare(a.Offset, b.Offset) + }) + + // Check that first mapping starts at 0 + if sortedMappings[0].Offset != 0 { + return fmt.Errorf("mappings don't start at 0: first mapping starts at %d for buildId %s", + sortedMappings[0].Offset, h.Metadata.BuildId.String()) + } + + // Check for gaps and overlaps between consecutive mappings + for i := range len(sortedMappings) - 1 { + currentEnd := sortedMappings[i].Offset + sortedMappings[i].Length + nextStart := sortedMappings[i+1].Offset + + if currentEnd < nextStart { + return fmt.Errorf("gap in mappings: mapping[%d] ends at %d but mapping[%d] starts at %d (gap=%d bytes) for buildId %s", + i, currentEnd, i+1, nextStart, nextStart-currentEnd, h.Metadata.BuildId.String()) + } + if currentEnd > nextStart { + return fmt.Errorf("overlap in mappings: mapping[%d] ends at %d but mapping[%d] starts at %d (overlap=%d bytes) for buildId %s", + i, currentEnd, i+1, nextStart, currentEnd-nextStart, h.Metadata.BuildId.String()) + } + } + + // Check that last mapping covers up to (at least) Size + lastMapping := sortedMappings[len(sortedMappings)-1] + lastEnd := lastMapping.Offset + lastMapping.Length + if lastEnd < h.Metadata.Size { + return fmt.Errorf("mappings don't cover entire file: last mapping ends at %d but file size is %d (missing %d bytes) for buildId %s", + lastEnd, h.Metadata.Size, h.Metadata.Size-lastEnd, h.Metadata.BuildId.String()) + } + + // Allow last mapping to extend up to one block past size (for alignment) + if lastEnd > h.Metadata.Size+h.Metadata.BlockSize { + return fmt.Errorf("last mapping extends too far: ends at %d but file size is %d (overhang=%d bytes, max allowed=%d) for buildId %s", + lastEnd, h.Metadata.Size, lastEnd-h.Metadata.Size, h.Metadata.BlockSize, h.Metadata.BuildId.String()) + } + + // Validate individual mapping bounds + for i, m := range h.Mapping { + if m.Offset > h.Metadata.Size { + return fmt.Errorf("mapping[%d] has Offset %d beyond header size %d for buildId %s", + i, m.Offset, h.Metadata.Size, m.BuildId.String()) + } + if m.Length == 0 { + return fmt.Errorf("mapping[%d] has zero length at offset %d for buildId %s", + i, m.Offset, m.BuildId.String()) + } + } + + return nil +} diff --git a/packages/shared/pkg/storage/header/mapping_test.go b/packages/shared/pkg/storage/header/mapping_test.go index 2ca2eab247..c32d9a1dbc 100644 --- a/packages/shared/pkg/storage/header/mapping_test.go +++ b/packages/shared/pkg/storage/header/mapping_test.go @@ -704,3 +704,181 @@ func TestNormalizeMappingsDoesNotModifyInput(t *testing.T) { err := ValidateMappings(m, 6*blockSize, blockSize) require.NoError(t, err) } + +// TestMergeMappings_Splits verifies that MergeMappings preserves +// BuildStorageOffset through splits. When a diff lands in the middle of a +// base mapping the base is split into left/right pieces; each piece must +// keep the correct BuildStorageOffset so the read path fetches data from +// the right position within each build's data blob. Without this, +// compressed builds whose frame tables are keyed by BuildStorageOffset +// would decompress the wrong frames. +func TestMergeMappings_Splits(t *testing.T) { + t.Parallel() + + compBaseID := uuid.New() + compDiffID := uuid.New() + plainID := uuid.New() + + tests := map[string]struct { + base []BuildMap + diff []BuildMap + validate func(t *testing.T, merged []BuildMap) + }{ + "diff inside base — left and right split correctly": { + base: []BuildMap{{ + Offset: 0, Length: 6 * blockSize, + BuildId: compBaseID, BuildStorageOffset: 0, + }}, + diff: []BuildMap{{ + Offset: 2 * blockSize, Length: 2 * blockSize, + BuildId: compDiffID, + }}, + validate: func(t *testing.T, m []BuildMap) { + t.Helper() + require.Len(t, m, 3) + + assert.Equal(t, uint64(0), m[0].Offset) + assert.Equal(t, 2*blockSize, m[0].Length) + assert.Equal(t, compBaseID, m[0].BuildId) + + assert.Equal(t, compDiffID, m[1].BuildId) + + assert.Equal(t, 4*blockSize, m[2].Offset) + assert.Equal(t, 4*blockSize, m[2].BuildStorageOffset) + assert.Equal(t, compBaseID, m[2].BuildId) + }, + }, + + "base after diff with overlap — right split keeps tail": { + base: []BuildMap{ + {Offset: 0, Length: 1 * blockSize, BuildId: plainID}, + { + Offset: 1 * blockSize, Length: 4 * blockSize, + BuildId: compBaseID, BuildStorageOffset: 0, + }, + }, + diff: []BuildMap{{ + Offset: 0, Length: 3 * blockSize, + BuildId: compDiffID, + }}, + validate: func(t *testing.T, m []BuildMap) { + t.Helper() + require.Len(t, m, 2) + + assert.Equal(t, compDiffID, m[0].BuildId) + + assert.Equal(t, 3*blockSize, m[1].Offset) + assert.Equal(t, 2*blockSize, m[1].BuildStorageOffset) + assert.Equal(t, compBaseID, m[1].BuildId) + }, + }, + + "diff after base with overlap — left split keeps head": { + base: []BuildMap{ + { + Offset: 0, Length: 4 * blockSize, + BuildId: compBaseID, BuildStorageOffset: 0, + }, + {Offset: 4 * blockSize, Length: 2 * blockSize, BuildId: plainID}, + }, + diff: []BuildMap{{ + Offset: 2 * blockSize, Length: 4 * blockSize, + BuildId: compDiffID, + }}, + validate: func(t *testing.T, m []BuildMap) { + t.Helper() + require.Len(t, m, 2) + + assert.Equal(t, uint64(0), m[0].Offset) + assert.Equal(t, 2*blockSize, m[0].Length) + assert.Equal(t, compBaseID, m[0].BuildId) + + assert.Equal(t, compDiffID, m[1].BuildId) + }, + }, + + "two diffs split same base into three pieces": { + base: []BuildMap{{ + Offset: 0, Length: 6 * blockSize, + BuildId: compBaseID, BuildStorageOffset: 0, + }}, + diff: []BuildMap{ + {Offset: 1 * blockSize, Length: 1 * blockSize, BuildId: compDiffID}, + {Offset: 4 * blockSize, Length: 1 * blockSize, BuildId: compDiffID}, + }, + validate: func(t *testing.T, m []BuildMap) { + t.Helper() + require.Len(t, m, 5) + + assert.Equal(t, compBaseID, m[0].BuildId) + assert.Equal(t, 1*blockSize, m[0].Length) + + assert.Equal(t, compDiffID, m[1].BuildId) + + assert.Equal(t, compBaseID, m[2].BuildId) + assert.Equal(t, 2*blockSize, m[2].Length) + assert.Equal(t, 2*blockSize, m[2].BuildStorageOffset) + + assert.Equal(t, compDiffID, m[3].BuildId) + + assert.Equal(t, compBaseID, m[4].BuildId) + assert.Equal(t, 1*blockSize, m[4].Length) + assert.Equal(t, 5*blockSize, m[4].BuildStorageOffset) + }, + }, + + "multi-layer base — diff splits middle build": { + base: func() []BuildMap { + buildA := uuid.New() + buildB := compBaseID + buildC := uuid.New() + + return []BuildMap{ + { + Offset: 0, Length: 2 * blockSize, + BuildId: buildA, BuildStorageOffset: 0, + }, + { + Offset: 2 * blockSize, Length: 4 * blockSize, + BuildId: buildB, BuildStorageOffset: 0, + }, + { + Offset: 6 * blockSize, Length: 2 * blockSize, + BuildId: buildC, BuildStorageOffset: 0, + }, + } + }(), + diff: []BuildMap{{ + Offset: 3 * blockSize, Length: 2 * blockSize, + BuildId: compDiffID, + }}, + validate: func(t *testing.T, m []BuildMap) { + t.Helper() + require.Len(t, m, 5) + + // Build B left + assert.Equal(t, 2*blockSize, m[1].Offset) + assert.Equal(t, 1*blockSize, m[1].Length) + assert.Equal(t, uint64(0), m[1].BuildStorageOffset) + + // Diff + assert.Equal(t, compDiffID, m[2].BuildId) + + // Build B right + assert.Equal(t, 5*blockSize, m[3].Offset) + assert.Equal(t, 1*blockSize, m[3].Length) + assert.Equal(t, 3*blockSize, m[3].BuildStorageOffset) + }, + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + + merged := MergeMappings(tc.base, tc.diff) + + tc.validate(t, merged) + }) + } +} diff --git a/packages/shared/pkg/storage/header/metadata.go b/packages/shared/pkg/storage/header/metadata.go index 5bba169396..4570fbf2f3 100644 --- a/packages/shared/pkg/storage/header/metadata.go +++ b/packages/shared/pkg/storage/header/metadata.go @@ -1,7 +1,9 @@ package header import ( + "bytes" "context" + "encoding/binary" "fmt" "io" @@ -15,6 +17,59 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/telemetry" ) +const ( + // metadataVersion is used by template-manager for uncompressed builds (V3 headers). + metadataVersion = 3 + // MetadataVersionV4 is used for compressed builds (V4 headers with FrameTables). + MetadataVersionV4 = 4 +) + +type Metadata struct { + Version uint64 + BlockSize uint64 + Size uint64 + Generation uint64 + BuildId uuid.UUID + // TODO: Use the base build id when setting up the snapshot rootfs + BaseBuildId uuid.UUID +} + +func NewTemplateMetadata(buildId uuid.UUID, blockSize, size uint64) *Metadata { + return &Metadata{ + Version: metadataVersion, + Generation: 0, + BlockSize: blockSize, + Size: size, + BuildId: buildId, + BaseBuildId: buildId, + } +} + +func (m *Metadata) NextGeneration(buildID uuid.UUID) *Metadata { + return &Metadata{ + Version: m.Version, + Generation: m.Generation + 1, + BlockSize: m.BlockSize, + Size: m.Size, + BuildId: buildID, + BaseBuildId: m.BaseBuildId, + } +} + +// metadataSize is the binary size of the Metadata struct, computed from the struct layout. +var metadataSize = binary.Size(Metadata{}) + +func deserializeMetadata(data []byte) (*Metadata, error) { + var metadata Metadata + + err := binary.Read(bytes.NewReader(data), binary.LittleEndian, &metadata) + if err != nil { + return nil, fmt.Errorf("failed to read metadata: %w", err) + } + + return &metadata, nil +} + var ignoreBuildID = uuid.Nil type DiffMetadata struct { @@ -96,7 +151,7 @@ func (d *DiffMetadata) ToDiffHeader( attribute.String("snapshot.metadata.base_build_id", metadata.BaseBuildId.String()), ) - header, err := NewHeader(metadata, m) + header, err := newDiffHeader(metadata, m, originalHeader.Builds) if err != nil { return nil, fmt.Errorf("failed to create header: %w", err) } diff --git a/packages/shared/pkg/storage/header/serialization.go b/packages/shared/pkg/storage/header/serialization.go index 724d72e6e9..c9c86fe0b9 100644 --- a/packages/shared/pkg/storage/header/serialization.go +++ b/packages/shared/pkg/storage/header/serialization.go @@ -1,102 +1,88 @@ package header import ( - "bytes" "context" - "encoding/binary" - "errors" "fmt" - "io" - - "github.com/google/uuid" "github.com/e2b-dev/infra/packages/shared/pkg/storage" ) -const metadataVersion = 3 - -type Metadata struct { - Version uint64 - BlockSize uint64 - Size uint64 - Generation uint64 - BuildId uuid.UUID - // TODO: Use the base build id when setting up the snapshot rootfs - BaseBuildId uuid.UUID -} - -func NewTemplateMetadata(buildId uuid.UUID, blockSize, size uint64) *Metadata { - return &Metadata{ - Version: metadataVersion, - Generation: 0, - BlockSize: blockSize, - Size: size, - BuildId: buildId, - BaseBuildId: buildId, +// SerializeHeader serializes a header, dispatching to the version-specific format. +// +// V3 (Version <= 3): [Metadata] [v3 mappings…] +// V4 (Version >= 4): [Metadata] [uint8 flags] [uint32 uncompressedSize] [LZ4( Builds + v4 mappings )] +func SerializeHeader(h *Header) ([]byte, error) { + if h.Metadata.Version <= 3 { + return serializeV3(h.Metadata, h.Mapping) } -} -func (m *Metadata) NextGeneration(buildID uuid.UUID) *Metadata { - return &Metadata{ - Version: m.Version, - Generation: m.Generation + 1, - BlockSize: m.BlockSize, - Size: m.Size, - BuildId: buildID, - BaseBuildId: m.BaseBuildId, - } + return serializeV4(h.Metadata, h.Builds, h.Mapping, h.IncompletePendingUpload) } -func Serialize(metadata *Metadata, mappings []BuildMap) ([]byte, error) { - var buf bytes.Buffer +// DeserializeBytes auto-detects the header version and deserializes accordingly. +// See SerializeHeader for the binary layout. +func DeserializeBytes(data []byte) (*Header, error) { + if len(data) < metadataSize { + return nil, fmt.Errorf("header too short: %d bytes", len(data)) + } - err := binary.Write(&buf, binary.LittleEndian, metadata) + metadata, err := deserializeMetadata(data[:metadataSize]) if err != nil { - return nil, fmt.Errorf("failed to write metadata: %w", err) + return nil, err } - for i := range mappings { - err := binary.Write(&buf, binary.LittleEndian, &mappings[i]) - if err != nil { - return nil, fmt.Errorf("failed to write block mapping: %w", err) - } + blockData := data[metadataSize:] + + if metadata.Version >= 4 { + return deserializeV4(metadata, blockData) } - return buf.Bytes(), nil + return deserializeV3(metadata, blockData) } -func Deserialize(ctx context.Context, in storage.Blob) (*Header, error) { - data, err := storage.GetBlob(ctx, in) +// LoadHeader fetches a serialized header from storage and deserializes it. +// Errors (including storage.ErrObjectNotExist) are returned as-is. +func LoadHeader(ctx context.Context, s storage.StorageProvider, path string) (*Header, error) { + blob, err := s.OpenBlob(ctx, path, storage.MetadataObjectType) if err != nil { - return nil, fmt.Errorf("failed to write to buffer: %w", err) + return nil, fmt.Errorf("open blob %s: %w", path, err) + } + + data, err := storage.GetBlob(ctx, blob) + if err != nil { + return nil, err } return DeserializeBytes(data) } -func DeserializeBytes(data []byte) (*Header, error) { - reader := bytes.NewReader(data) - var metadata Metadata - err := binary.Read(reader, binary.LittleEndian, &metadata) - if err != nil { - return nil, fmt.Errorf("failed to read metadata: %w", err) +// StoreHeader serializes a header and uploads it to long-term storage. +// Refuses to persist a header still flagged as in-flight — the upload pipeline +// must clear IncompletePendingUpload before reaching here. +func StoreHeader(ctx context.Context, s storage.StorageProvider, path string, h *Header) error { + if h.IncompletePendingUpload { + return fmt.Errorf("refusing to persist incomplete header for %s", path) } - mappings := make([]BuildMap, 0) + data, err := SerializeHeader(h) + if err != nil { + return fmt.Errorf("serialize header: %w", err) + } - for { - var m BuildMap - err := binary.Read(reader, binary.LittleEndian, &m) - if errors.Is(err, io.EOF) { - break - } + blob, err := s.OpenBlob(ctx, path, storage.MetadataObjectType) + if err != nil { + return fmt.Errorf("open blob %s: %w", path, err) + } - if err != nil { - return nil, fmt.Errorf("failed to read block mapping: %w", err) - } + return blob.Put(ctx, data) +} - mappings = append(mappings, m) +// Deserialize reads a header from a storage Blob (legacy API). +func Deserialize(ctx context.Context, in storage.Blob) (*Header, error) { + data, err := storage.GetBlob(ctx, in) + if err != nil { + return nil, fmt.Errorf("failed to write to buffer: %w", err) } - return NewHeader(&metadata, mappings) + return DeserializeBytes(data) } diff --git a/packages/shared/pkg/storage/header/serialization_test.go b/packages/shared/pkg/storage/header/serialization_test.go new file mode 100644 index 0000000000..c791e801f8 --- /dev/null +++ b/packages/shared/pkg/storage/header/serialization_test.go @@ -0,0 +1,696 @@ +package header + +import ( + "crypto/sha256" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/e2b-dev/infra/packages/shared/pkg/storage" +) + +func TestSerializeDeserialize_V3_RoundTrip(t *testing.T) { + t.Parallel() + + buildID := uuid.New() + baseID := uuid.New() + metadata := &Metadata{ + Version: 3, + BlockSize: 4096, + Size: 8192, + Generation: 7, + BuildId: buildID, + BaseBuildId: baseID, + } + + mappings := []BuildMap{ + { + Offset: 0, + Length: 4096, + BuildId: buildID, + BuildStorageOffset: 0, + }, + { + Offset: 4096, + Length: 4096, + BuildId: baseID, + BuildStorageOffset: 123, + }, + } + + data, err := serializeV3(metadata, mappings) + require.NoError(t, err) + + got, err := DeserializeBytes(data) + require.NoError(t, err) + + require.Equal(t, metadata, got.Metadata) + require.Len(t, got.Mapping, 2) + require.Equal(t, uint64(0), got.Mapping[0].Offset) + require.Equal(t, uint64(4096), got.Mapping[0].Length) + require.Equal(t, buildID, got.Mapping[0].BuildId) + require.Equal(t, uint64(0), got.Mapping[0].BuildStorageOffset) + + require.Equal(t, uint64(4096), got.Mapping[1].Offset) + require.Equal(t, uint64(4096), got.Mapping[1].Length) + require.Equal(t, baseID, got.Mapping[1].BuildId) + require.Equal(t, uint64(123), got.Mapping[1].BuildStorageOffset) + + // V3 headers have no Builds + require.Nil(t, got.Builds) +} + +func TestDeserialize_TruncatedMetadata(t *testing.T) { + t.Parallel() + + _, err := DeserializeBytes([]byte{0x01, 0x02, 0x03}) + require.Error(t, err) + require.Contains(t, err.Error(), "header too short") +} + +func TestSerializeDeserialize_EmptyMappings_Defaults(t *testing.T) { + t.Parallel() + + metadata := &Metadata{ + Version: 3, + BlockSize: 4096, + Size: 8192, + Generation: 0, + BuildId: uuid.New(), + BaseBuildId: uuid.New(), + } + + data, err := serializeV3(metadata, nil) + require.NoError(t, err) + + got, err := DeserializeBytes(data) + require.NoError(t, err) + + // NewHeader creates a default mapping when none provided + require.Len(t, got.Mapping, 1) + require.Equal(t, uint64(0), got.Mapping[0].Offset) + require.Equal(t, metadata.Size, got.Mapping[0].Length) + require.Equal(t, metadata.BuildId, got.Mapping[0].BuildId) +} + +func TestDeserialize_BlockSizeZero(t *testing.T) { + t.Parallel() + + metadata := &Metadata{ + Version: 3, + BlockSize: 0, + Size: 4096, + Generation: 0, + BuildId: uuid.New(), + BaseBuildId: uuid.New(), + } + + data, err := serializeV3(metadata, nil) + require.NoError(t, err) + + _, err = DeserializeBytes(data) + require.Error(t, err) + require.Contains(t, err.Error(), "block size cannot be zero") +} + +func TestSerializeDeserialize_V4_WithFrameTable(t *testing.T) { + t.Parallel() + + buildID := uuid.New() + baseID := uuid.New() + metadata := &Metadata{ + Version: 4, + BlockSize: 4096, + Size: 8192, + Generation: 1, + BuildId: buildID, + BaseBuildId: baseID, + } + + mappings := []BuildMap{ + { + Offset: 0, + Length: 4096, + BuildId: buildID, + BuildStorageOffset: 0, + }, + { + Offset: 4096, + Length: 4096, + BuildId: baseID, + BuildStorageOffset: 0, + }, + } + + checksum := sha256.Sum256([]byte("test-data")) + + h, err := NewHeader(metadata, mappings) + require.NoError(t, err) + h.Builds = map[uuid.UUID]BuildData{ + buildID: { + Size: 12345, Checksum: checksum, + FrameData: storage.NewFrameTable(storage.CompressionLZ4, []storage.FrameSize{ + {U: 2048, C: 1024}, + {U: 2048, C: 900}, + }), + }, + baseID: {Size: 67890}, + } + + data, err := SerializeHeader(h) + require.NoError(t, err) + + got, err := DeserializeBytes(data) + require.NoError(t, err) + + require.Equal(t, uint64(4), got.Metadata.Version) + require.Len(t, got.Mapping, 2) + require.Equal(t, buildID, got.Mapping[0].BuildId) + require.Equal(t, baseID, got.Mapping[1].BuildId) + + // Builds round-trip + require.Len(t, got.Builds, 2) + require.Equal(t, int64(12345), got.Builds[buildID].Size) + require.Equal(t, checksum, got.Builds[buildID].Checksum) + require.Equal(t, int64(67890), got.Builds[baseID].Size) + + // Frame data round-trip + fd := got.Builds[buildID].FrameData + require.NotNil(t, fd) + require.Equal(t, storage.CompressionLZ4, fd.CompressionType()) + require.Equal(t, 2, fd.NumFrames()) + + r, err := fd.LocateCompressed(0) + require.NoError(t, err) + require.Equal(t, int64(0), r.Offset) + require.Equal(t, 1024, r.Length) + + r, err = fd.LocateCompressed(2048) + require.NoError(t, err) + require.Equal(t, int64(1024), r.Offset) + require.Equal(t, 900, r.Length) + + // baseID has no frames + require.Nil(t, got.Builds[baseID].FrameData) +} + +func TestSerializeDeserialize_V4_Zstd(t *testing.T) { + t.Parallel() + + buildID := uuid.New() + metadata := &Metadata{ + Version: 4, + BlockSize: 4096, + Size: 4096, + Generation: 0, + BuildId: buildID, + BaseBuildId: buildID, + } + + mappings := []BuildMap{ + { + Offset: 0, + Length: 4096, + BuildId: buildID, + BuildStorageOffset: 8192, + }, + } + + h, err := NewHeader(metadata, mappings) + require.NoError(t, err) + // 3 frames; only the third [8192, 12288) overlaps the mapping. + h.Builds = map[uuid.UUID]BuildData{ + buildID: { + FrameData: storage.NewFrameTable(storage.CompressionZstd, []storage.FrameSize{ + {U: 4096, C: 2000}, + {U: 4096, C: 3000}, + {U: 4096, C: 3500}, + }), + }, + } + + data, err := SerializeHeader(h) + require.NoError(t, err) + + got, err := DeserializeBytes(data) + require.NoError(t, err) + + require.Len(t, got.Mapping, 1) + require.Equal(t, uint64(8192), got.Mapping[0].BuildStorageOffset) + + require.Len(t, got.Builds, 1) + fd := got.Builds[buildID].FrameData + require.NotNil(t, fd) + require.Equal(t, storage.CompressionZstd, fd.CompressionType()) + require.Equal(t, 1, fd.NumFrames()) + + r, err := fd.LocateCompressed(8192) + require.NoError(t, err) + require.Equal(t, int64(2000+3000), r.Offset) + require.Equal(t, 3500, r.Length) +} + +func TestSerializeDeserialize_V4_NoFrames(t *testing.T) { + t.Parallel() + + buildID := uuid.New() + baseID := uuid.New() + metadata := &Metadata{ + Version: 4, + BlockSize: 4096, + Size: 8192, + Generation: 0, + BuildId: buildID, + BaseBuildId: buildID, + } + + mappings := []BuildMap{ + { + Offset: 0, + Length: 4096, + BuildId: buildID, + BuildStorageOffset: 0, + }, + { + Offset: 4096, + Length: 4096, + BuildId: baseID, + BuildStorageOffset: 0, + }, + } + + h, err := NewHeader(metadata, mappings) + require.NoError(t, err) + + data, err := SerializeHeader(h) + require.NoError(t, err) + + got, err := DeserializeBytes(data) + require.NoError(t, err) + + require.Len(t, got.Mapping, 2) + require.Nil(t, got.Builds) +} + +func TestSerializeDeserialize_V4_ManyFrames(t *testing.T) { + t.Parallel() + + buildID := uuid.New() + const numFrames = 1000 + frames := make([]storage.FrameSize, numFrames) + for i := range frames { + frames[i] = storage.FrameSize{U: 4096, C: int32(2000 + i)} + } + + metadata := &Metadata{ + Version: 4, + BlockSize: 4096, + Size: 4096 * numFrames, + Generation: 0, + BuildId: buildID, + BaseBuildId: buildID, + } + + mappings := []BuildMap{ + { + Offset: 0, + Length: 4096 * numFrames, + BuildId: buildID, + BuildStorageOffset: 0, + }, + } + + h, err := NewHeader(metadata, mappings) + require.NoError(t, err) + h.Builds = map[uuid.UUID]BuildData{ + buildID: {FrameData: storage.NewFrameTable(storage.CompressionLZ4, frames)}, + } + + data, err := SerializeHeader(h) + require.NoError(t, err) + + got, err := DeserializeBytes(data) + require.NoError(t, err) + + require.Len(t, got.Mapping, 1) + require.NotNil(t, got.Builds) + fd := got.Builds[buildID].FrameData + require.NotNil(t, fd) + require.Equal(t, numFrames, fd.NumFrames()) + + r, err := fd.LocateCompressed(0) + require.NoError(t, err) + require.Equal(t, int64(0), r.Offset) + require.Equal(t, 2000, r.Length) + + r, err = fd.LocateCompressed(int64(4096 * (numFrames - 1))) + require.NoError(t, err) + require.Equal(t, 2000+numFrames-1, r.Length) +} + +func TestSerializeDeserialize_V4_NoBuilds(t *testing.T) { + t.Parallel() + + buildID := uuid.New() + metadata := &Metadata{ + Version: 4, + BlockSize: 4096, + Size: 4096, + Generation: 0, + BuildId: buildID, + BaseBuildId: buildID, + } + + mappings := []BuildMap{ + { + Offset: 0, + Length: 4096, + BuildId: buildID, + }, + } + + h, err := NewHeader(metadata, mappings) + require.NoError(t, err) + // No Builds set (nil map) + + data, err := SerializeHeader(h) + require.NoError(t, err) + + got, err := DeserializeBytes(data) + require.NoError(t, err) + + require.Len(t, got.Mapping, 1) + require.Nil(t, got.Builds) +} + +func TestSerializeDeserialize_V4_MultiBuild_LocateCompressed(t *testing.T) { + t.Parallel() + + buildA := uuid.New() + buildB := uuid.New() + + // Build A: 3 frames, each 4096 uncompressed. + // Frame 0: U=[0,4096) C=[0,1000) + // Frame 1: U=[4096,8192) C=[1000,2800) + // Frame 2: U=[8192,12288) C=[2800,5100) + ftA := storage.NewFrameTable(storage.CompressionZstd, []storage.FrameSize{ + {U: 4096, C: 1000}, + {U: 4096, C: 1800}, + {U: 4096, C: 2300}, + }) + + // Build B: 2 frames, each 4096 uncompressed. + // Frame 0: U=[0,4096) C=[0,500) + // Frame 1: U=[4096,8192) C=[500,1700) + ftB := storage.NewFrameTable(storage.CompressionLZ4, []storage.FrameSize{ + {U: 4096, C: 500}, + {U: 4096, C: 1200}, + }) + + // Virtual layout (20480 bytes total): + // [0,4096) → buildA offset 0 + // [4096,12288) → buildB offset 0 (8192 bytes = both frames of B) + // [12288,20480)→ buildA offset 4096 (8192 bytes = frames 1..2 of A) + metadata := &Metadata{ + Version: 4, + BlockSize: 4096, + Size: 20480, + Generation: 2, + BuildId: buildA, + BaseBuildId: buildA, + } + + mappings := []BuildMap{ + {Offset: 0, Length: 4096, BuildId: buildA, BuildStorageOffset: 0}, + {Offset: 4096, Length: 8192, BuildId: buildB, BuildStorageOffset: 0}, + {Offset: 12288, Length: 8192, BuildId: buildA, BuildStorageOffset: 4096}, + } + + checksumA := sha256.Sum256([]byte("build-a")) + checksumB := sha256.Sum256([]byte("build-b")) + + h, err := NewHeader(metadata, mappings) + require.NoError(t, err) + h.Builds = map[uuid.UUID]BuildData{ + buildA: {Size: 12288, Checksum: checksumA, FrameData: ftA}, + buildB: {Size: 8192, Checksum: checksumB, FrameData: ftB}, + } + + data, err := SerializeHeader(h) + require.NoError(t, err) + + got, err := DeserializeBytes(data) + require.NoError(t, err) + + require.Equal(t, uint64(4), got.Metadata.Version) + require.Len(t, got.Mapping, 3) + require.Len(t, got.Builds, 2) + + // Verify checksums round-trip. + require.Equal(t, checksumA, got.Builds[buildA].Checksum) + require.Equal(t, checksumB, got.Builds[buildB].Checksum) + + // --- Build A frame lookups via GetBuildFrameData --- + fdA := got.GetBuildFrameData(buildA) + require.NotNil(t, fdA) + require.Equal(t, storage.CompressionZstd, fdA.CompressionType()) + // All 3 frames should survive trimming: frame 0 referenced by mapping 0, + // frames 1-2 referenced by mapping 2. + require.Equal(t, 3, fdA.NumFrames()) + + // Frame 0 of A: U=0, C offset=0, C length=1000. + r, err := fdA.LocateCompressed(0) + require.NoError(t, err) + require.Equal(t, int64(0), r.Offset) + require.Equal(t, 1000, r.Length) + + // Frame 1 of A: U=4096, C offset=1000, C length=1800. + r, err = fdA.LocateCompressed(4096) + require.NoError(t, err) + require.Equal(t, int64(1000), r.Offset) + require.Equal(t, 1800, r.Length) + + // Frame 2 of A: U=8192, C offset=2800, C length=2300. + r, err = fdA.LocateCompressed(8192) + require.NoError(t, err) + require.Equal(t, int64(2800), r.Offset) + require.Equal(t, 2300, r.Length) + + // --- Build B frame lookups via GetBuildFrameData --- + fdB := got.GetBuildFrameData(buildB) + require.NotNil(t, fdB) + require.Equal(t, storage.CompressionLZ4, fdB.CompressionType()) + require.Equal(t, 2, fdB.NumFrames()) + + // Frame 0 of B: U=0, C offset=0, C length=500. + r, err = fdB.LocateCompressed(0) + require.NoError(t, err) + require.Equal(t, int64(0), r.Offset) + require.Equal(t, 500, r.Length) + + // Frame 1 of B: U=4096, C offset=500, C length=1200. + r, err = fdB.LocateCompressed(4096) + require.NoError(t, err) + require.Equal(t, int64(500), r.Offset) + require.Equal(t, 1200, r.Length) + + // Beyond end of B's frames. + _, err = fdB.LocateCompressed(8192) + require.Error(t, err) +} + +func TestSerializeDeserialize_V4_TrimmedOffsets_Error(t *testing.T) { + t.Parallel() + + buildID := uuid.New() + + // 4 frames, each 4096 uncompressed. + ft := storage.NewFrameTable(storage.CompressionZstd, []storage.FrameSize{ + {U: 4096, C: 2000}, + {U: 4096, C: 3000}, + {U: 4096, C: 3500}, + {U: 4096, C: 1800}, + }) + + metadata := &Metadata{ + Version: 4, + BlockSize: 4096, + Size: 4096, + Generation: 0, + BuildId: buildID, + BaseBuildId: buildID, + } + + // Mapping references only frame 2 (BuildStorageOffset=8192, Length=4096). + // Frames 0, 1, and 3 should be trimmed away. + mappings := []BuildMap{ + { + Offset: 0, + Length: 4096, + BuildId: buildID, + BuildStorageOffset: 8192, + }, + } + + h, err := NewHeader(metadata, mappings) + require.NoError(t, err) + h.Builds = map[uuid.UUID]BuildData{ + buildID: {FrameData: ft}, + } + + data, err := SerializeHeader(h) + require.NoError(t, err) + + got, err := DeserializeBytes(data) + require.NoError(t, err) + + fd := got.Builds[buildID].FrameData + require.NotNil(t, fd) + require.Equal(t, 1, fd.NumFrames(), "only frame 2 should survive trimming") + + // The surviving frame covers U=[8192,12288). Lookup should succeed. + r, err := fd.LocateCompressed(8192) + require.NoError(t, err) + require.Equal(t, int64(5000), r.Offset) + require.Equal(t, 3500, r.Length) + + // Trimmed offsets should return errors. + _, err = fd.LocateCompressed(0) + require.Error(t, err, "frame at offset 0 was trimmed, should error") + + _, err = fd.LocateCompressed(4096) + require.Error(t, err, "frame at offset 4096 was trimmed, should error") + + _, err = fd.LocateCompressed(12288) + require.Error(t, err, "frame at offset 12288 was trimmed, should error") + + // Completely beyond original range. + _, err = fd.LocateCompressed(16384) + require.Error(t, err, "offset beyond all frames should error") +} + +func TestFrameTable_LocateCompressed(t *testing.T) { + t.Parallel() + + fd := storage.NewFrameTable(storage.CompressionZstd, []storage.FrameSize{ + {U: 2048, C: 1024}, + {U: 2048, C: 900}, + {U: 4096, C: 3500}, + }) + + // Frame 0: U=[0,2048), C=[0,1024) + r, err := fd.LocateCompressed(0) + require.NoError(t, err) + require.Equal(t, int64(0), r.Offset) + require.Equal(t, 1024, r.Length) + + r, err = fd.LocateCompressed(2047) + require.NoError(t, err) + require.Equal(t, int64(0), r.Offset) + require.Equal(t, 1024, r.Length) + + // Frame 1: U=[2048,4096), C=[1024,1924) + r, err = fd.LocateCompressed(2048) + require.NoError(t, err) + require.Equal(t, int64(1024), r.Offset) + require.Equal(t, 900, r.Length) + + // Frame 2: U=[4096,8192), C=[1924,5424) + r, err = fd.LocateCompressed(4096) + require.NoError(t, err) + require.Equal(t, int64(1924), r.Offset) + require.Equal(t, 3500, r.Length) + + // Beyond end + _, err = fd.LocateCompressed(8192) + require.Error(t, err) +} + +func TestFrameTable_LocateUncompressed(t *testing.T) { + t.Parallel() + + fd := storage.NewFrameTable(storage.CompressionZstd, []storage.FrameSize{ + {U: 2048, C: 1024}, + {U: 4096, C: 3500}, + }) + + // Frame 0: U=[0,2048) + r, err := fd.LocateUncompressed(0) + require.NoError(t, err) + require.Equal(t, int64(0), r.Offset) + require.Equal(t, 2048, r.Length) + + // Frame 1: U=[2048,6144) + r, err = fd.LocateUncompressed(2048) + require.NoError(t, err) + require.Equal(t, int64(2048), r.Offset) + require.Equal(t, 4096, r.Length) + + // Beyond end + _, err = fd.LocateUncompressed(6144) + require.Error(t, err) +} + +func TestSerializeDeserialize_V4_SparseTrimming(t *testing.T) { + t.Parallel() + + buildID := uuid.New() + otherID := uuid.New() + + ft := storage.NewFrameTable(storage.CompressionLZ4, []storage.FrameSize{ + {U: 4096, C: 2000}, + {U: 4096, C: 3000}, + {U: 4096, C: 2500}, + {U: 4096, C: 1800}, + }) + + metadata := &Metadata{ + Version: 4, + BlockSize: 4096, + Size: 4096 * 4, + Generation: 0, + BuildId: buildID, + BaseBuildId: buildID, + } + + // Mapping only references frames 0 and 3 (gap at 1,2 due to otherID). + mappings := []BuildMap{ + {Offset: 0, Length: 4096, BuildId: buildID, BuildStorageOffset: 0}, + {Offset: 4096, Length: 8192, BuildId: otherID, BuildStorageOffset: 0}, + {Offset: 12288, Length: 4096, BuildId: buildID, BuildStorageOffset: 12288}, + } + + h, err := NewHeader(metadata, mappings) + require.NoError(t, err) + h.Builds = map[uuid.UUID]BuildData{ + buildID: {FrameData: ft, Size: 16384}, + otherID: {Size: 8192}, + } + + data, err := SerializeHeader(h) + require.NoError(t, err) + + got, err := DeserializeBytes(data) + require.NoError(t, err) + + gotFT := got.Builds[buildID].FrameData + require.NotNil(t, gotFT) + require.Equal(t, 2, gotFT.NumFrames()) + + // Frame 0 + r, err := gotFT.LocateCompressed(0) + require.NoError(t, err) + require.Equal(t, int64(0), r.Offset) + require.Equal(t, 2000, r.Length) + + // Frame 3 + r, err = gotFT.LocateCompressed(12288) + require.NoError(t, err) + require.Equal(t, int64(2000+3000+2500), r.Offset) + require.Equal(t, 1800, r.Length) + + // Gap + _, err = gotFT.LocateCompressed(4096) + require.Error(t, err) +} diff --git a/packages/shared/pkg/storage/header/serialization_v3.go b/packages/shared/pkg/storage/header/serialization_v3.go new file mode 100644 index 0000000000..5ee8d45f0a --- /dev/null +++ b/packages/shared/pkg/storage/header/serialization_v3.go @@ -0,0 +1,65 @@ +package header + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" +) + +type v3SerializableBuildMap struct { + Offset uint64 + Length uint64 + BuildId [16]byte // uuid.UUID + BuildStorageOffset uint64 +} + +// serializeV3 writes [Metadata] [v3 mappings…] with no length prefix. +func serializeV3(metadata *Metadata, mappings []BuildMap) ([]byte, error) { + var buf bytes.Buffer + + if err := binary.Write(&buf, binary.LittleEndian, metadata); err != nil { + return nil, fmt.Errorf("failed to write metadata: %w", err) + } + + for _, mapping := range mappings { + v3 := &v3SerializableBuildMap{ + Offset: mapping.Offset, + Length: mapping.Length, + BuildId: mapping.BuildId, + BuildStorageOffset: mapping.BuildStorageOffset, + } + if err := binary.Write(&buf, binary.LittleEndian, v3); err != nil { + return nil, fmt.Errorf("failed to write block mapping: %w", err) + } + } + + return buf.Bytes(), nil +} + +// deserializeV3 reads V3 mappings (read until EOF, no count prefix). +func deserializeV3(metadata *Metadata, blockData []byte) (*Header, error) { + reader := bytes.NewReader(blockData) + var mappings []BuildMap + + for { + var v3 v3SerializableBuildMap + err := binary.Read(reader, binary.LittleEndian, &v3) + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return nil, fmt.Errorf("failed to read block mapping: %w", err) + } + + mappings = append(mappings, BuildMap{ + Offset: v3.Offset, + Length: v3.Length, + BuildId: v3.BuildId, + BuildStorageOffset: v3.BuildStorageOffset, + }) + } + + return NewHeader(metadata, mappings) +} diff --git a/packages/shared/pkg/storage/header/serialization_v4.go b/packages/shared/pkg/storage/header/serialization_v4.go new file mode 100644 index 0000000000..a88a772ede --- /dev/null +++ b/packages/shared/pkg/storage/header/serialization_v4.go @@ -0,0 +1,254 @@ +package header + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "slices" + + "github.com/google/uuid" + lz4 "github.com/pierrec/lz4/v4" + + "github.com/e2b-dev/infra/packages/shared/pkg/storage" +) + +// v4SizePrefixLen is the length of the uint32 size prefix that precedes the +// LZ4-compressed block in the V4 header layout. +const v4SizePrefixLen = 4 + +// v4FlagsLen is the length of the V4 flags byte. Bit 0 = IncompletePendingUpload. +const v4FlagsLen = 1 + +// v4FlagIncomplete is bit 0 of the V4 flags byte: when set, the header +// describes a build whose upload has not yet finalized (an in-flight diff). +// StoreHeader refuses to persist headers carrying this flag; only the P2P +// peer-server path emits it. +const v4FlagIncomplete uint8 = 1 << 0 + +type v4SerializableBuildMap struct { + Offset uint64 + Length uint64 + BuildId [16]byte // uuid.UUID + BuildStorageOffset uint64 +} + +// v4SerializableBuildInfo is the on-disk format for a build's fixed fields, +// followed by a serialized FrameTable. +type v4SerializableBuildInfo struct { + BuildId uuid.UUID + FileSize int64 + Checksum [32]byte +} + +// serializeV4 writes [Metadata] [uint8 flags] [uint32 LZ4 size] [LZ4( Builds[] + Mappings[] )]. +// Frame tables are sparse-trimmed to only frames referenced by mappings. +func serializeV4(metadata *Metadata, builds map[uuid.UUID]BuildData, mappings []BuildMap, incomplete bool) ([]byte, error) { + var metaBuf bytes.Buffer + if err := binary.Write(&metaBuf, binary.LittleEndian, metadata); err != nil { + return nil, fmt.Errorf("failed to write metadata: %w", err) + } + + var block bytes.Buffer + + // Sort by UUID for deterministic serialization. + buildIDs := make([]uuid.UUID, 0, len(builds)) + for id := range builds { + buildIDs = append(buildIDs, id) + } + slices.SortFunc(buildIDs, func(a, b uuid.UUID) int { + return bytes.Compare(a[:], b[:]) + }) + + if err := binary.Write(&block, binary.LittleEndian, uint32(len(buildIDs))); err != nil { + return nil, fmt.Errorf("failed to write build count: %w", err) + } + + buildRanges := extractRelevantRanges(mappings) + for _, id := range buildIDs { + bd := builds[id] + + entry := v4SerializableBuildInfo{ + BuildId: id, + FileSize: bd.Size, + Checksum: bd.Checksum, + } + + if err := binary.Write(&block, binary.LittleEndian, &entry); err != nil { + return nil, fmt.Errorf("failed to write build info: %w", err) + } + + trimmed := bd.FrameData.TrimToRanges(buildRanges[id]) + if err := trimmed.Serialize(&block); err != nil { + return nil, fmt.Errorf("failed to write build frame data: %w", err) + } + } + + if err := binary.Write(&block, binary.LittleEndian, uint32(len(mappings))); err != nil { + return nil, fmt.Errorf("failed to write mappings count: %w", err) + } + + for _, mapping := range mappings { + v4 := &v4SerializableBuildMap{ + Offset: mapping.Offset, + Length: mapping.Length, + BuildId: mapping.BuildId, + BuildStorageOffset: mapping.BuildStorageOffset, + } + + if err := binary.Write(&block, binary.LittleEndian, v4); err != nil { + return nil, fmt.Errorf("failed to write block mapping: %w", err) + } + } + + // LZ4-compress the block and assemble: [metadata] [uint8 flags] [uint32 size] [compressed block]. + blockBytes := block.Bytes() + compressed, err := compressLZ4(blockBytes) + if err != nil { + return nil, fmt.Errorf("failed to LZ4-compress v4 header block: %w", err) + } + + var flags uint8 + if incomplete { + flags |= v4FlagIncomplete + } + + result := make([]byte, metadataSize+v4FlagsLen+v4SizePrefixLen+len(compressed)) + copy(result, metaBuf.Bytes()) + result[metadataSize] = flags + binary.LittleEndian.PutUint32(result[metadataSize+v4FlagsLen:], uint32(len(blockBytes))) + copy(result[metadataSize+v4FlagsLen+v4SizePrefixLen:], compressed) + + return result, nil +} + +// deserializeV4 decompresses and reads the V4 block. +func deserializeV4(metadata *Metadata, blockData []byte) (*Header, error) { + if len(blockData) < v4FlagsLen+v4SizePrefixLen { + return nil, fmt.Errorf("v4 header block too short for flags + size prefix: %d bytes", len(blockData)) + } + + flags := blockData[0] + + decompressed, err := decompressLZ4(blockData[v4FlagsLen+v4SizePrefixLen:]) + if err != nil { + return nil, fmt.Errorf("failed to LZ4-decompress v4 header block: %w", err) + } + + reader := bytes.NewReader(decompressed) + + var numBuilds uint32 + if err := binary.Read(reader, binary.LittleEndian, &numBuilds); err != nil { + return nil, fmt.Errorf("failed to read build count: %w", err) + } + + var builds map[uuid.UUID]BuildData + + if numBuilds > 0 { + builds = make(map[uuid.UUID]BuildData, numBuilds) + + for range numBuilds { + var entry v4SerializableBuildInfo + if err := binary.Read(reader, binary.LittleEndian, &entry); err != nil { + return nil, fmt.Errorf("failed to read build info: %w", err) + } + + bd := BuildData{ + Size: entry.FileSize, + Checksum: entry.Checksum, + } + + ft, err := storage.DeserializeFrameTable(reader) + if err != nil { + return nil, fmt.Errorf("failed to read frame table for build %s: %w", entry.BuildId, err) + } + + bd.FrameData = ft + builds[entry.BuildId] = bd + } + } + + var numMappings uint32 + if err := binary.Read(reader, binary.LittleEndian, &numMappings); err != nil { + return nil, fmt.Errorf("failed to read mappings count: %w", err) + } + + mappings := make([]BuildMap, 0, numMappings) + for range numMappings { + var v4 v4SerializableBuildMap + if err := binary.Read(reader, binary.LittleEndian, &v4); err != nil { + return nil, fmt.Errorf("failed to read block mapping: %w", err) + } + + m := BuildMap{ + Offset: v4.Offset, + Length: v4.Length, + BuildId: v4.BuildId, + BuildStorageOffset: v4.BuildStorageOffset, + } + + mappings = append(mappings, m) + } + + h, err := NewHeader(metadata, mappings) + if err != nil { + return nil, err + } + h.Builds = builds + h.IncompletePendingUpload = flags&v4FlagIncomplete != 0 + + return h, nil +} + +// compressLZ4 compresses data for V4 header serialization using the LZ4 +// streaming API. Settings are fixed for the V4 wire format. +func compressLZ4(data []byte) ([]byte, error) { + var buf bytes.Buffer + buf.Grow(len(data)) + + w := lz4.NewWriter(&buf) + if err := w.Apply( + lz4.BlockSizeOption(lz4.Block4Mb), + lz4.BlockChecksumOption(true), + lz4.ChecksumOption(true), + lz4.CompressionLevelOption(lz4.Fast), + ); err != nil { + return nil, fmt.Errorf("lz4 options: %w", err) + } + + if _, err := w.Write(data); err != nil { + return nil, fmt.Errorf("lz4 compress: %w", err) + } + + if err := w.Close(); err != nil { + return nil, fmt.Errorf("lz4 compress close: %w", err) + } + + return buf.Bytes(), nil +} + +// extractRelevantRanges groups mappings into per-build U-space [start, end) ranges +// for sparse frame table trimming during serialization. +func extractRelevantRanges(mappings []BuildMap) map[uuid.UUID][]storage.Range { + ranges := make(map[uuid.UUID][]storage.Range) + for _, m := range mappings { + ranges[m.BuildId] = append(ranges[m.BuildId], storage.Range{ + Offset: int64(m.BuildStorageOffset), + Length: int(m.Length), + }) + } + + return ranges +} + +// decompressLZ4 decompresses an LZ4 frame from V4 header data. +func decompressLZ4(src []byte) ([]byte, error) { + r := lz4.NewReader(bytes.NewReader(src)) + + data, err := io.ReadAll(r) + if err != nil { + return nil, fmt.Errorf("lz4 decompress: %w", err) + } + + return data, nil +} diff --git a/packages/shared/pkg/storage/mocks/mockobjectprovider.go b/packages/shared/pkg/storage/mock_blob.go similarity index 99% rename from packages/shared/pkg/storage/mocks/mockobjectprovider.go rename to packages/shared/pkg/storage/mock_blob.go index 1f515f815f..e391c4e458 100644 --- a/packages/shared/pkg/storage/mocks/mockobjectprovider.go +++ b/packages/shared/pkg/storage/mock_blob.go @@ -2,7 +2,7 @@ // github.com/vektra/mockery // template: testify -package storagemocks +package storage import ( "context" diff --git a/packages/shared/pkg/storage/mocks/mockfeatureflagsclient.go b/packages/shared/pkg/storage/mock_featureflagsclient.go similarity index 99% rename from packages/shared/pkg/storage/mocks/mockfeatureflagsclient.go rename to packages/shared/pkg/storage/mock_featureflagsclient.go index d83936eddd..53dd4c5b29 100644 --- a/packages/shared/pkg/storage/mocks/mockfeatureflagsclient.go +++ b/packages/shared/pkg/storage/mock_featureflagsclient.go @@ -2,7 +2,7 @@ // github.com/vektra/mockery // template: testify -package storagemocks +package storage import ( "context" diff --git a/packages/shared/pkg/storage/mocks/mockioreader.go b/packages/shared/pkg/storage/mock_ioreader.go similarity index 99% rename from packages/shared/pkg/storage/mocks/mockioreader.go rename to packages/shared/pkg/storage/mock_ioreader.go index 5497bc53c5..9adb02421e 100644 --- a/packages/shared/pkg/storage/mocks/mockioreader.go +++ b/packages/shared/pkg/storage/mock_ioreader.go @@ -2,7 +2,7 @@ // github.com/vektra/mockery // template: testify -package storagemocks +package storage import ( mock "github.com/stretchr/testify/mock" diff --git a/packages/shared/pkg/storage/mocks/mockseekableobjectprovider.go b/packages/shared/pkg/storage/mock_seekable.go similarity index 63% rename from packages/shared/pkg/storage/mocks/mockseekableobjectprovider.go rename to packages/shared/pkg/storage/mock_seekable.go index 123d9cbd80..440836bc54 100644 --- a/packages/shared/pkg/storage/mocks/mockseekableobjectprovider.go +++ b/packages/shared/pkg/storage/mock_seekable.go @@ -2,13 +2,12 @@ // github.com/vektra/mockery // template: testify -package storagemocks +package storage import ( "context" "io" - "github.com/e2b-dev/infra/packages/shared/pkg/storage/storageopts" mock "github.com/stretchr/testify/mock" ) @@ -40,8 +39,8 @@ func (_m *MockSeekable) EXPECT() *MockSeekable_Expecter { } // OpenRangeReader provides a mock function for the type MockSeekable -func (_mock *MockSeekable) OpenRangeReader(ctx context.Context, off int64, length int64) (io.ReadCloser, error) { - ret := _mock.Called(ctx, off, length) +func (_mock *MockSeekable) OpenRangeReader(ctx context.Context, offsetU int64, length int64, frameTable *FrameTable) (io.ReadCloser, error) { + ret := _mock.Called(ctx, offsetU, length, frameTable) if len(ret) == 0 { panic("no return value specified for OpenRangeReader") @@ -49,18 +48,18 @@ func (_mock *MockSeekable) OpenRangeReader(ctx context.Context, off int64, lengt var r0 io.ReadCloser var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, int64, int64) (io.ReadCloser, error)); ok { - return returnFunc(ctx, off, length) + if returnFunc, ok := ret.Get(0).(func(context.Context, int64, int64, *FrameTable) (io.ReadCloser, error)); ok { + return returnFunc(ctx, offsetU, length, frameTable) } - if returnFunc, ok := ret.Get(0).(func(context.Context, int64, int64) io.ReadCloser); ok { - r0 = returnFunc(ctx, off, length) + if returnFunc, ok := ret.Get(0).(func(context.Context, int64, int64, *FrameTable) io.ReadCloser); ok { + r0 = returnFunc(ctx, offsetU, length, frameTable) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(io.ReadCloser) } } - if returnFunc, ok := ret.Get(1).(func(context.Context, int64, int64) error); ok { - r1 = returnFunc(ctx, off, length) + if returnFunc, ok := ret.Get(1).(func(context.Context, int64, int64, *FrameTable) error); ok { + r1 = returnFunc(ctx, offsetU, length, frameTable) } else { r1 = ret.Error(1) } @@ -74,13 +73,14 @@ type MockSeekable_OpenRangeReader_Call struct { // OpenRangeReader is a helper method to define mock.On call // - ctx context.Context -// - off int64 +// - offsetU int64 // - length int64 -func (_e *MockSeekable_Expecter) OpenRangeReader(ctx interface{}, off interface{}, length interface{}) *MockSeekable_OpenRangeReader_Call { - return &MockSeekable_OpenRangeReader_Call{Call: _e.mock.On("OpenRangeReader", ctx, off, length)} +// - frameTable *FrameTable +func (_e *MockSeekable_Expecter) OpenRangeReader(ctx interface{}, offsetU interface{}, length interface{}, frameTable interface{}) *MockSeekable_OpenRangeReader_Call { + return &MockSeekable_OpenRangeReader_Call{Call: _e.mock.On("OpenRangeReader", ctx, offsetU, length, frameTable)} } -func (_c *MockSeekable_OpenRangeReader_Call) Run(run func(ctx context.Context, off int64, length int64)) *MockSeekable_OpenRangeReader_Call { +func (_c *MockSeekable_OpenRangeReader_Call) Run(run func(ctx context.Context, offsetU int64, length int64, frameTable *FrameTable)) *MockSeekable_OpenRangeReader_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { @@ -94,10 +94,15 @@ func (_c *MockSeekable_OpenRangeReader_Call) Run(run func(ctx context.Context, o if args[2] != nil { arg2 = args[2].(int64) } + var arg3 *FrameTable + if args[3] != nil { + arg3 = args[3].(*FrameTable) + } run( arg0, arg1, arg2, + arg3, ) }) return _c @@ -108,79 +113,7 @@ func (_c *MockSeekable_OpenRangeReader_Call) Return(readCloser io.ReadCloser, er return _c } -func (_c *MockSeekable_OpenRangeReader_Call) RunAndReturn(run func(ctx context.Context, off int64, length int64) (io.ReadCloser, error)) *MockSeekable_OpenRangeReader_Call { - _c.Call.Return(run) - return _c -} - -// ReadAt provides a mock function for the type MockSeekable -func (_mock *MockSeekable) ReadAt(ctx context.Context, buffer []byte, off int64) (int, error) { - ret := _mock.Called(ctx, buffer, off) - - if len(ret) == 0 { - panic("no return value specified for ReadAt") - } - - var r0 int - var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, []byte, int64) (int, error)); ok { - return returnFunc(ctx, buffer, off) - } - if returnFunc, ok := ret.Get(0).(func(context.Context, []byte, int64) int); ok { - r0 = returnFunc(ctx, buffer, off) - } else { - r0 = ret.Get(0).(int) - } - if returnFunc, ok := ret.Get(1).(func(context.Context, []byte, int64) error); ok { - r1 = returnFunc(ctx, buffer, off) - } else { - r1 = ret.Error(1) - } - return r0, r1 -} - -// MockSeekable_ReadAt_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReadAt' -type MockSeekable_ReadAt_Call struct { - *mock.Call -} - -// ReadAt is a helper method to define mock.On call -// - ctx context.Context -// - buffer []byte -// - off int64 -func (_e *MockSeekable_Expecter) ReadAt(ctx interface{}, buffer interface{}, off interface{}) *MockSeekable_ReadAt_Call { - return &MockSeekable_ReadAt_Call{Call: _e.mock.On("ReadAt", ctx, buffer, off)} -} - -func (_c *MockSeekable_ReadAt_Call) Run(run func(ctx context.Context, buffer []byte, off int64)) *MockSeekable_ReadAt_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - var arg1 []byte - if args[1] != nil { - arg1 = args[1].([]byte) - } - var arg2 int64 - if args[2] != nil { - arg2 = args[2].(int64) - } - run( - arg0, - arg1, - arg2, - ) - }) - return _c -} - -func (_c *MockSeekable_ReadAt_Call) Return(n int, err error) *MockSeekable_ReadAt_Call { - _c.Call.Return(n, err) - return _c -} - -func (_c *MockSeekable_ReadAt_Call) RunAndReturn(run func(ctx context.Context, buffer []byte, off int64) (int, error)) *MockSeekable_ReadAt_Call { +func (_c *MockSeekable_OpenRangeReader_Call) RunAndReturn(run func(ctx context.Context, offsetU int64, length int64, frameTable *FrameTable) (io.ReadCloser, error)) *MockSeekable_OpenRangeReader_Call { _c.Call.Return(run) return _c } @@ -246,7 +179,7 @@ func (_c *MockSeekable_Size_Call) RunAndReturn(run func(ctx context.Context) (in } // StoreFile provides a mock function for the type MockSeekable -func (_mock *MockSeekable) StoreFile(ctx context.Context, path string, opts ...storageopts.PutOption) error { +func (_mock *MockSeekable) StoreFile(ctx context.Context, path string, opts ...PutOption) (*FrameTable, [32]byte, error) { var tmpRet mock.Arguments if len(opts) > 0 { tmpRet = _mock.Called(ctx, path, opts) @@ -259,13 +192,32 @@ func (_mock *MockSeekable) StoreFile(ctx context.Context, path string, opts ...s panic("no return value specified for StoreFile") } - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, string, ...storageopts.PutOption) error); ok { + var r0 *FrameTable + var r1 [32]byte + var r2 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, ...PutOption) (*FrameTable, [32]byte, error)); ok { + return returnFunc(ctx, path, opts...) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, ...PutOption) *FrameTable); ok { r0 = returnFunc(ctx, path, opts...) } else { - r0 = ret.Error(0) + if ret.Get(0) != nil { + r0 = ret.Get(0).(*FrameTable) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, ...PutOption) [32]byte); ok { + r1 = returnFunc(ctx, path, opts...) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).([32]byte) + } + } + if returnFunc, ok := ret.Get(2).(func(context.Context, string, ...PutOption) error); ok { + r2 = returnFunc(ctx, path, opts...) + } else { + r2 = ret.Error(2) } - return r0 + return r0, r1, r2 } // MockSeekable_StoreFile_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'StoreFile' @@ -276,13 +228,13 @@ type MockSeekable_StoreFile_Call struct { // StoreFile is a helper method to define mock.On call // - ctx context.Context // - path string -// - opts ...storageopts.PutOption +// - opts ...PutOption func (_e *MockSeekable_Expecter) StoreFile(ctx interface{}, path interface{}, opts ...interface{}) *MockSeekable_StoreFile_Call { return &MockSeekable_StoreFile_Call{Call: _e.mock.On("StoreFile", append([]interface{}{ctx, path}, opts...)...)} } -func (_c *MockSeekable_StoreFile_Call) Run(run func(ctx context.Context, path string, opts ...storageopts.PutOption)) *MockSeekable_StoreFile_Call { +func (_c *MockSeekable_StoreFile_Call) Run(run func(ctx context.Context, path string, opts ...PutOption)) *MockSeekable_StoreFile_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { @@ -292,10 +244,10 @@ func (_c *MockSeekable_StoreFile_Call) Run(run func(ctx context.Context, path st if args[1] != nil { arg1 = args[1].(string) } - var arg2 []storageopts.PutOption - var variadicArgs []storageopts.PutOption + var arg2 []PutOption + var variadicArgs []PutOption if len(args) > 2 { - variadicArgs = args[2].([]storageopts.PutOption) + variadicArgs = args[2].([]PutOption) } arg2 = variadicArgs run( @@ -307,12 +259,12 @@ func (_c *MockSeekable_StoreFile_Call) Run(run func(ctx context.Context, path st return _c } -func (_c *MockSeekable_StoreFile_Call) Return(err error) *MockSeekable_StoreFile_Call { - _c.Call.Return(err) +func (_c *MockSeekable_StoreFile_Call) Return(frameTable *FrameTable, bytes [32]byte, err error) *MockSeekable_StoreFile_Call { + _c.Call.Return(frameTable, bytes, err) return _c } -func (_c *MockSeekable_StoreFile_Call) RunAndReturn(run func(ctx context.Context, path string, opts ...storageopts.PutOption) error) *MockSeekable_StoreFile_Call { +func (_c *MockSeekable_StoreFile_Call) RunAndReturn(run func(ctx context.Context, path string, opts ...PutOption) (*FrameTable, [32]byte, error)) *MockSeekable_StoreFile_Call { _c.Call.Return(run) return _c } diff --git a/packages/shared/pkg/storage/mocks/provider/mockstorageprovider.go b/packages/shared/pkg/storage/mock_storageprovider.go similarity index 86% rename from packages/shared/pkg/storage/mocks/provider/mockstorageprovider.go rename to packages/shared/pkg/storage/mock_storageprovider.go index b505eb617f..4657bf0754 100644 --- a/packages/shared/pkg/storage/mocks/provider/mockstorageprovider.go +++ b/packages/shared/pkg/storage/mock_storageprovider.go @@ -2,13 +2,12 @@ // github.com/vektra/mockery // template: testify -package providermocks +package storage import ( "context" "time" - "github.com/e2b-dev/infra/packages/shared/pkg/storage" mock "github.com/stretchr/testify/mock" ) @@ -141,26 +140,26 @@ func (_c *MockStorageProvider_GetDetails_Call) RunAndReturn(run func() string) * } // OpenBlob provides a mock function for the type MockStorageProvider -func (_mock *MockStorageProvider) OpenBlob(ctx context.Context, path string, objectType storage.ObjectType) (storage.Blob, error) { +func (_mock *MockStorageProvider) OpenBlob(ctx context.Context, path string, objectType ObjectType) (Blob, error) { ret := _mock.Called(ctx, path, objectType) if len(ret) == 0 { panic("no return value specified for OpenBlob") } - var r0 storage.Blob + var r0 Blob var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, string, storage.ObjectType) (storage.Blob, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, ObjectType) (Blob, error)); ok { return returnFunc(ctx, path, objectType) } - if returnFunc, ok := ret.Get(0).(func(context.Context, string, storage.ObjectType) storage.Blob); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, ObjectType) Blob); ok { r0 = returnFunc(ctx, path, objectType) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(storage.Blob) + r0 = ret.Get(0).(Blob) } } - if returnFunc, ok := ret.Get(1).(func(context.Context, string, storage.ObjectType) error); ok { + if returnFunc, ok := ret.Get(1).(func(context.Context, string, ObjectType) error); ok { r1 = returnFunc(ctx, path, objectType) } else { r1 = ret.Error(1) @@ -176,12 +175,12 @@ type MockStorageProvider_OpenBlob_Call struct { // OpenBlob is a helper method to define mock.On call // - ctx context.Context // - path string -// - objectType storage.ObjectType +// - objectType ObjectType func (_e *MockStorageProvider_Expecter) OpenBlob(ctx interface{}, path interface{}, objectType interface{}) *MockStorageProvider_OpenBlob_Call { return &MockStorageProvider_OpenBlob_Call{Call: _e.mock.On("OpenBlob", ctx, path, objectType)} } -func (_c *MockStorageProvider_OpenBlob_Call) Run(run func(ctx context.Context, path string, objectType storage.ObjectType)) *MockStorageProvider_OpenBlob_Call { +func (_c *MockStorageProvider_OpenBlob_Call) Run(run func(ctx context.Context, path string, objectType ObjectType)) *MockStorageProvider_OpenBlob_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { @@ -191,9 +190,9 @@ func (_c *MockStorageProvider_OpenBlob_Call) Run(run func(ctx context.Context, p if args[1] != nil { arg1 = args[1].(string) } - var arg2 storage.ObjectType + var arg2 ObjectType if args[2] != nil { - arg2 = args[2].(storage.ObjectType) + arg2 = args[2].(ObjectType) } run( arg0, @@ -204,37 +203,37 @@ func (_c *MockStorageProvider_OpenBlob_Call) Run(run func(ctx context.Context, p return _c } -func (_c *MockStorageProvider_OpenBlob_Call) Return(blob storage.Blob, err error) *MockStorageProvider_OpenBlob_Call { +func (_c *MockStorageProvider_OpenBlob_Call) Return(blob Blob, err error) *MockStorageProvider_OpenBlob_Call { _c.Call.Return(blob, err) return _c } -func (_c *MockStorageProvider_OpenBlob_Call) RunAndReturn(run func(ctx context.Context, path string, objectType storage.ObjectType) (storage.Blob, error)) *MockStorageProvider_OpenBlob_Call { +func (_c *MockStorageProvider_OpenBlob_Call) RunAndReturn(run func(ctx context.Context, path string, objectType ObjectType) (Blob, error)) *MockStorageProvider_OpenBlob_Call { _c.Call.Return(run) return _c } // OpenSeekable provides a mock function for the type MockStorageProvider -func (_mock *MockStorageProvider) OpenSeekable(ctx context.Context, path string, seekableObjectType storage.SeekableObjectType) (storage.Seekable, error) { +func (_mock *MockStorageProvider) OpenSeekable(ctx context.Context, path string, seekableObjectType SeekableObjectType) (Seekable, error) { ret := _mock.Called(ctx, path, seekableObjectType) if len(ret) == 0 { panic("no return value specified for OpenSeekable") } - var r0 storage.Seekable + var r0 Seekable var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, string, storage.SeekableObjectType) (storage.Seekable, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, SeekableObjectType) (Seekable, error)); ok { return returnFunc(ctx, path, seekableObjectType) } - if returnFunc, ok := ret.Get(0).(func(context.Context, string, storage.SeekableObjectType) storage.Seekable); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, SeekableObjectType) Seekable); ok { r0 = returnFunc(ctx, path, seekableObjectType) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(storage.Seekable) + r0 = ret.Get(0).(Seekable) } } - if returnFunc, ok := ret.Get(1).(func(context.Context, string, storage.SeekableObjectType) error); ok { + if returnFunc, ok := ret.Get(1).(func(context.Context, string, SeekableObjectType) error); ok { r1 = returnFunc(ctx, path, seekableObjectType) } else { r1 = ret.Error(1) @@ -250,12 +249,12 @@ type MockStorageProvider_OpenSeekable_Call struct { // OpenSeekable is a helper method to define mock.On call // - ctx context.Context // - path string -// - seekableObjectType storage.SeekableObjectType +// - seekableObjectType SeekableObjectType func (_e *MockStorageProvider_Expecter) OpenSeekable(ctx interface{}, path interface{}, seekableObjectType interface{}) *MockStorageProvider_OpenSeekable_Call { return &MockStorageProvider_OpenSeekable_Call{Call: _e.mock.On("OpenSeekable", ctx, path, seekableObjectType)} } -func (_c *MockStorageProvider_OpenSeekable_Call) Run(run func(ctx context.Context, path string, seekableObjectType storage.SeekableObjectType)) *MockStorageProvider_OpenSeekable_Call { +func (_c *MockStorageProvider_OpenSeekable_Call) Run(run func(ctx context.Context, path string, seekableObjectType SeekableObjectType)) *MockStorageProvider_OpenSeekable_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { @@ -265,9 +264,9 @@ func (_c *MockStorageProvider_OpenSeekable_Call) Run(run func(ctx context.Contex if args[1] != nil { arg1 = args[1].(string) } - var arg2 storage.SeekableObjectType + var arg2 SeekableObjectType if args[2] != nil { - arg2 = args[2].(storage.SeekableObjectType) + arg2 = args[2].(SeekableObjectType) } run( arg0, @@ -278,12 +277,12 @@ func (_c *MockStorageProvider_OpenSeekable_Call) Run(run func(ctx context.Contex return _c } -func (_c *MockStorageProvider_OpenSeekable_Call) Return(seekable storage.Seekable, err error) *MockStorageProvider_OpenSeekable_Call { +func (_c *MockStorageProvider_OpenSeekable_Call) Return(seekable Seekable, err error) *MockStorageProvider_OpenSeekable_Call { _c.Call.Return(seekable, err) return _c } -func (_c *MockStorageProvider_OpenSeekable_Call) RunAndReturn(run func(ctx context.Context, path string, seekableObjectType storage.SeekableObjectType) (storage.Seekable, error)) *MockStorageProvider_OpenSeekable_Call { +func (_c *MockStorageProvider_OpenSeekable_Call) RunAndReturn(run func(ctx context.Context, path string, seekableObjectType SeekableObjectType) (Seekable, error)) *MockStorageProvider_OpenSeekable_Call { _c.Call.Return(run) return _c } diff --git a/packages/shared/pkg/storage/paths.go b/packages/shared/pkg/storage/paths.go index 0164d73b69..6a9a9df74d 100644 --- a/packages/shared/pkg/storage/paths.go +++ b/packages/shared/pkg/storage/paths.go @@ -34,7 +34,7 @@ func (p Paths) Memfile() string { } func (p Paths) MemfileHeader() string { - return fmt.Sprintf("%s/%s%s", p.BuildID, MemfileName, HeaderSuffix) + return p.HeaderFile(MemfileName) } func (p Paths) Rootfs() string { @@ -42,7 +42,7 @@ func (p Paths) Rootfs() string { } func (p Paths) RootfsHeader() string { - return fmt.Sprintf("%s/%s%s", p.BuildID, RootfsName, HeaderSuffix) + return p.HeaderFile(RootfsName) } func (p Paths) Snapfile() string { @@ -53,6 +53,30 @@ func (p Paths) Metadata() string { return fmt.Sprintf("%s/%s", p.BuildID, MetadataName) } +func (p Paths) MemfileCompressed(ct CompressionType) string { + return fmt.Sprintf("%s/%s%s", p.BuildID, MemfileName, ct.Suffix()) +} + +func (p Paths) RootfsCompressed(ct CompressionType) string { + return fmt.Sprintf("%s/%s%s", p.BuildID, RootfsName, ct.Suffix()) +} + +// DataFile returns the storage path for a data file (e.g. "memfile", "rootfs.ext4"), +// with compression suffix appended if ct is not CompressionNone. +func (p Paths) DataFile(name string, ct CompressionType) string { + if ct == CompressionNone { + return fmt.Sprintf("%s/%s", p.BuildID, name) + } + + return fmt.Sprintf("%s/%s%s", p.BuildID, name, ct.Suffix()) +} + +// HeaderFile returns the storage path for a header sidecar of a data file +// (e.g. "memfile" → "{buildID}/memfile.header"). +func (p Paths) HeaderFile(name string) string { + return fmt.Sprintf("%s/%s%s", p.BuildID, name, HeaderSuffix) +} + // SplitPath splits a storage path of the form "{buildID}/{fileName}" // back into its components. This is the inverse of the path methods. func SplitPath(path string) (buildID, fileName string) { @@ -60,3 +84,26 @@ func SplitPath(path string) (buildID, fileName string) { return buildID, fileName } + +var knownCompressionSuffixes = []string{CompressionLZ4.Suffix(), CompressionZstd.Suffix()} + +// StripCompression removes a known compression suffix from a file name. +// For example: "memfile.zstd" → "memfile". +// If no known suffix is present, the name is returned unchanged. +func StripCompression(name string) string { + for _, suffix := range knownCompressionSuffixes { + if before, ok := strings.CutSuffix(name, suffix); ok { + return before + } + } + + return name +} + +// SizeSidecar returns the sidecar path that stores the original +// uncompressed size for a compressed object (e.g. "/data/memfile.zstd" → +// "/data/memfile.zstd.uncompressed-size"). Used by the FS backend where +// GCS-style object metadata is unavailable. +func SizeSidecar(objectPath string) string { + return objectPath + "." + MetadataKeyUncompressedSize +} diff --git a/packages/shared/pkg/storage/sandbox.go b/packages/shared/pkg/storage/sandbox.go index b0fa9d44da..97111e99a7 100644 --- a/packages/shared/pkg/storage/sandbox.go +++ b/packages/shared/pkg/storage/sandbox.go @@ -18,6 +18,8 @@ type SandboxFiles struct { } type Config struct { + CompressConfig + SandboxCacheDir string `env:"SANDBOX_CACHE_DIR,expand" envDefault:"${ORCHESTRATOR_BASE_PATH}/sandbox"` SnapshotCacheDir string `env:"SNAPSHOT_CACHE_DIR,expand" envDefault:"/mnt/snapshot-cache"` TemplateCacheDir string `env:"TEMPLATE_CACHE_DIR,expand" envDefault:"${ORCHESTRATOR_BASE_PATH}/template"` diff --git a/packages/shared/pkg/storage/storage.go b/packages/shared/pkg/storage/storage.go index f2c6a7ea68..27a77e78ac 100644 --- a/packages/shared/pkg/storage/storage.go +++ b/packages/shared/pkg/storage/storage.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "os" "time" "go.opentelemetry.io/otel" @@ -42,6 +43,10 @@ const ( // MemoryChunkSize must always be bigger or equal to the block size. MemoryChunkSize = 4 * 1024 * 1024 // 4 MB + + // MetadataKeyUncompressedSize stores the original size so that Size() + // returns the uncompressed size for compressed objects. + MetadataKeyUncompressedSize = "uncompressed-size" ) // GetProviderType returns the configured storage provider type from the @@ -94,8 +99,23 @@ const ObjectMetadataTeamID = storageopts.ObjectMetadataTeamID func WithMetadata(metadata ObjectMetadata) PutOption { return storageopts.WithMetadata(metadata) } +// WithCompressConfig threads a typed CompressConfig through PutOptions. It is +// stored as `any` in storageopts to avoid importing storage from there; +// backends use CompressConfigFromOpts to pull it back out. +func WithCompressConfig(cfg CompressConfig) PutOption { return storageopts.WithCompression(cfg) } + func ApplyPutOptions(opts []PutOption) PutOptions { return storageopts.Apply(opts) } +// CompressConfigFromOpts returns the typed CompressConfig set by +// WithCompressConfig, or the zero value if absent. +func CompressConfigFromOpts(p PutOptions) CompressConfig { + if c, ok := p.Compression.(CompressConfig); ok { + return c + } + + return CompressConfig{} +} + type Blob interface { WriteTo(ctx context.Context, dst io.Writer) (int64, error) Put(ctx context.Context, data []byte, opts ...storageopts.PutOption) error @@ -104,24 +124,56 @@ type Blob interface { type SeekableReader interface { // Random slice access, off and buffer length must be aligned to block size - ReadAt(ctx context.Context, buffer []byte, off int64) (int, error) + ReadAt(ctx context.Context, buffer []byte, off int64, ft *FrameTable) (int, error) Size(ctx context.Context) (int64, error) } // StreamingReader supports progressive reads via a streaming range reader. type StreamingReader interface { - OpenRangeReader(ctx context.Context, off, length int64) (io.ReadCloser, error) + OpenRangeReader(ctx context.Context, offsetU int64, length int64, frameTable *FrameTable) (io.ReadCloser, error) } type SeekableWriter interface { - // Store entire file - StoreFile(ctx context.Context, path string, opts ...storageopts.PutOption) error + // Store entire file. Compression is opt-in via WithCompressConfig. + StoreFile(ctx context.Context, path string, opts ...PutOption) (*FrameTable, [32]byte, error) } type Seekable interface { - SeekableReader - SeekableWriter StreamingReader + SeekableWriter + Size(ctx context.Context) (int64, error) +} + +func UploadFramed(ctx context.Context, provider StorageProvider, remotePath string, objType SeekableObjectType, localPath string, opts ...PutOption) (*FrameTable, [32]byte, error) { + object, err := provider.OpenSeekable(ctx, remotePath, objType) + if err != nil { + return nil, [32]byte{}, err + } + + return object.StoreFile(ctx, localPath, opts...) +} + +func UploadBlob(ctx context.Context, provider StorageProvider, remotePath string, objType ObjectType, localPath string, opts ...PutOption) error { + blob, err := provider.OpenBlob(ctx, remotePath, objType) + if err != nil { + return err + } + + data, err := os.ReadFile(localPath) + if err != nil { + return fmt.Errorf("failed to read file %s: %w", localPath, err) + } + + return blob.Put(ctx, data, opts...) +} + +// PeerTransitionedError is returned by the peer Seekable when the remote +// storage upload has completed; the caller should re-load the V4 header from +// storage. +type PeerTransitionedError struct{} + +func (e *PeerTransitionedError) Error() string { + return "peer upload completed, reload header from storage" } // StorageConfig holds the configuration for creating a storage provider. diff --git a/packages/shared/pkg/storage/storage_aws.go b/packages/shared/pkg/storage/storage_aws.go index 13c649bea7..a6e59c459a 100644 --- a/packages/shared/pkg/storage/storage_aws.go +++ b/packages/shared/pkg/storage/storage_aws.go @@ -162,13 +162,18 @@ func (o *awsObject) WriteTo(ctx context.Context, dst io.Writer) (int64, error) { return io.Copy(dst, resp.Body) } -func (o *awsObject) StoreFile(ctx context.Context, path string, opts ...PutOption) error { +func (o *awsObject) StoreFile(ctx context.Context, path string, opts ...PutOption) (*FrameTable, [32]byte, error) { + p := ApplyPutOptions(opts) + if CompressConfigFromOpts(p).IsCompressionEnabled() { + return nil, [32]byte{}, errors.New("compressed uploads are not supported on AWS (builds target GCP only)") + } + ctx, cancel := context.WithTimeout(ctx, awsWriteTimeout) defer cancel() f, err := os.Open(path) if err != nil { - return fmt.Errorf("failed to open file %s: %w", path, err) + return nil, [32]byte{}, fmt.Errorf("failed to open file %s: %w", path, err) } defer f.Close() @@ -186,11 +191,26 @@ func (o *awsObject) StoreFile(ctx context.Context, path string, opts ...PutOptio Bucket: &o.bucketName, Key: &o.path, Body: f, - Metadata: ApplyPutOptions(opts).Metadata, + Metadata: p.Metadata, }, ) + if err == nil { + fi, _ := f.Stat() + var size int64 + if fi != nil { + size = fi.Size() + } - return err + logger.L().Debug(ctx, "Uploaded file to S3", + zap.String("bucket", o.bucketName), + zap.String("object", o.path), + zap.String("source", path), + zap.Int64("size_uncompressed", size), + zap.String("compression", "none"), + ) + } + + return nil, [32]byte{}, err } func (o *awsObject) Put(ctx context.Context, data []byte, opts ...PutOption) error { @@ -213,7 +233,11 @@ func (o *awsObject) Put(ctx context.Context, data []byte, opts ...PutOption) err return nil } -func (o *awsObject) OpenRangeReader(ctx context.Context, off, length int64) (io.ReadCloser, error) { +func (o *awsObject) OpenRangeReader(ctx context.Context, off, length int64, frameTable *FrameTable) (io.ReadCloser, error) { + if frameTable.IsCompressed() { + return nil, errors.New("compressed reads are not supported on AWS") + } + readRange := aws.String(fmt.Sprintf("bytes=%d-%d", off, off+length-1)) resp, err := o.client.GetObject(ctx, &s3.GetObjectInput{ Bucket: aws.String(o.bucketName), @@ -232,37 +256,6 @@ func (o *awsObject) OpenRangeReader(ctx context.Context, off, length int64) (io. return resp.Body, nil } -func (o *awsObject) ReadAt(ctx context.Context, buff []byte, off int64) (n int, err error) { - ctx, cancel := context.WithTimeout(ctx, awsReadTimeout) - defer cancel() - - readRange := aws.String(fmt.Sprintf("bytes=%d-%d", off, off+int64(len(buff))-1)) - resp, err := o.client.GetObject(ctx, &s3.GetObjectInput{ - Bucket: aws.String(o.bucketName), - Key: aws.String(o.path), - Range: readRange, - }) - if err != nil { - var nsk *types.NoSuchKey - if errors.As(err, &nsk) { - return 0, ErrObjectNotExist - } - - return 0, err - } - - defer resp.Body.Close() - - // When the object is smaller than requested range there will be unexpected EOF, - // but backend expects to return EOF in this case. - n, err = io.ReadFull(resp.Body, buff) - if errors.Is(err, io.ErrUnexpectedEOF) { - err = io.EOF - } - - return n, err -} - func (o *awsObject) Size(ctx context.Context) (int64, error) { ctx, cancel := context.WithTimeout(ctx, awsOperationTimeout) defer cancel() diff --git a/packages/shared/pkg/storage/storage_cache_blob_test.go b/packages/shared/pkg/storage/storage_cache_blob_test.go index 27c4afe88e..fbeecac4e0 100644 --- a/packages/shared/pkg/storage/storage_cache_blob_test.go +++ b/packages/shared/pkg/storage/storage_cache_blob_test.go @@ -13,8 +13,6 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/trace/noop" - - storagemocks "github.com/e2b-dev/infra/packages/shared/pkg/storage/mocks" ) var noopTracer = noop.TracerProvider{}.Tracer("github.com/e2b-dev/infra/packages/shared/pkg/storage") @@ -32,12 +30,12 @@ func TestCachedObjectProvider_Put(t *testing.T) { err := os.MkdirAll(cacheDir, os.ModePerm) require.NoError(t, err) - inner := storagemocks.NewMockBlob(t) + inner := NewMockBlob(t) inner.EXPECT(). Put(mock.Anything, mock.Anything). Return(nil) - featureFlags := storagemocks.NewMockFeatureFlagsClient(t) + featureFlags := NewMockFeatureFlagsClient(t) featureFlags.EXPECT().BoolFlag(mock.Anything, mock.Anything).Return(true) c := cachedBlob{path: cacheDir, inner: inner, chunkSize: 1024, flags: featureFlags, tracer: noopTracer} @@ -68,7 +66,7 @@ func TestCachedObjectProvider_Put(t *testing.T) { const dataSize = 10 * megabyte actualData := generateData(t, dataSize) - inner := storagemocks.NewMockBlob(t) + inner := NewMockBlob(t) inner.EXPECT(). WriteTo(mock.Anything, mock.Anything). RunAndReturn(func(_ context.Context, dst io.Writer) (int64, error) { @@ -101,7 +99,7 @@ func TestCachedObjectProvider_WriteFileToCache(t *testing.T) { tracer: noopTracer, } errTarget := errors.New("find me") - reader := storagemocks.NewMockReader(t) + reader := NewMockReader(t) reader.EXPECT().Read(mock.Anything).Return(4, nil).Once() reader.EXPECT().Read(mock.Anything).Return(0, errTarget).Once() diff --git a/packages/shared/pkg/storage/storage_cache_compressed_test.go b/packages/shared/pkg/storage/storage_cache_compressed_test.go new file mode 100644 index 0000000000..8314abb0a1 --- /dev/null +++ b/packages/shared/pkg/storage/storage_cache_compressed_test.go @@ -0,0 +1,153 @@ +package storage + +import ( + "bytes" + "io" + "os" + "testing" + + lz4 "github.com/pierrec/lz4/v4" + "github.com/stretchr/testify/require" +) + +// lz4Compress is a test helper that LZ4-compresses src. +func lz4Compress(t *testing.T, src []byte) []byte { + t.Helper() + + var buf bytes.Buffer + + w := lz4.NewWriter(&buf) + _, err := w.Write(src) + require.NoError(t, err) + require.NoError(t, w.Close()) + + return buf.Bytes() +} + +// lz4CompressProd matches the production encoder configuration (compress_encode.go): +// BlockSize=4Mb, BlockChecksumOption(true), ChecksumOption(false). Output ends in +// a 4-byte EndMark; with content checksum disabled, the decoder will not pull +// past the last block's data unless the caller reads past EOF. +func lz4CompressProd(t *testing.T, src []byte) []byte { + t.Helper() + + var buf bytes.Buffer + + w := lz4.NewWriter(&buf) + require.NoError(t, w.Apply( + lz4.BlockSizeOption(lz4.Block4Mb), + lz4.BlockChecksumOption(true), + lz4.ChecksumOption(false), + )) + _, err := w.Write(src) + require.NoError(t, err) + require.NoError(t, w.Close()) + + return buf.Bytes() +} + +func TestDecompressingCacheReader(t *testing.T) { + t.Parallel() + + newTestCache := func(t *testing.T) cachedSeekable { + t.Helper() + + return cachedSeekable{ + path: t.TempDir(), + chunkSize: 10, + tracer: noopTracer, + } + } + + original := []byte("the quick brown fox jumps over the lazy dog") + compressed := lz4Compress(t, original) + + t.Run("complete read is cached", func(t *testing.T) { + t.Parallel() + + c := newTestCache(t) + framePath := makeFrameFilename(c.path, Range{Offset: 0, Length: len(compressed)}) + + rc, err := newDecompressingCacheReader( + io.NopCloser(bytes.NewReader(compressed)), + CompressionLZ4, + len(compressed), + &c, t.Context(), framePath, 0, + ) + require.NoError(t, err) + + got, err := io.ReadAll(rc) + require.NoError(t, err) + require.Equal(t, original, got) + + require.NoError(t, rc.Close()) + c.wg.Wait() + + cached, err := os.ReadFile(framePath) + require.NoError(t, err) + require.Equal(t, compressed, cached) + }) + + t.Run("io.ReadFull at exact uncompressed size still populates cache (production LZ4 options)", func(t *testing.T) { + t.Parallel() + + // Mirror the chunker's progressiveRead: io.ReadFull with the EXACT + // uncompressed byte count, against an encoder configured the way prod + // configures it. With BlockChecksumOption(true)+ChecksumOption(false), + // the trailing 4-byte EndMark is part of the encoded frame but lz4.Reader + // does not pull it through the tee unless the caller reads past EOF. + // The cache writeback path must tolerate that — failing the read for a + // writeback short would mean every subsequent read for the same block + // repeats the GCS round-trip and re-fails Close, defeating both cache + // tiers (chunker mmap bitmap + NFS .frm). + c := newTestCache(t) + compressedProd := lz4CompressProd(t, original) + framePath := makeFrameFilename(c.path, Range{Offset: 0, Length: len(compressedProd)}) + + rc, err := newDecompressingCacheReader( + io.NopCloser(bytes.NewReader(compressedProd)), + CompressionLZ4, + len(compressedProd), + &c, t.Context(), framePath, 0, + ) + require.NoError(t, err) + + out := make([]byte, len(original)) + n, err := io.ReadFull(rc, out) + require.NoError(t, err) + require.Equal(t, len(original), n) + require.Equal(t, original, out) + + require.NoError(t, rc.Close(), "writeback failure must not surface as a read error") + c.wg.Wait() + + _, err = os.Stat(framePath) + require.NoError(t, err, "frame should be cached after a successful complete read") + }) + + t.Run("size mismatch skips cache writeback but does not fail the read", func(t *testing.T) { + t.Parallel() + + c := newTestCache(t) + framePath := makeFrameFilename(c.path, Range{Offset: 0, Length: len(compressed)}) + + rc, err := newDecompressingCacheReader( + io.NopCloser(bytes.NewReader(compressed)), + CompressionLZ4, + len(compressed)+100, // wrong size + &c, t.Context(), framePath, 0, + ) + require.NoError(t, err) + + got, err := io.ReadAll(rc) + require.NoError(t, err) + require.Equal(t, original, got, "decompressed data should be correct regardless") + + require.NoError(t, rc.Close(), "writeback failure must not surface as a read error") + + c.wg.Wait() + + _, err = os.Stat(framePath) + require.True(t, os.IsNotExist(err), "mismatched frame should not be cached") + }) +} diff --git a/packages/shared/pkg/storage/storage_cache_metrics.go b/packages/shared/pkg/storage/storage_cache_metrics.go index 037bc7ed06..514b34539e 100644 --- a/packages/shared/pkg/storage/storage_cache_metrics.go +++ b/packages/shared/pkg/storage/storage_cache_metrics.go @@ -29,12 +29,11 @@ type cacheOp string const ( cacheOpWriteTo cacheOp = "write_to" - cacheOpReadAt cacheOp = "read_at" + cacheOpWrite cacheOp = "write" cacheOpSize cacheOp = "size" cacheOpOpenRangeReader cacheOp = "open_range_reader" - cacheOpWrite cacheOp = "write" cacheOpWriteFromFileSystem cacheOp = "write_from_filesystem" ) diff --git a/packages/shared/pkg/storage/storage_cache_seekable.go b/packages/shared/pkg/storage/storage_cache_seekable.go index 8c5b908280..5ce7529214 100644 --- a/packages/shared/pkg/storage/storage_cache_seekable.go +++ b/packages/shared/pkg/storage/storage_cache_seekable.go @@ -32,7 +32,9 @@ var ( ) const ( - nfsCacheOperationAttr = "operation" + nfsCacheOperationAttr = "operation" + // Value kept as "ReadAt" dashboard compatibility after the method was + // renamed to OpenRangeReader. nfsCacheOperationAttrReadAt = "ReadAt" nfsCacheOperationAttrSize = "Size" ) @@ -72,114 +74,116 @@ var ( _ StreamingReader = (*cachedSeekable)(nil) ) -func (c *cachedSeekable) ReadAt(ctx context.Context, buff []byte, offset int64) (n int, err error) { - ctx, span := c.tracer.Start(ctx, "read object at offset", trace.WithAttributes( - attribute.Int64("offset", offset), - attribute.Int("buff_len", len(buff)), - )) - defer func() { - recordError(span, err) - span.End() - }() - - if err := c.validateReadAtParams(int64(len(buff)), offset); err != nil { - return 0, err - } - - // try to read from cache first - chunkPath := c.makeChunkFilename(offset) +func (c *cachedSeekable) OpenRangeReader(ctx context.Context, off int64, length int64, frameTable *FrameTable) (io.ReadCloser, error) { + compressed := frameTable.IsCompressed() - readTimer := cacheSlabReadTimerFactory.Begin(attribute.String(nfsCacheOperationAttr, nfsCacheOperationAttrReadAt)) - count, err := c.readAtFromCache(ctx, chunkPath, buff) - if ignoreEOF(err) == nil { - recordCacheRead(ctx, true, int64(count), cacheTypeSeekable, cacheOpReadAt) - readTimer.Success(ctx, int64(count)) + ctx, span := c.tracer.Start(ctx, "read", trace.WithAttributes( + attribute.Int64("offset", off), + attribute.Int64("length", length), + attribute.Bool("compressed", compressed), + )) - return count, err // return `err` in case it's io.EOF - } - readTimer.Failure(ctx, int64(count)) + if compressed { + rc, err := c.openReaderCompressed(ctx, off, frameTable) + if err != nil { + recordError(span, err) + span.End() - if !os.IsNotExist(err) { - recordCacheReadError(ctx, cacheTypeSeekable, cacheOpReadAt, err) - } + return nil, err + } - logger.L().Debug(ctx, "failed to read cached chunk, falling back to remote read", - zap.String("chunk_path", chunkPath), - zap.Int64("offset", offset), - zap.Error(err)) + rc = withSpan(rc, span) - // read remote file - readCount, err := c.inner.ReadAt(ctx, buff, offset) - if ignoreEOF(err) != nil { - return readCount, fmt.Errorf("failed to perform uncached read: %w", err) + return rc, nil } - if !skipCacheWriteback(ctx) && isCompleteRead(readCount, len(buff), err) { - shadowBuff := make([]byte, readCount) - copy(shadowBuff, buff[:readCount]) - - c.goCtx(ctx, func(ctx context.Context) { - ctx, span := c.tracer.Start(ctx, "write chunk at offset back to cache") - defer span.End() + if err := c.validateReadParams(length, off); err != nil { + recordError(span, err) + span.End() - if err := c.writeChunkToCache(ctx, offset, chunkPath, shadowBuff); err != nil { - recordError(span, err) - recordCacheWriteError(ctx, cacheTypeSeekable, cacheOpReadAt, err) - } - }) + return nil, err } - recordCacheRead(ctx, false, int64(readCount), cacheTypeSeekable, cacheOpReadAt) - - return readCount, err -} + timer := cacheSlabReadTimerFactory.Begin( + attribute.String(nfsCacheOperationAttr, nfsCacheOperationAttrReadAt), + attribute.Bool("compressed", false), + ) -func (c *cachedSeekable) OpenRangeReader(ctx context.Context, off, length int64) (io.ReadCloser, error) { - // Try NFS cache file first chunkPath := c.makeChunkFilename(off) fp, err := os.Open(chunkPath) if err == nil { recordCacheRead(ctx, true, length, cacheTypeSeekable, cacheOpOpenRangeReader) + timer.Success(ctx, length) + + rc := io.ReadCloser(&fsRangeReadCloser{Reader: io.NewSectionReader(fp, 0, length), file: fp}) + rc = withSpan(rc, span) - return &fsRangeReadCloser{ - Reader: io.NewSectionReader(fp, 0, length), - file: fp, - }, nil + return rc, nil } if !os.IsNotExist(err) { recordCacheReadError(ctx, cacheTypeSeekable, cacheOpOpenRangeReader, err) } - // Cache miss: delegate to the inner backend (Seekable embeds StreamingReader). - inner, err := c.inner.OpenRangeReader(ctx, off, length) + timer.Failure(ctx, 0) + + rc, err := c.inner.OpenRangeReader(ctx, off, length, nil) if err != nil { + recordError(span, err) + span.End() + return nil, fmt.Errorf("failed to open inner range reader: %w", err) } recordCacheRead(ctx, false, length, cacheTypeSeekable, cacheOpOpenRangeReader) - // Skip write-through when the caller has opted out of cache writeback. - if skipCacheWriteback(ctx) { - return inner, nil + if !skipCacheWriteback(ctx) { + rc = newCacheWriteThroughReader(rc, c, ctx, off, length, chunkPath) } - // Wrap in a write-through reader that caches data on Close + rc = withSpan(rc, span) + + return rc, nil +} + +// withSpan wraps a reader with an OTEL span that ends on Close. +func withSpan(rc io.ReadCloser, span trace.Span) io.ReadCloser { + return &spanReadCloser{inner: rc, span: span} +} + +type spanReadCloser struct { + inner io.ReadCloser + span trace.Span +} + +func (r *spanReadCloser) Read(p []byte) (int, error) { + return r.inner.Read(p) +} + +func (r *spanReadCloser) Close() error { + err := r.inner.Close() + recordError(r.span, err) + r.span.End() + + return err +} + +// newCacheWriteThroughReader wraps a reader, buffering all data read through it. +// On Close, it asynchronously writes the buffered data to the NFS cache only +// if the total bytes read match the expected length (to avoid caching truncated data). +func newCacheWriteThroughReader(inner io.ReadCloser, cache *cachedSeekable, ctx context.Context, off, expectedLen int64, chunkPath string) io.ReadCloser { return &cacheWriteThroughReader{ inner: inner, - buf: bytes.NewBuffer(make([]byte, 0, length)), - cache: c, + buf: bytes.NewBuffer(make([]byte, 0, expectedLen)), + cache: cache, ctx: ctx, off: off, - expectedLen: length, + expectedLen: expectedLen, chunkPath: chunkPath, - }, nil + } } -// cacheWriteThroughReader wraps an inner reader, buffering all data read through it. -// On Close, it asynchronously writes the buffered data to the NFS cache only -// if the total bytes read match the expected length (to avoid caching truncated data). type cacheWriteThroughReader struct { inner io.ReadCloser buf *bytes.Buffer @@ -206,7 +210,7 @@ func (r *cacheWriteThroughReader) Close() error { // Unlike ReadAt where io.EOF can justify a short read (last chunk), // a streaming reader always ends with EOF regardless of whether the // data was truncated, so the byte count is the only reliable check. - if r.buf.Len() > 0 && int64(r.buf.Len()) == r.expectedLen { + if isCompleteRead(r.buf.Len(), int(r.expectedLen), nil) { data := make([]byte, r.buf.Len()) copy(data, r.buf.Bytes()) @@ -214,7 +218,7 @@ func (r *cacheWriteThroughReader) Close() error { ctx, span := r.cache.tracer.Start(ctx, "write range reader chunk back to cache") defer span.End() - if err := r.cache.writeChunkToCache(ctx, r.off, r.chunkPath, data); err != nil { + if err := r.cache.writeToCache(ctx, r.off, r.chunkPath, data); err != nil { recordError(span, err) recordCacheWriteError(ctx, cacheTypeSeekable, cacheOpOpenRangeReader, err) } @@ -266,7 +270,7 @@ func (c *cachedSeekable) Size(ctx context.Context) (n int64, e error) { return size, nil } -func (c *cachedSeekable) StoreFile(ctx context.Context, path string, opts ...PutOption) (e error) { +func (c *cachedSeekable) StoreFile(ctx context.Context, path string, opts ...PutOption) (_ *FrameTable, _ [32]byte, e error) { ctx, span := c.tracer.Start(ctx, "write object from file system", trace.WithAttributes(attribute.String("path", path)), ) @@ -275,10 +279,12 @@ func (c *cachedSeekable) StoreFile(ctx context.Context, path string, opts ...Put span.End() }() + cfg := CompressConfigFromOpts(ApplyPutOptions(opts)) + // write the file to the disk and the remote system at the same time. // this opens the file twice, but the API makes it difficult to use a MultiWriter - if c.flags.BoolFlag(ctx, featureflags.EnableWriteThroughCacheFlag) { + if !cfg.IsCompressionEnabled() && c.flags.BoolFlag(ctx, featureflags.EnableWriteThroughCacheFlag) { c.goCtx(ctx, func(ctx context.Context) { ctx, span := c.tracer.Start(ctx, "write cache object from file system", trace.WithAttributes(attribute.String("path", path))) @@ -314,36 +320,8 @@ func (c *cachedSeekable) makeChunkFilename(offset int64) string { return fmt.Sprintf("%s/%012d-%d.bin", c.path, offset/c.chunkSize, c.chunkSize) } -func (c *cachedSeekable) makeTempChunkFilename(offset int64) string { - tempFilename := uuid.NewString() - - return fmt.Sprintf("%s/.temp.%012d-%d.bin.%s", c.path, offset/c.chunkSize, c.chunkSize, tempFilename) -} - -func (c *cachedSeekable) readAtFromCache(ctx context.Context, chunkPath string, buff []byte) (n int, e error) { - ctx, span := c.tracer.Start(ctx, "read chunk at offset from cache") - defer func() { - recordError(span, e) - span.End() - }() - - fp, err := os.Open(chunkPath) - if err != nil { - return 0, fmt.Errorf("failed to open file: %w", err) - } - - defer utils.Cleanup(ctx, "failed to close chunk", fp.Close) - - // ReadAt (pread) is used instead of Read so that short reads from cache - // files (e.g. last chunk) return io.EOF per the io.ReaderAt contract. - // Plain Read on Linux returns (n, nil) for short reads and only - // signals EOF on a subsequent call, which would hide truncation. - count, err := fp.ReadAt(buff, 0) - if ignoreEOF(err) != nil { - return 0, fmt.Errorf("failed to read from chunk: %w", err) - } - - return count, err // return `err` in case it's io.EOF +func (c *cachedSeekable) makeTempFilename(path string) string { + return path + ".tmp." + uuid.NewString() } func (c *cachedSeekable) sizeFilename() string { @@ -365,7 +343,7 @@ func (c *cachedSeekable) readLocalSize(context.Context) (int64, error) { return size, nil } -func (c *cachedSeekable) validateReadAtParams(buffSize, offset int64) error { +func (c *cachedSeekable) validateReadParams(buffSize, offset int64) error { if buffSize == 0 { return ErrBufferTooSmall } @@ -382,14 +360,14 @@ func (c *cachedSeekable) validateReadAtParams(buffSize, offset int64) error { return nil } -func (c *cachedSeekable) writeChunkToCache(ctx context.Context, offset int64, chunkPath string, bytes []byte) error { +func (c *cachedSeekable) writeToCache(ctx context.Context, offset int64, finalPath string, bytes []byte) error { writeTimer := cacheSlabWriteTimerFactory.Begin() // Try to acquire lock for this chunk write to NFS cache - lockFile, err := lock.TryAcquireLock(ctx, chunkPath) + lockFile, err := lock.TryAcquireLock(ctx, finalPath) if err != nil { // failed to acquire lock, which is a different category of failure than "write failed" - recordCacheWriteError(ctx, cacheTypeSeekable, cacheOpReadAt, err) + recordCacheWriteError(ctx, cacheTypeSeekable, cacheOpOpenRangeReader, err) writeTimer.Failure(ctx, 0) @@ -400,14 +378,14 @@ func (c *cachedSeekable) writeChunkToCache(ctx context.Context, offset int64, ch defer func() { err := lock.ReleaseLock(ctx, lockFile) if err != nil { - logger.L().Warn(ctx, "failed to release lock after writing chunk to cache", + logger.L().Warn(ctx, "failed to release lock after writing to cache", zap.Int64("offset", offset), - zap.String("path", chunkPath), + zap.String("path", finalPath), zap.Error(err)) } }() - tempPath := c.makeTempChunkFilename(offset) + tempPath := c.makeTempFilename(finalPath) if err := os.WriteFile(tempPath, bytes, cacheFilePermissions); err != nil { go safelyRemoveFile(ctx, tempPath) @@ -417,7 +395,7 @@ func (c *cachedSeekable) writeChunkToCache(ctx context.Context, offset int64, ch return fmt.Errorf("failed to write temp cache file: %w", err) } - if err := utils.RenameOrDeleteFile(ctx, tempPath, chunkPath); err != nil { + if err := utils.RenameOrDeleteFile(ctx, tempPath, finalPath); err != nil { writeTimer.Failure(ctx, int64(len(bytes))) return fmt.Errorf("failed to rename temp file: %w", err) diff --git a/packages/shared/pkg/storage/storage_cache_seekable_compressed.go b/packages/shared/pkg/storage/storage_cache_seekable_compressed.go new file mode 100644 index 0000000000..3b1f4af900 --- /dev/null +++ b/packages/shared/pkg/storage/storage_cache_seekable_compressed.go @@ -0,0 +1,173 @@ +package storage + +import ( + "bytes" + "context" + "fmt" + "io" + "os" + + "go.opentelemetry.io/otel/attribute" +) + +// Precomputed OTEL attributes for compressed cache reads (avoids per-read allocation). +var compressedCacheReadAttrs = []attribute.KeyValue{ + attribute.String(nfsCacheOperationAttr, nfsCacheOperationAttrReadAt), + attribute.Bool("compressed", true), +} + +// openReaderCompressed handles the compressed cache path for OpenRangeReader. +// NFS stores compressed frames (.frm); on hit we decompress, on miss we fetch +// raw compressed bytes and tee them to NFS on Close. +func (c *cachedSeekable) openReaderCompressed(ctx context.Context, offsetU int64, frameTable *FrameTable) (io.ReadCloser, error) { + r, err := frameTable.LocateCompressed(offsetU) + if err != nil { + return nil, fmt.Errorf("frame lookup for offset %d: %w", offsetU, err) + } + + path := makeFrameFilename(c.path, r) + + timer := cacheSlabReadTimerFactory.Begin(compressedCacheReadAttrs...) + + // Cache hit: open compressed frame from NFS and wrap with decompressor. + f, err := os.Open(path) + + switch { + case err == nil: + recordCacheRead(ctx, true, int64(r.Length), cacheTypeSeekable, cacheOpOpenRangeReader) + timer.Success(ctx, int64(r.Length)) + + decompressed, err := newDecompressingReadCloser(f, frameTable.CompressionType()) + if err != nil { + f.Close() + + return nil, fmt.Errorf("decompress cached frame: %w", err) + } + + return decompressed, nil + case !os.IsNotExist(err): + recordCacheReadError(ctx, cacheTypeSeekable, cacheOpOpenRangeReader, err) + } + + timer.Failure(ctx, 0) + + // Cache miss: fetch raw compressed bytes via OpenRangeReader(nil frameTable). + raw, err := c.inner.OpenRangeReader(ctx, r.Offset, int64(r.Length), nil) + if err != nil { + return nil, fmt.Errorf("raw fetch at C=%d: %w", r.Offset, err) + } + + recordCacheRead(ctx, false, int64(r.Length), cacheTypeSeekable, cacheOpOpenRangeReader) + + rc, err := newDecompressingCacheReader(raw, frameTable.CompressionType(), r.Length, c, ctx, path, offsetU) + if err != nil { + raw.Close() + + return nil, fmt.Errorf("create decompressor: %w", err) + } + + return rc, nil +} + +// newDecompressingCacheReader creates a reader that decompresses on Read and +// writes the accumulated compressed bytes to the NFS cache on Close. +func newDecompressingCacheReader( + raw io.ReadCloser, + ct CompressionType, + expectedSize int, + cache *cachedSeekable, + ctx context.Context, //nolint:revive // ctx after other params for readability at call site + framePath string, + offset int64, +) (io.ReadCloser, error) { + var compressedBuf bytes.Buffer + compressedBuf.Grow(expectedSize) + + tee := io.TeeReader(raw, &compressedBuf) + + dec, err := NewDecompressingReader(tee, ct) + if err != nil { + return nil, err + } + + return &decompressingCacheReader{ + decompressor: dec, + raw: raw, + compressedBuf: &compressedBuf, + expectedSize: expectedSize, + cache: cache, + ctx: ctx, + framePath: framePath, + offset: offset, + }, nil +} + +type decompressingCacheReader struct { + decompressor io.ReadCloser // decompresses on Read + raw io.ReadCloser // underlying compressed stream (must be closed) + compressedBuf *bytes.Buffer + expectedSize int + cache *cachedSeekable + ctx context.Context //nolint:containedctx // needed for async cache write-back in Close + framePath string + offset int64 +} + +func (r *decompressingCacheReader) Read(p []byte) (int, error) { + return r.decompressor.Read(p) +} + +func (r *decompressingCacheReader) Close() error { + // Drive the decompressor to EOF before closing it. With io.ReadFull bounded + // by the uncompressed size, an LZ4 frame written with BlockChecksum=true / + // Checksum=false leaves the 4-byte EndMark unread — the next Read on the + // decoder pulls the EndMark (block-size = 0 → io.EOF) from raw through the + // tee, populating compressedBuf with the full encoded frame for cache writeback. + _, _ = io.Copy(io.Discard, r.decompressor) + + decErr := r.decompressor.Close() + rawErr := r.raw.Close() + + if decErr != nil { + return decErr + } + if rawErr != nil { + return rawErr + } + + got := r.compressedBuf.Len() + if skipCacheWriteback(r.ctx) { + return nil + } + + // Cache writeback is best-effort. After draining above, a remaining shortfall + // implies upstream truncation — log/metric and skip writeback rather than + // poison the read (the caller already received valid decompressed bytes). + if !isCompleteRead(got, r.expectedSize, nil) { + recordCacheWriteError(r.ctx, cacheTypeSeekable, cacheOpOpenRangeReader, + fmt.Errorf("compressed frame cache writeback short: got %d bytes, expected %d for %s", got, r.expectedSize, r.framePath)) + + return nil + } + + data := r.compressedBuf.Bytes() + r.compressedBuf = nil + + r.cache.goCtx(r.ctx, func(ctx context.Context) { + ctx, span := r.cache.tracer.Start(ctx, "write compressed frame back to cache") + defer span.End() + + if err := r.cache.writeToCache(ctx, r.offset, r.framePath, data); err != nil { + recordError(span, err) + recordCacheWriteError(ctx, cacheTypeSeekable, cacheOpOpenRangeReader, err) + } + }) + + return nil +} + +// makeFrameFilename returns the NFS cache path for a compressed frame. +// Format: {cacheBasePath}/{016xStart}-{xLength}.frm +func makeFrameFilename(cacheBasePath string, r Range) string { + return fmt.Sprintf("%s/%016x-%x.frm", cacheBasePath, r.Offset, uint32(r.Length)) +} diff --git a/packages/shared/pkg/storage/storage_cache_seekable_test.go b/packages/shared/pkg/storage/storage_cache_seekable_test.go index 40b9ea03d7..e0ea301f70 100644 --- a/packages/shared/pkg/storage/storage_cache_seekable_test.go +++ b/packages/shared/pkg/storage/storage_cache_seekable_test.go @@ -12,10 +12,30 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" - - storagemocks "github.com/e2b-dev/infra/packages/shared/pkg/storage/mocks" ) +// testReadAt emulates the removed cachedSeekable.ReadAt via OpenRangeReader. +// This preserves the base test structure after ReadAt was removed from the Seekable interface. +func testReadAt(ctx context.Context, c *cachedSeekable, buff []byte, off int64) (int, error) { + rc, err := c.OpenRangeReader(ctx, off, int64(len(buff)), nil) + if err != nil { + return 0, err + } + + n, err := io.ReadFull(rc, buff) + + closeErr := rc.Close() + if errors.Is(err, io.ErrUnexpectedEOF) { + err = io.EOF + } + + if err == nil { + err = closeErr + } + + return n, err +} + func TestCachedFileObjectProvider_MakeChunkFilename(t *testing.T) { t.Parallel() @@ -32,7 +52,7 @@ func TestCachedFileObjectProvider_Size(t *testing.T) { const expectedSize int64 = 1024 - inner := storagemocks.NewMockSeekable(t) + inner := NewMockSeekable(t) inner.EXPECT().Size(mock.Anything).Return(expectedSize, nil) c := cachedSeekable{path: t.TempDir(), inner: inner, tracer: noopTracer} @@ -71,19 +91,19 @@ func TestCachedFileObjectProvider_WriteFromFileSystem(t *testing.T) { err = os.WriteFile(tempFilename, data, 0o644) require.NoError(t, err) - inner := storagemocks.NewMockSeekable(t) + inner := NewMockSeekable(t) inner.EXPECT(). StoreFile(mock.Anything, mock.Anything). - Return(nil) + Return(nil, [32]byte{}, nil) - featureFlags := storagemocks.NewMockFeatureFlagsClient(t) + featureFlags := NewMockFeatureFlagsClient(t) featureFlags.EXPECT().BoolFlag(mock.Anything, mock.Anything).Return(true) featureFlags.EXPECT().IntFlag(mock.Anything, mock.Anything).Return(10) c := cachedSeekable{path: cacheDir, inner: inner, chunkSize: 1024, flags: featureFlags, tracer: noopTracer} // write temp file - err = c.StoreFile(t.Context(), tempFilename) + _, _, err = c.StoreFile(t.Context(), tempFilename) require.NoError(t, err) // file is written asynchronously, wait for it to finish @@ -98,7 +118,7 @@ func TestCachedFileObjectProvider_WriteFromFileSystem(t *testing.T) { // verify that the size has been cached buff := make([]byte, len(data)) - bytesRead, err := c.ReadAt(t.Context(), buff, 0) + bytesRead, err := testReadAt(t.Context(), &c, buff, 0) require.NoError(t, err) assert.Equal(t, data, buff) assert.Equal(t, len(data), bytesRead) @@ -125,7 +145,7 @@ func TestCachedFileObjectProvider_WriteTo(t *testing.T) { require.NoError(t, err) buffer := make([]byte, 3) - read, err := c.ReadAt(t.Context(), buffer, 0) + read, err := testReadAt(t.Context(), &c, buffer, 0) require.NoError(t, err) assert.Equal(t, []byte{1, 2, 3}, buffer) assert.Equal(t, 3, read) @@ -147,7 +167,7 @@ func TestCachedFileObjectProvider_WriteTo(t *testing.T) { // per the io.ReaderAt contract. This is a cache hit — the caller // sees the data with EOF indicating end of file. buffer := make([]byte, 10) - read, err := c.ReadAt(t.Context(), buffer, 0) + read, err := testReadAt(t.Context(), &c, buffer, 0) require.ErrorIs(t, err, io.EOF) assert.Equal(t, 3, read) assert.Equal(t, []byte{1, 2, 3}, buffer[:read]) @@ -157,30 +177,27 @@ func TestCachedFileObjectProvider_WriteTo(t *testing.T) { t.Parallel() fakeData := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} - fakeStorageObjectProvider := storagemocks.NewMockSeekable(t) + inner := NewMockSeekable(t) - fakeStorageObjectProvider.EXPECT(). - ReadAt(mock.Anything, mock.Anything, mock.Anything). - RunAndReturn(func(_ context.Context, buff []byte, off int64) (int, error) { - start := off - end := off + int64(len(buff)) - end = min(end, int64(len(fakeData))) - copy(buff, fakeData[start:end]) - - return int(end - start), nil + inner.EXPECT(). + OpenRangeReader(mock.Anything, mock.Anything, mock.Anything, (*FrameTable)(nil)). + RunAndReturn(func(_ context.Context, off int64, length int64, _ *FrameTable) (io.ReadCloser, error) { + end := min(int(off)+int(length), len(fakeData)) + + return io.NopCloser(bytes.NewReader(fakeData[off:end])), nil }) tempDir := t.TempDir() c := cachedSeekable{ path: tempDir, chunkSize: 3, - inner: fakeStorageObjectProvider, + inner: inner, tracer: noopTracer, } // first read goes to source buffer := make([]byte, 3) - read, err := c.ReadAt(t.Context(), buffer, 3) + read, err := testReadAt(t.Context(), &c, buffer, 3) require.NoError(t, err) assert.Equal(t, []byte{4, 5, 6}, buffer) assert.Equal(t, 3, read) @@ -191,7 +208,7 @@ func TestCachedFileObjectProvider_WriteTo(t *testing.T) { // second read pulls from cache c.inner = nil // prevent remote reads, force cache read buffer = make([]byte, 3) - read, err = c.ReadAt(t.Context(), buffer, 3) + read, err = testReadAt(t.Context(), &c, buffer, 3) require.NoError(t, err) assert.Equal(t, []byte{4, 5, 6}, buffer) assert.Equal(t, 3, read) @@ -202,7 +219,7 @@ func TestCachedFileObjectProvider_WriteTo(t *testing.T) { fakeData := []byte{1, 2, 3} - fakeStorageObjectProvider := storagemocks.NewMockBlob(t) + fakeStorageObjectProvider := NewMockBlob(t) fakeStorageObjectProvider.EXPECT(). WriteTo(mock.Anything, mock.Anything). RunAndReturn(func(_ context.Context, dst io.Writer) (int64, error) { @@ -279,7 +296,7 @@ func TestCachedFileObjectProvider_validateReadAtParams(t *testing.T) { chunkSize: tc.chunkSize, tracer: noopTracer, } - err := c.validateReadAtParams(tc.bufferSize, tc.offset) + err := c.validateReadParams(tc.bufferSize, tc.offset) if tc.expected == nil { require.NoError(t, err) } else { @@ -292,39 +309,24 @@ func TestCachedFileObjectProvider_validateReadAtParams(t *testing.T) { func TestCachedSeekableObjectProvider_ReadAt(t *testing.T) { t.Parallel() - t.Run("failed but returns count on short read", func(t *testing.T) { - t.Parallel() - - c := cachedSeekable{chunkSize: 10, tracer: noopTracer} - errTarget := errors.New("find me") - mockSeeker := storagemocks.NewMockSeekable(t) - mockSeeker.EXPECT().ReadAt(mock.Anything, mock.Anything, mock.Anything).Return(5, errTarget) - c.inner = mockSeeker - - buff := make([]byte, 10) - count, err := c.ReadAt(t.Context(), buff, 0) - require.ErrorIs(t, err, errTarget) - assert.Equal(t, 5, count) - }) - t.Run("zero byte read with EOF is not cached", func(t *testing.T) { t.Parallel() tempDir := t.TempDir() - mockSeeker := storagemocks.NewMockSeekable(t) - mockSeeker.EXPECT(). - ReadAt(mock.Anything, mock.Anything, mock.Anything). - Return(0, io.EOF) + inner := NewMockSeekable(t) + inner.EXPECT(). + OpenRangeReader(mock.Anything, mock.Anything, mock.Anything, (*FrameTable)(nil)). + Return(io.NopCloser(bytes.NewReader(nil)), nil) c := cachedSeekable{ path: tempDir, chunkSize: 10, - inner: mockSeeker, + inner: inner, tracer: noopTracer, } buff := make([]byte, 10) - count, err := c.ReadAt(t.Context(), buff, 0) + count, err := testReadAt(t.Context(), &c, buff, 0) require.ErrorIs(t, err, io.EOF) assert.Equal(t, 0, count) @@ -335,127 +337,25 @@ func TestCachedSeekableObjectProvider_ReadAt(t *testing.T) { assert.True(t, os.IsNotExist(err), "zero-byte read should not be cached") }) - t.Run("zero byte read without EOF is not cached", func(t *testing.T) { - t.Parallel() - - tempDir := t.TempDir() - mockSeeker := storagemocks.NewMockSeekable(t) - mockSeeker.EXPECT(). - ReadAt(mock.Anything, mock.Anything, mock.Anything). - Return(0, nil) - - c := cachedSeekable{ - path: tempDir, - chunkSize: 10, - inner: mockSeeker, - tracer: noopTracer, - } - - buff := make([]byte, 10) - count, err := c.ReadAt(t.Context(), buff, 0) - require.NoError(t, err) - assert.Equal(t, 0, count) - - c.wg.Wait() - - chunkPath := c.makeChunkFilename(0) - _, err = os.Stat(chunkPath) - assert.True(t, os.IsNotExist(err), "zero-byte read should not be cached") - }) - - t.Run("short read without EOF is not cached", func(t *testing.T) { - t.Parallel() - - tempDir := t.TempDir() - mockSeeker := storagemocks.NewMockSeekable(t) - mockSeeker.EXPECT(). - ReadAt(mock.Anything, mock.Anything, mock.Anything). - RunAndReturn(func(_ context.Context, buff []byte, _ int64) (int, error) { - // Simulate a truncated upstream response: return fewer - // bytes than requested with no error and no EOF. - copy(buff[:2], []byte{0xAA, 0xBB}) - - return 2, nil - }) - - c := cachedSeekable{ - path: tempDir, - chunkSize: 10, - inner: mockSeeker, - tracer: noopTracer, - } - - buff := make([]byte, 10) - count, err := c.ReadAt(t.Context(), buff, 0) - require.NoError(t, err) - assert.Equal(t, 2, count) - - c.wg.Wait() - - // Verify no cache file was written. - chunkPath := c.makeChunkFilename(0) - _, err = os.Stat(chunkPath) - assert.True(t, os.IsNotExist(err), "truncated data should not be cached") - }) - - t.Run("short read with EOF is cached", func(t *testing.T) { - t.Parallel() - - tempDir := t.TempDir() - mockSeeker := storagemocks.NewMockSeekable(t) - mockSeeker.EXPECT(). - ReadAt(mock.Anything, mock.Anything, mock.Anything). - RunAndReturn(func(_ context.Context, buff []byte, _ int64) (int, error) { - // Last chunk: fewer bytes than the buffer with EOF. - copy(buff[:3], []byte{1, 2, 3}) - - return 3, io.EOF - }) - - c := cachedSeekable{ - path: tempDir, - chunkSize: 10, - inner: mockSeeker, - tracer: noopTracer, - } - - buff := make([]byte, 10) - count, err := c.ReadAt(t.Context(), buff, 0) - require.ErrorIs(t, err, io.EOF) - assert.Equal(t, 3, count) - - c.wg.Wait() - - // Verify the data was cached. - chunkPath := c.makeChunkFilename(0) - cached, err := os.ReadFile(chunkPath) - require.NoError(t, err) - assert.Equal(t, []byte{1, 2, 3}, cached) - }) - t.Run("full read without EOF is cached", func(t *testing.T) { t.Parallel() tempDir := t.TempDir() data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} - mockSeeker := storagemocks.NewMockSeekable(t) - mockSeeker.EXPECT(). - ReadAt(mock.Anything, mock.Anything, mock.Anything). - RunAndReturn(func(_ context.Context, buff []byte, _ int64) (int, error) { - copy(buff, data) - - return len(data), nil - }) + inner := NewMockSeekable(t) + inner.EXPECT(). + OpenRangeReader(mock.Anything, mock.Anything, mock.Anything, (*FrameTable)(nil)). + Return(io.NopCloser(bytes.NewReader(data)), nil) c := cachedSeekable{ path: tempDir, chunkSize: 10, - inner: mockSeeker, + inner: inner, tracer: noopTracer, } buff := make([]byte, 10) - count, err := c.ReadAt(t.Context(), buff, 0) + count, err := testReadAt(t.Context(), &c, buff, 0) require.NoError(t, err) assert.Equal(t, 10, count) @@ -504,24 +404,20 @@ func TestCachedSeekable_ReadAt_PreservesEOF(t *testing.T) { t.Parallel() tempDir := t.TempDir() - mockSeeker := storagemocks.NewMockSeekable(t) - mockSeeker.EXPECT(). - ReadAt(mock.Anything, mock.Anything, mock.Anything). - RunAndReturn(func(_ context.Context, buff []byte, _ int64) (int, error) { - copy(buff[:3], []byte{1, 2, 3}) - - return 3, io.EOF - }) + inner := NewMockSeekable(t) + inner.EXPECT(). + OpenRangeReader(mock.Anything, mock.Anything, mock.Anything, (*FrameTable)(nil)). + Return(io.NopCloser(bytes.NewReader([]byte{1, 2, 3})), nil) c := cachedSeekable{ path: tempDir, chunkSize: 10, - inner: mockSeeker, + inner: inner, tracer: noopTracer, } buff := make([]byte, 10) - n, err := c.ReadAt(t.Context(), buff, 0) + n, err := testReadAt(t.Context(), &c, buff, 0) assert.Equal(t, 3, n) require.ErrorIs(t, err, io.EOF, "cachedSeekable must not swallow io.EOF") @@ -532,24 +428,20 @@ func TestCachedSeekable_ReadAt_PreservesEOF(t *testing.T) { t.Parallel() tempDir := t.TempDir() - mockSeeker := storagemocks.NewMockSeekable(t) - mockSeeker.EXPECT(). - ReadAt(mock.Anything, mock.Anything, mock.Anything). - RunAndReturn(func(_ context.Context, buff []byte, _ int64) (int, error) { - copy(buff, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) - - return 10, nil - }) + inner := NewMockSeekable(t) + inner.EXPECT(). + OpenRangeReader(mock.Anything, mock.Anything, mock.Anything, (*FrameTable)(nil)). + Return(io.NopCloser(bytes.NewReader([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})), nil) c := cachedSeekable{ path: tempDir, chunkSize: 10, - inner: mockSeeker, + inner: inner, tracer: noopTracer, } buff := make([]byte, 10) - n, err := c.ReadAt(t.Context(), buff, 0) + n, err := testReadAt(t.Context(), &c, buff, 0) assert.Equal(t, 10, n) require.NoError(t, err, "cachedSeekable must not inject errors on full read") @@ -562,25 +454,23 @@ func TestCachedSeekable_ReadAt_SkipCacheWriteback(t *testing.T) { tempDir := t.TempDir() data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} - mockSeeker := storagemocks.NewMockSeekable(t) - mockSeeker.EXPECT(). - ReadAt(mock.Anything, mock.Anything, mock.Anything). - RunAndReturn(func(_ context.Context, buff []byte, _ int64) (int, error) { - copy(buff, data) - - return len(data), nil + inner := NewMockSeekable(t) + inner.EXPECT(). + OpenRangeReader(mock.Anything, mock.Anything, mock.Anything, (*FrameTable)(nil)). + RunAndReturn(func(_ context.Context, _ int64, _ int64, _ *FrameTable) (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(data)), nil }) c := cachedSeekable{ path: tempDir, chunkSize: 10, - inner: mockSeeker, + inner: inner, tracer: noopTracer, } ctx := WithSkipCacheWriteback(t.Context()) buff := make([]byte, 10) - n, err := c.ReadAt(ctx, buff, 0) + n, err := testReadAt(ctx, &c, buff, 0) require.NoError(t, err) assert.Equal(t, 10, n) @@ -600,21 +490,21 @@ func TestCachedSeekable_OpenRangeReader(t *testing.T) { tempDir := t.TempDir() data := []byte("hello") - mockSeeker := storagemocks.NewMockSeekable(t) - mockSeeker.EXPECT(). - OpenRangeReader(mock.Anything, int64(0), int64(len(data))). + inner := NewMockSeekable(t) + inner.EXPECT(). + OpenRangeReader(mock.Anything, int64(0), int64(len(data)), (*FrameTable)(nil)). Return(io.NopCloser(bytes.NewReader(data)), nil). Once() c := cachedSeekable{ path: tempDir, chunkSize: 10, - inner: mockSeeker, + inner: inner, tracer: noopTracer, } // First call: cache miss, reads from inner. - rc, err := c.OpenRangeReader(t.Context(), 0, int64(len(data))) + rc, err := c.OpenRangeReader(t.Context(), 0, int64(len(data)), nil) require.NoError(t, err) got, err := io.ReadAll(rc) @@ -626,7 +516,7 @@ func TestCachedSeekable_OpenRangeReader(t *testing.T) { // Second call: should serve from NFS cache, inner not called again. c.inner = nil - rc2, err := c.OpenRangeReader(t.Context(), 0, int64(len(data))) + rc2, err := c.OpenRangeReader(t.Context(), 0, int64(len(data)), nil) require.NoError(t, err) got2, err := io.ReadAll(rc2) @@ -641,10 +531,10 @@ func TestCachedSeekable_OpenRangeReader(t *testing.T) { tempDir := t.TempDir() data := []byte("hello") - mockSeeker := storagemocks.NewMockSeekable(t) - mockSeeker.EXPECT(). - OpenRangeReader(mock.Anything, int64(0), int64(len(data))). - RunAndReturn(func(_ context.Context, _ int64, _ int64) (io.ReadCloser, error) { + inner := NewMockSeekable(t) + inner.EXPECT(). + OpenRangeReader(mock.Anything, int64(0), int64(len(data)), (*FrameTable)(nil)). + RunAndReturn(func(_ context.Context, _ int64, _ int64, _ *FrameTable) (io.ReadCloser, error) { return io.NopCloser(bytes.NewReader(data)), nil }). Times(2) @@ -652,13 +542,13 @@ func TestCachedSeekable_OpenRangeReader(t *testing.T) { c := cachedSeekable{ path: tempDir, chunkSize: 10, - inner: mockSeeker, + inner: inner, tracer: noopTracer, } ctx := WithSkipCacheWriteback(t.Context()) - rc, err := c.OpenRangeReader(ctx, 0, int64(len(data))) + rc, err := c.OpenRangeReader(ctx, 0, int64(len(data)), nil) require.NoError(t, err) got, err := io.ReadAll(rc) @@ -673,7 +563,7 @@ func TestCachedSeekable_OpenRangeReader(t *testing.T) { _, err = os.Stat(chunkPath) assert.True(t, os.IsNotExist(err), "skip writeback should not populate cache") - rc2, err := c.OpenRangeReader(ctx, 0, int64(len(data))) + rc2, err := c.OpenRangeReader(ctx, 0, int64(len(data)), nil) require.NoError(t, err) got2, err := io.ReadAll(rc2) @@ -687,19 +577,19 @@ func TestCachedSeekable_OpenRangeReader(t *testing.T) { tempDir := t.TempDir() - mockSeeker := storagemocks.NewMockSeekable(t) - mockSeeker.EXPECT(). - OpenRangeReader(mock.Anything, int64(0), int64(5)). + inner := NewMockSeekable(t) + inner.EXPECT(). + OpenRangeReader(mock.Anything, int64(0), int64(5), (*FrameTable)(nil)). Return(io.NopCloser(bytes.NewReader([]byte{0xAA, 0xBB})), nil) c := cachedSeekable{ path: tempDir, chunkSize: 10, - inner: mockSeeker, + inner: inner, tracer: noopTracer, } - rc, err := c.OpenRangeReader(t.Context(), 0, 5) + rc, err := c.OpenRangeReader(t.Context(), 0, 5, nil) require.NoError(t, err) got, err := io.ReadAll(rc) diff --git a/packages/shared/pkg/storage/storage_fs.go b/packages/shared/pkg/storage/storage_fs.go index 57e1e9c5ee..3285a8cd4a 100644 --- a/packages/shared/pkg/storage/storage_fs.go +++ b/packages/shared/pkg/storage/storage_fs.go @@ -13,7 +13,12 @@ import ( "os" "path/filepath" "strconv" + "strings" "time" + + "go.uber.org/zap" + + "github.com/e2b-dev/infra/packages/shared/pkg/logger" ) type fsStorage struct { @@ -125,28 +130,81 @@ func (o *fsObject) Put(_ context.Context, data []byte, _ ...PutOption) error { return err } -func (o *fsObject) StoreFile(_ context.Context, path string, _ ...PutOption) error { +func (o *fsObject) StoreFile(ctx context.Context, path string, opts ...PutOption) (*FrameTable, [32]byte, error) { + cfg := CompressConfigFromOpts(ApplyPutOptions(opts)) + if cfg.IsCompressionEnabled() { + ft, checksum, err := o.storeFileCompressed(ctx, path, cfg) + if err == nil { + logger.L().Debug(ctx, "Stored file to filesystem", + zap.String("object", o.path), + zap.String("source", path), + zap.Int64("size_uncompressed", ft.UncompressedSize()), + zap.Int64("size_compressed", ft.CompressedSize()), + zap.String("compression", cfg.CompressionType().String()), + zap.Int("frames", ft.NumFrames()), + ) + } + + return ft, checksum, err + } + r, err := os.Open(path) if err != nil { - return fmt.Errorf("failed to open file %s: %w", path, err) + return nil, [32]byte{}, fmt.Errorf("failed to open file %s: %w", path, err) } defer r.Close() handle, err := o.getHandle(false) if err != nil { - return err + return nil, [32]byte{}, err } defer handle.Close() - _, err = io.Copy(handle, r) + n, err := io.Copy(handle, r) + if err == nil { + logger.L().Debug(ctx, "Stored file to filesystem", + zap.String("object", o.path), + zap.String("source", path), + zap.Int64("size_uncompressed", n), + zap.String("compression", "none"), + ) + } + + return nil, [32]byte{}, err +} + +func (o *fsObject) storeFileCompressed(ctx context.Context, localPath string, cfg CompressConfig) (*FrameTable, [32]byte, error) { + file, err := os.Open(localPath) if err != nil { - return err + return nil, [32]byte{}, fmt.Errorf("failed to open local file %s: %w", localPath, err) + } + defer file.Close() + + fi, err := file.Stat() + if err != nil { + return nil, [32]byte{}, fmt.Errorf("failed to stat local file %s: %w", localPath, err) + } + + uploader := &fsPartUploader{fullPath: o.path} + + const noConcurrencyForMemUploader = 1 + ft, checksum, err := compressStream(ctx, file, cfg, uploader, noConcurrencyForMemUploader) + if err != nil { + return nil, [32]byte{}, err + } + + // Sidecar is written only after compressStream succeeds so a failure (cancel, + // partial read, compress error) doesn't leave Size() reporting the new size + // against the unchanged data file. + sidecarPath := SizeSidecar(o.path) + if writeErr := os.WriteFile(sidecarPath, []byte(strconv.FormatInt(fi.Size(), 10)), 0o644); writeErr != nil { + return nil, [32]byte{}, fmt.Errorf("failed to write uncompressed-size sidecar for %s: %w", o.path, writeErr) } - return nil + return ft, checksum, nil } -func (o *fsObject) OpenRangeReader(_ context.Context, off, length int64) (io.ReadCloser, error) { +func (o *fsObject) openRangeReader(_ context.Context, off, length int64) (io.ReadCloser, error) { f, err := o.getHandle(true) if err != nil { return nil, err @@ -158,16 +216,6 @@ func (o *fsObject) OpenRangeReader(_ context.Context, off, length int64) (io.Rea }, nil } -func (o *fsObject) ReadAt(_ context.Context, buff []byte, off int64) (n int, err error) { - handle, err := o.getHandle(true) - if err != nil { - return 0, err - } - defer handle.Close() - - return handle.ReadAt(buff, off) -} - func (o *fsObject) Exists(_ context.Context) (bool, error) { _, err := os.Stat(o.path) if os.IsNotExist(err) { @@ -189,6 +237,14 @@ func (o *fsObject) Size(_ context.Context) (int64, error) { return 0, err } + // Check for .uncompressed-size sidecar file + sidecarPath := SizeSidecar(o.path) + if sidecarData, sidecarErr := os.ReadFile(sidecarPath); sidecarErr == nil { + if parsed, parseErr := strconv.ParseInt(strings.TrimSpace(string(sidecarData)), 10, 64); parseErr == nil { + return parsed, nil + } + } + return fileInfo.Size(), nil } @@ -240,3 +296,45 @@ func (o *fsObject) getHandle(checkExistence bool) (*os.File, error) { return handle, nil } + +// fsPartUploader implements partUploader for local filesystem. +// Embeds memPartUploader for concurrent-safe part collection, +// then writes atomically on Complete. +type fsPartUploader struct { + memPartUploader + + fullPath string +} + +func (u *fsPartUploader) Complete(_ context.Context) error { + if err := os.MkdirAll(filepath.Dir(u.fullPath), 0o755); err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + + return os.WriteFile(u.fullPath, u.Assemble(), 0o644) +} + +func (o *fsObject) OpenRangeReader(ctx context.Context, offsetU int64, length int64, frameTable *FrameTable) (io.ReadCloser, error) { + if frameTable.IsCompressed() { + r, err := frameTable.LocateCompressed(offsetU) + if err != nil { + return nil, fmt.Errorf("get frame for offset %d, FS:%s: %w", offsetU, o.path, err) + } + + raw, err := o.openRangeReader(ctx, r.Offset, int64(r.Length)) + if err != nil { + return nil, err + } + + decompressed, err := newDecompressingReadCloser(raw, frameTable.CompressionType()) + if err != nil { + raw.Close() + + return nil, err + } + + return decompressed, nil + } + + return o.openRangeReader(ctx, offsetU, length) +} diff --git a/packages/shared/pkg/storage/storage_google.go b/packages/shared/pkg/storage/storage_google.go index d1d7710863..cab9eb47ab 100644 --- a/packages/shared/pkg/storage/storage_google.go +++ b/packages/shared/pkg/storage/storage_google.go @@ -8,8 +8,10 @@ import ( "errors" "fmt" "io" + "maps" "net/http" "os" + "strconv" "time" "cloud.google.com/go/storage" @@ -44,13 +46,14 @@ const ( defaultGCSEnableDirectPath = false gcloudDefaultUploadConcurrency = 16 - gcsOperationAttr = "operation" - gcsOperationAttrReadAt = "ReadAt" - gcsOperationAttrWrite = "Write" - gcsOperationAttrWriteFromFileSystem = "WriteFromFileSystem" - gcsOperationAttrWriteFromFileSystemOneShot = "WriteFromFileSystemOneShot" - gcsOperationAttrWriteTo = "WriteTo" - gcsOperationAttrSize = "Size" + gcsOperationAttr = "operation" + gcsOperationAttrWrite = "Write" + gcsOperationAttrWriteFromFileSystem = "WriteFromFileSystem" + gcsOperationAttrWriteTo = "WriteTo" + gcsOperationAttrSize = "Size" + // gcsOperationAttrReadAt tags GCS read timer metrics for OpenRangeReader + // (the method was renamed from ReadAt; value kept for dashboard compatibility). + gcsOperationAttrReadAt = "ReadAt" ) var ( @@ -245,10 +248,17 @@ func (o *gcpObject) Size(ctx context.Context) (int64, error) { timer.Success(ctx, 0) + if v, ok := attrs.Metadata[MetadataKeyUncompressedSize]; ok { + parsed, parseErr := strconv.ParseInt(v, 10, 64) + if parseErr == nil { + return parsed, nil + } + } + return attrs.Size, nil } -func (o *gcpObject) OpenRangeReader(ctx context.Context, off, length int64) (io.ReadCloser, error) { +func (o *gcpObject) openRangeReader(ctx context.Context, off, length int64) (io.ReadCloser, error) { ctx, cancel := context.WithTimeout(ctx, googleReadTimeout) reader, err := o.handle.NewRangeReader(ctx, off, length) @@ -275,38 +285,6 @@ func (r *cancelOnCloseReader) Close() error { return r.ReadCloser.Close() } -func (o *gcpObject) ReadAt(ctx context.Context, buff []byte, off int64) (n int, err error) { - timer := googleReadTimerFactory.Begin(attribute.String(gcsOperationAttr, gcsOperationAttrReadAt)) - - ctx, cancel := context.WithTimeout(ctx, googleReadTimeout) - defer cancel() - - // The file should not be gzip compressed - reader, err := o.handle.NewRangeReader(ctx, off, int64(len(buff))) - if err != nil { - timer.Failure(ctx, int64(n)) - - return 0, fmt.Errorf("failed to create GCS reader for %q: %w", o.path, err) - } - - defer reader.Close() - - n, err = io.ReadFull(reader, buff) - if errors.Is(err, io.ErrUnexpectedEOF) { - err = io.EOF - } - - if ignoreEOF(err) != nil { - timer.Failure(ctx, int64(n)) - - return n, fmt.Errorf("failed to read %q: %w", o.path, err) - } - - timer.Success(ctx, int64(n)) - - return n, err -} - func (o *gcpObject) Put(ctx context.Context, data []byte, opts ...PutOption) error { timer := googleWriteTimerFactory.Begin(attribute.String(gcsOperationAttr, gcsOperationAttrWrite)) @@ -387,7 +365,7 @@ func (o *gcpObject) WriteTo(ctx context.Context, dst io.Writer) (int64, error) { return n, nil } -func (o *gcpObject) StoreFile(ctx context.Context, path string, opts ...PutOption) (e error) { +func (o *gcpObject) StoreFile(ctx context.Context, path string, opts ...PutOption) (_ *FrameTable, _ [32]byte, e error) { ctx, span := tracer.Start(ctx, "write to gcp from file system") defer func() { recordError(span, e) @@ -401,37 +379,17 @@ func (o *gcpObject) StoreFile(ctx context.Context, path string, opts ...PutOptio fileInfo, err := os.Stat(path) if err != nil { - return fmt.Errorf("failed to get file size: %w", err) + return nil, [32]byte{}, fmt.Errorf("failed to get file size: %w", err) } - // If the file is too small, the overhead of writing in parallel isn't worth the effort. - // Write it in one shot instead. - if fileInfo.Size() < gcpMultipartUploadChunkSize { - timer := googleWriteTimerFactory.Begin( - attribute.String(gcsOperationAttr, gcsOperationAttrWriteFromFileSystemOneShot), - ) - - data, err := os.ReadFile(path) - if err != nil { - timer.Failure(ctx, 0) - - return fmt.Errorf("failed to read file: %w", err) - } - - err = o.Put(ctx, data, opts...) - if err != nil { - timer.Failure(ctx, int64(len(data))) - - return fmt.Errorf("failed to write file (%d bytes): %w", len(data), err) - } - - timer.Success(ctx, int64(len(data))) - - return nil - } + cfg := CompressConfigFromOpts(putOpts) + // Tag the upload timer with the compression mode so dashboards can split + // duration/throughput by codec and level. Type is "none" when disabled. timer := googleWriteTimerFactory.Begin( attribute.String(gcsOperationAttr, gcsOperationAttrWriteFromFileSystem), + attribute.String("compression.type", cfg.CompressionType().String()), + attribute.Int("compression.level", cfg.Level), ) maxConcurrency := gcloudDefaultUploadConcurrency @@ -442,7 +400,7 @@ func (o *gcpObject) StoreFile(ctx context.Context, path string, opts ...PutOptio if semaphoreErr != nil { timer.Failure(ctx, 0) - return fmt.Errorf("failed to acquire semaphore: %w", semaphoreErr) + return nil, [32]byte{}, fmt.Errorf("failed to acquire semaphore: %w", semaphoreErr) } defer uploadLimiter.Release(1) } @@ -450,6 +408,61 @@ func (o *gcpObject) StoreFile(ctx context.Context, path string, opts ...PutOptio maxConcurrency = o.limiter.GCloudMaxTasks(ctx) } + // Compressed uploads always go through the multipart compressed path, + // regardless of file size. + if cfg.IsCompressionEnabled() { + start := time.Now() + ft, checksum, err := o.storeFileCompressed(ctx, path, cfg, maxConcurrency, putOpts) + if err != nil { + timer.Failure(ctx, fileInfo.Size()) + } else { + timer.Success(ctx, fileInfo.Size()) + + logger.L().Debug(ctx, "Uploaded file to GCS", + zap.String("bucket", bucketName), + zap.String("object", objectName), + zap.String("source", path), + zap.Int64("size_uncompressed", fileInfo.Size()), + zap.Int64("size_compressed", ft.CompressedSize()), + zap.String("compression", cfg.CompressionType().String()), + zap.Int("frames", ft.NumFrames()), + zap.Int64("duration_ms", time.Since(start).Milliseconds()), + ) + } + + return ft, checksum, err + } + + // If the file is too small, the overhead of writing in parallel isn't worth the effort. + // Write it in one shot instead. + if fileInfo.Size() < gcpMultipartUploadChunkSize { + data, err := os.ReadFile(path) + if err != nil { + timer.Failure(ctx, 0) + + return nil, [32]byte{}, fmt.Errorf("failed to read file: %w", err) + } + + err = o.Put(ctx, data, opts...) + if err != nil { + timer.Failure(ctx, int64(len(data))) + + return nil, [32]byte{}, fmt.Errorf("failed to write file (%d bytes): %w", len(data), err) + } + + timer.Success(ctx, int64(len(data))) + + logger.L().Debug(ctx, "Uploaded file to GCS", + zap.String("bucket", bucketName), + zap.String("object", objectName), + zap.String("source", path), + zap.Int64("size_uncompressed", int64(len(data))), + zap.String("compression", "none"), + ) + + return nil, [32]byte{}, e + } + uploader, err := NewMultipartUploaderWithRetryConfig( ctx, bucketName, @@ -460,7 +473,7 @@ func (o *gcpObject) StoreFile(ctx context.Context, path string, opts ...PutOptio if err != nil { timer.Failure(ctx, 0) - return fmt.Errorf("failed to create multipart uploader: %w", err) + return nil, [32]byte{}, fmt.Errorf("failed to create multipart uploader: %w", err) } start := time.Now() @@ -468,21 +481,54 @@ func (o *gcpObject) StoreFile(ctx context.Context, path string, opts ...PutOptio if err != nil { timer.Failure(ctx, count) - return fmt.Errorf("failed to upload file in parallel: %w", err) + return nil, [32]byte{}, fmt.Errorf("failed to upload file in parallel: %w", err) } - logger.L().Debug(ctx, "Uploaded file in parallel", + logger.L().Debug(ctx, "Uploaded file to GCS", zap.String("bucket", bucketName), zap.String("object", objectName), - zap.String("path", path), + zap.String("source", path), + zap.Int64("size_uncompressed", fileInfo.Size()), + zap.String("compression", "none"), zap.Int("max_concurrency", maxConcurrency), - zap.Int64("file_size", fileInfo.Size()), - zap.Int64("duration", time.Since(start).Milliseconds()), + zap.Int64("duration_ms", time.Since(start).Milliseconds()), ) timer.Success(ctx, count) - return nil + return nil, [32]byte{}, e +} + +func (o *gcpObject) storeFileCompressed(ctx context.Context, localPath string, cfg CompressConfig, maxConcurrency int, putOpts PutOptions) (*FrameTable, [32]byte, error) { + file, err := os.Open(localPath) + if err != nil { + return nil, [32]byte{}, fmt.Errorf("failed to open local file %s: %w", localPath, err) + } + defer file.Close() + + fi, err := file.Stat() + if err != nil { + return nil, [32]byte{}, fmt.Errorf("failed to stat local file %s: %w", localPath, err) + } + + // Merge caller metadata (e.g. team_id) with our internal uncompressed-size + // bookkeeping. Internal key wins on collision. + metadata := make(map[string]string, len(putOpts.Metadata)+1) + maps.Copy(metadata, putOpts.Metadata) + metadata[MetadataKeyUncompressedSize] = strconv.FormatInt(fi.Size(), 10) + + uploader, err := NewMultipartUploaderWithRetryConfig( + ctx, + o.storage.bucket.BucketName(), + o.path, + DefaultRetryConfig(), + metadata, + ) + if err != nil { + return nil, [32]byte{}, fmt.Errorf("failed to create multipart uploader: %w", err) + } + + return compressStream(ctx, file, cfg, uploader, maxConcurrency) } type gcpServiceToken struct { @@ -504,6 +550,78 @@ func parseServiceAccountBase64(serviceAccount string) (*gcpServiceToken, error) return &sa, nil } +func (o *gcpObject) OpenRangeReader(ctx context.Context, offsetU int64, length int64, frameTable *FrameTable) (io.ReadCloser, error) { + timer := googleReadTimerFactory.Begin(attribute.String(gcsOperationAttr, gcsOperationAttrReadAt)) + + if !frameTable.IsCompressed() { + rc, err := o.openRangeReader(ctx, offsetU, length) + if err != nil { + timer.Failure(ctx, 0) + + return nil, err + } + + return &timedReadCloser{inner: rc, timer: timer, ctx: ctx}, nil + } + + r, err := frameTable.LocateCompressed(offsetU) + if err != nil { + timer.Failure(ctx, 0) + + return nil, fmt.Errorf("get frame for offset %d, GCS:%s: %w", offsetU, o.path, err) + } + + raw, err := o.openRangeReader(ctx, r.Offset, int64(r.Length)) + if err != nil { + timer.Failure(ctx, 0) + + return nil, err + } + + decompressed, err := newDecompressingReadCloser(raw, frameTable.CompressionType()) + if err != nil { + raw.Close() + timer.Failure(ctx, 0) + + return nil, err + } + + return &timedReadCloser{inner: decompressed, timer: timer, ctx: ctx}, nil +} + +// timedReadCloser wraps a reader with OTEL timer metrics. +// Close records success (with total bytes read) or failure on the timer. +type timedReadCloser struct { + inner io.ReadCloser + timer *telemetry.Stopwatch + ctx context.Context //nolint:containedctx // needed for timer recording in Close + bytesRead int64 + closeErr error +} + +func (r *timedReadCloser) Read(p []byte) (int, error) { + n, err := r.inner.Read(p) + r.bytesRead += int64(n) + + if err != nil && err != io.EOF { + r.closeErr = err + } + + return n, err +} + +func (r *timedReadCloser) Close() error { + err := r.inner.Close() + + if r.closeErr != nil || err != nil { + r.timer.Failure(r.ctx, r.bytesRead) + } else { + r.timer.Success(r.ctx, r.bytesRead) + } + + return err +} + func isResourceExhausted(err error) bool { type grpcStatusProvider interface { GRPCStatus() *status.Status diff --git a/packages/shared/pkg/storage/storageopts/storageopts.go b/packages/shared/pkg/storage/storageopts/storageopts.go index df51547bbb..7cdb46481a 100644 --- a/packages/shared/pkg/storage/storageopts/storageopts.go +++ b/packages/shared/pkg/storage/storageopts/storageopts.go @@ -8,8 +8,12 @@ type ObjectMetadata map[string]string const ObjectMetadataTeamID = "team_id" +// PutOptions holds parameters for blob/seekable writes. Compression is held +// as `any` so that storage.CompressConfig (which has heavy storage-internal +// dependencies) doesn't have to be moved here. Backends type-assert it back. type PutOptions struct { - Metadata ObjectMetadata + Metadata ObjectMetadata + Compression any } type PutOption func(*PutOptions) @@ -26,6 +30,12 @@ func WithMetadata(metadata ObjectMetadata) PutOption { } } +// WithCompression stashes a compression config (typed in the storage package) +// into PutOptions. The storage package wraps this with a typed helper. +func WithCompression(cfg any) PutOption { + return func(o *PutOptions) { o.Compression = cfg } +} + func Apply(opts []PutOption) PutOptions { var p PutOptions for _, opt := range opts { diff --git a/tests/integration/Makefile b/tests/integration/Makefile index 00349fcfd4..13b52698be 100644 --- a/tests/integration/Makefile +++ b/tests/integration/Makefile @@ -40,9 +40,9 @@ test/%: *.go:*) \ BASE=$${TEST_PATH%%:*}; \ TEST_FN=$${TEST_PATH#*:}; \ - go tool gotestsum --rerun-fails=1 --packages="$$BASE" --format standard-verbose --junitfile=test-results.xml -- -count=1 -parallel=4 -run "$${TEST_FN}" ;; \ - *.go) go tool gotestsum --rerun-fails=1 --packages="$$TEST_PATH" --format standard-verbose --junitfile=test-results.xml -- -count=1 -parallel=4 ;; \ - *) go tool gotestsum --rerun-fails=1 --packages="$$TEST_PATH/..." --format standard-verbose --junitfile=test-results.xml -- -count=1 -parallel=4 ;; \ + go tool gotestsum --rerun-fails=1 --packages="$$BASE" --format standard-verbose --junitfile=test-results.xml -- -count=1 -parallel=4 -timeout=20m -run "$${TEST_FN}" ;; \ + *.go) go tool gotestsum --rerun-fails=1 --packages="$$TEST_PATH" --format standard-verbose --junitfile=test-results.xml -- -count=1 -parallel=4 -timeout=20m ;; \ + *) go tool gotestsum --rerun-fails=1 --packages="$$TEST_PATH/..." --format standard-verbose --junitfile=test-results.xml -- -count=1 -parallel=4 -timeout=20m ;; \ esac .PHONY: connect-orchestrator diff --git a/tests/integration/go.mod b/tests/integration/go.mod index 3033c74ede..6378c96e53 100644 --- a/tests/integration/go.mod +++ b/tests/integration/go.mod @@ -34,15 +34,48 @@ require ( ) require ( + cel.dev/expr v0.25.1 // indirect + cloud.google.com/go v0.123.0 // indirect + cloud.google.com/go/auth v0.18.1 // indirect + cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect + cloud.google.com/go/compute/metadata v0.9.0 // indirect + cloud.google.com/go/iam v1.5.3 // indirect + cloud.google.com/go/monitoring v1.24.3 // indirect + cloud.google.com/go/storage v1.59.2 // indirect dario.cat/mergo v1.0.2 // indirect github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c // indirect github.com/ClickHouse/ch-go v0.67.0 // indirect github.com/ClickHouse/clickhouse-go/v2 v2.40.1 // indirect github.com/DataDog/datadog-go/v5 v5.2.0 // indirect + github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.31.0 // indirect + github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.54.0 // indirect + github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.54.0 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect + github.com/RoaringBitmap/roaring/v2 v2.18.0 // indirect github.com/andybalholm/brotli v1.2.0 // indirect github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect + github.com/aws/aws-sdk-go-v2 v1.41.6 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.9 // indirect + github.com/aws/aws-sdk-go-v2/config v1.32.6 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.19.6 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16 // indirect + github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.20.12 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.22 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.22 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.23 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.8 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.14 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.22 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.22 // indirect + github.com/aws/aws-sdk-go-v2/service/s3 v1.100.0 // indirect + github.com/aws/aws-sdk-go-v2/service/signin v1.0.4 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.30.8 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.12 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.41.5 // indirect + github.com/aws/smithy-go v1.25.0 // indirect github.com/bitfield/gotestdox v0.2.2 // indirect + github.com/bits-and-blooms/bitset v1.24.4 // indirect github.com/bsm/redislock v0.9.4 // indirect github.com/bytedance/gopkg v0.1.4 // indirect github.com/bytedance/sonic v1.15.0 // indirect @@ -51,6 +84,7 @@ require ( github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect + github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5 // indirect github.com/containerd/errdefs v1.0.0 // indirect github.com/containerd/errdefs/pkg v0.3.0 // indirect github.com/containerd/log v0.1.0 // indirect @@ -65,6 +99,8 @@ require ( github.com/docker/go-units v0.5.0 // indirect github.com/dprotaso/go-yit v0.0.0-20220510233725-9ba8df137936 // indirect github.com/ebitengine/purego v0.10.0 // indirect + github.com/envoyproxy/go-control-plane/envoy v1.36.0 // indirect + github.com/envoyproxy/protoc-gen-validate v1.3.0 // indirect github.com/exaring/otelpgx v0.9.3 // indirect github.com/fatih/color v1.18.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect @@ -76,6 +112,7 @@ require ( github.com/gin-gonic/gin v1.12.0 // indirect github.com/go-faster/city v1.0.1 // indirect github.com/go-faster/errors v0.7.1 // indirect + github.com/go-jose/go-jose/v4 v4.1.4 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.3.0 // indirect @@ -89,10 +126,16 @@ require ( github.com/golang-jwt/jwt/v5 v5.3.1 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/nftables v0.3.0 // indirect + github.com/google/s2a-go v0.1.9 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect + github.com/googleapis/enterprise-certificate-proxy v0.3.11 // indirect + github.com/googleapis/gax-go/v2 v2.17.0 // indirect github.com/gorilla/mux v1.8.1 // indirect + github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79 // indirect github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.3.2 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 // indirect + github.com/hashicorp/go-cleanhttp v0.5.2 // indirect + github.com/hashicorp/go-retryablehttp v0.7.7 // indirect github.com/jackc/pgerrcode v0.0.0-20250907135507-afb5586c32a6 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect @@ -102,6 +145,14 @@ require ( github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.18.5 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/launchdarkly/ccache v1.1.0 // indirect + github.com/launchdarkly/eventsource v1.10.0 // indirect + github.com/launchdarkly/go-jsonstream/v3 v3.1.0 // indirect + github.com/launchdarkly/go-sdk-common/v3 v3.3.0 // indirect + github.com/launchdarkly/go-sdk-events/v3 v3.5.0 // indirect + github.com/launchdarkly/go-semver v1.0.3 // indirect + github.com/launchdarkly/go-server-sdk-evaluation/v3 v3.0.1 // indirect + github.com/launchdarkly/go-server-sdk/v7 v7.13.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/lib/pq v1.11.2 // indirect github.com/lufia/plan9stats v0.0.0-20240909124753-873cd0166683 // indirect @@ -124,6 +175,7 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee // indirect github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect + github.com/mschoch/smat v0.2.0 // indirect github.com/ngrok/firewall_toolkit v0.0.18 // indirect github.com/oapi-codegen/gin-middleware v1.0.2 // indirect github.com/oapi-codegen/oapi-codegen/v2 v2.6.0 // indirect @@ -132,10 +184,12 @@ require ( github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect github.com/orcaman/concurrent-map/v2 v2.0.1 // indirect + github.com/patrickmn/go-cache v2.1.0+incompatible // indirect github.com/paulmach/orb v0.11.1 // indirect github.com/pelletier/go-toml/v2 v2.3.0 // indirect github.com/perimeterx/marshmallow v1.1.5 // indirect github.com/pierrec/lz4/v4 v4.1.22 // indirect + github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect github.com/pressly/goose/v3 v3.26.0 // indirect @@ -150,6 +204,8 @@ require ( github.com/sirupsen/logrus v1.9.4 // indirect github.com/speakeasy-api/jsonpath v0.6.0 // indirect github.com/speakeasy-api/openapi-overlay v0.10.2 // indirect + github.com/spiffe/go-spiffe/v2 v2.6.0 // indirect + github.com/stretchr/objx v0.5.3 // indirect github.com/testcontainers/testcontainers-go v0.42.0 // indirect github.com/tklauser/go-sysconf v0.3.16 // indirect github.com/tklauser/numcpus v0.11.0 // indirect @@ -161,6 +217,7 @@ require ( go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/contrib/bridges/otelzap v0.14.0 // indirect + go.opentelemetry.io/contrib/detectors/gcp v1.39.0 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.65.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 // indirect go.opentelemetry.io/contrib/instrumentation/runtime v0.66.0 // indirect @@ -180,12 +237,17 @@ require ( go.uber.org/zap v1.27.1 // indirect golang.org/x/arch v0.25.0 // indirect golang.org/x/crypto v0.50.0 // indirect + golang.org/x/exp v0.0.0-20260212183809-81e46e3db34a // indirect golang.org/x/mod v0.35.0 // indirect golang.org/x/net v0.53.0 // indirect + golang.org/x/oauth2 v0.36.0 // indirect golang.org/x/sys v0.43.0 // indirect golang.org/x/term v0.42.0 // indirect golang.org/x/text v0.36.0 // indirect + golang.org/x/time v0.14.0 // indirect golang.org/x/tools v0.44.0 // indirect + google.golang.org/api v0.267.0 // indirect + google.golang.org/genproto v0.0.0-20260128011058-8636f8732409 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect diff --git a/tests/integration/go.sum b/tests/integration/go.sum index 3dc530375c..d5f7d7e164 100644 --- a/tests/integration/go.sum +++ b/tests/integration/go.sum @@ -1,3 +1,25 @@ +cel.dev/expr v0.25.1 h1:1KrZg61W6TWSxuNZ37Xy49ps13NUovb66QLprthtwi4= +cel.dev/expr v0.25.1/go.mod h1:hrXvqGP6G6gyx8UAHSHJ5RGk//1Oj5nXQ2NI02Nrsg4= +cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE= +cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU= +cloud.google.com/go/auth v0.18.1 h1:IwTEx92GFUo2pJ6Qea0EU3zYvKnTAeRCODxfA/G5UWs= +cloud.google.com/go/auth v0.18.1/go.mod h1:GfTYoS9G3CWpRA3Va9doKN9mjPGRS+v41jmZAhBzbrA= +cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc= +cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c= +cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= +cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= +cloud.google.com/go/iam v1.5.3 h1:+vMINPiDF2ognBJ97ABAYYwRgsaqxPbQDlMnbHMjolc= +cloud.google.com/go/iam v1.5.3/go.mod h1:MR3v9oLkZCTlaqljW6Eb2d3HGDGK5/bDv93jhfISFvU= +cloud.google.com/go/logging v1.13.1 h1:O7LvmO0kGLaHY/gq8cV7T0dyp6zJhYAOtZPX4TF3QtY= +cloud.google.com/go/logging v1.13.1/go.mod h1:XAQkfkMBxQRjQek96WLPNze7vsOmay9H5PqfsNYDqvw= +cloud.google.com/go/longrunning v0.8.0 h1:LiKK77J3bx5gDLi4SMViHixjD2ohlkwBi+mKA7EhfW8= +cloud.google.com/go/longrunning v0.8.0/go.mod h1:UmErU2Onzi+fKDg2gR7dusz11Pe26aknR4kHmJJqIfk= +cloud.google.com/go/monitoring v1.24.3 h1:dde+gMNc0UhPZD1Azu6at2e79bfdztVDS5lvhOdsgaE= +cloud.google.com/go/monitoring v1.24.3/go.mod h1:nYP6W0tm3N9H/bOw8am7t62YTzZY+zUeQ+Bi6+2eonI= +cloud.google.com/go/storage v1.59.2 h1:gmOAuG1opU8YvycMNpP+DvHfT9BfzzK5Cy+arP+Nocw= +cloud.google.com/go/storage v1.59.2/go.mod h1:cMWbtM+anpC74gn6qjLh+exqYcfmB9Hqe5z6adx+CLI= +cloud.google.com/go/trace v1.11.7 h1:kDNDX8JkaAG3R2nq1lIdkb7FCSi1rCmsEtKVsty7p+U= +cloud.google.com/go/trace v1.11.7/go.mod h1:TNn9d5V3fQVf6s4SCveVMIBS2LJUqo73GACmq/Tky0s= connectrpc.com/connect v1.18.1 h1:PAg7CjSAGvscaf6YZKUefjoih5Z/qYkyaTrBW8xvYPw= connectrpc.com/connect v1.18.1/go.mod h1:0292hj1rnx8oFrStN7cB4jjVBeqs+Yx5yDIC2prWDO8= dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= @@ -12,16 +34,68 @@ github.com/ClickHouse/clickhouse-go/v2 v2.40.1 h1:PbwsHBgqXRydU7jKULD1C8CHmifczf github.com/ClickHouse/clickhouse-go/v2 v2.40.1/go.mod h1:GDzSBLVhladVm8V01aEB36IoBOVLLICfyeuiIp/8Ezc= github.com/DataDog/datadog-go/v5 v5.2.0 h1:kSptqUGSNK67DgA+By3rwtFnAh6pTBxJ7Hn8JCLZcKY= github.com/DataDog/datadog-go/v5 v5.2.0/go.mod h1:XRDJk1pTc00gm+ZDiBKsjh7oOOtJfYfglVCmFb8C2+Q= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.31.0 h1:DHa2U07rk8syqvCge0QIGMCE1WxGj9njT44GH7zNJLQ= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.31.0/go.mod h1:P4WPRUkOhJC13W//jWpyfJNDAIpvRbAUIYLX/4jtlE0= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.54.0 h1:lhhYARPUu3LmHysQ/igznQphfzynnqI3D75oUyw1HXk= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.54.0/go.mod h1:l9rva3ApbBpEJxSNYnwT9N4CDLrWgtq3u8736C5hyJw= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/cloudmock v0.54.0 h1:xfK3bbi6F2RDtaZFtUdKO3osOBIhNb+xTs8lFW6yx9o= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/cloudmock v0.54.0/go.mod h1:vB2GH9GAYYJTO3mEn8oYwzEdhlayZIdQz6zdzgUIRvA= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.54.0 h1:s0WlVbf9qpvkh1c/uDAPElam0WrL7fHRIidgZJ7UqZI= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.54.0/go.mod h1:Mf6O40IAyB9zR/1J8nGDDPirZQQPbYJni8Yisy7NTMc= github.com/Microsoft/go-winio v0.5.0/go.mod h1:JPGBdM1cNvN/6ISo+n8V5iA4v8pBzdOpzfwIujj1a84= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk= +github.com/RoaringBitmap/roaring/v2 v2.18.0 h1:h7sS0VqCkfBMGgcHaudJFB4FE6Td71H6svRB2poRnGY= +github.com/RoaringBitmap/roaring/v2 v2.18.0/go.mod h1:eq4wdNXxtJIS/oikeCzdX1rBzek7ANzbth041hrU8Q4= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7Dml6nw9rQ= github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk= +github.com/aws/aws-sdk-go-v2 v1.41.6 h1:1AX0AthnBQzMx1vbmir3Y4WsnJgiydmnJjiLu+LvXOg= +github.com/aws/aws-sdk-go-v2 v1.41.6/go.mod h1:dy0UzBIfwSeot4grGvY1AqFWN5zgziMmWGzysDnHFcQ= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.9 h1:adBsCIIpLbLmYnkQU+nAChU5yhVTvu5PerROm+/Kq2A= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.9/go.mod h1:uOYhgfgThm/ZyAuJGNQ5YgNyOlYfqnGpTHXvk3cpykg= +github.com/aws/aws-sdk-go-v2/config v1.32.6 h1:hFLBGUKjmLAekvi1evLi5hVvFQtSo3GYwi+Bx4lpJf8= +github.com/aws/aws-sdk-go-v2/config v1.32.6/go.mod h1:lcUL/gcd8WyjCrMnxez5OXkO3/rwcNmvfno62tnXNcI= +github.com/aws/aws-sdk-go-v2/credentials v1.19.6 h1:F9vWao2TwjV2MyiyVS+duza0NIRtAslgLUM0vTA1ZaE= +github.com/aws/aws-sdk-go-v2/credentials v1.19.6/go.mod h1:SgHzKjEVsdQr6Opor0ihgWtkWdfRAIwxYzSJ8O85VHY= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16 h1:80+uETIWS1BqjnN9uJ0dBUaETh+P1XwFy5vwHwK5r9k= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16/go.mod h1:wOOsYuxYuB/7FlnVtzeBYRcjSRtQpAW0hCP7tIULMwo= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.20.12 h1:Zy6Tme1AA13kX8x3CnkHx5cqdGWGaj/anwOiWGnA0Xo= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.20.12/go.mod h1:ql4uXYKoTM9WUAUSmthY4AtPVrlTBZOvnBJTiCUdPxI= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.22 h1:GmLa5Kw1ESqtFpXsx5MmC84QWa/ZrLZvlJGa2y+4kcQ= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.22/go.mod h1:6sW9iWm9DK9YRpRGga/qzrzNLgKpT2cIxb7Vo2eNOp0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.22 h1:dY4kWZiSaXIzxnKlj17nHnBcXXBfac6UlsAx2qL6XrU= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.22/go.mod h1:KIpEUx0JuRZLO7U6cbV204cWAEco2iC3l061IxlwLtI= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.23 h1:FPXsW9+gMuIeKmz7j6ENWcWtBGTe1kH8r9thNt5Uxx4= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.23/go.mod h1:7J8iGMdRKk6lw2C+cMIphgAnT8uTwBwNOsGkyOCm80U= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.8 h1:HtOTYcbVcGABLOVuPYaIihj6IlkqubBwFj10K5fxRek= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.8/go.mod h1:VsK9abqQeGlzPgUr+isNWzPlK2vKe9INMLWnY65f5Xs= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.14 h1:xnvDEnw+pnj5mctWiYuFbigrEzSm35x7k4KS/ZkCANg= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.14/go.mod h1:yS5rNogD8e0Wu9+l3MUwr6eENBzEeGejvINpN5PAYfY= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.22 h1:PUmZeJU6Y1Lbvt9WFuJ0ugUK2xn6hIWUBBbKuOWF30s= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.22/go.mod h1:nO6egFBoAaoXze24a2C0NjQCvdpk8OueRoYimvEB9jo= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.22 h1:SE+aQ4DEqG53RRCAIHlCf//B2ycxGH7jFkpnAh/kKPM= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.22/go.mod h1:ES3ynECd7fYeJIL6+oax+uIEljmfps0S70BaQzbMd/o= +github.com/aws/aws-sdk-go-v2/service/s3 v1.100.0 h1:7G26Sae6PMKn4kMcU5JzNfrm1YrKwyOhowXPYR2WiWY= +github.com/aws/aws-sdk-go-v2/service/s3 v1.100.0/go.mod h1:Fw9aqhJicIVee1VytBBjH+l+5ov6/PhbtIK/u3rt/ls= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.4 h1:HpI7aMmJ+mm1wkSHIA2t5EaFFv5EFYXePW30p1EIrbQ= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.4/go.mod h1:C5RdGMYGlfM0gYq/tifqgn4EbyX99V15P2V3R+VHbQU= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.8 h1:aM/Q24rIlS3bRAhTyFurowU8A0SMyGDtEOY/l/s/1Uw= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.8/go.mod h1:+fWt2UHSb4kS7Pu8y+BMBvJF0EWx+4H0hzNwtDNRTrg= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.12 h1:AHDr0DaHIAo8c9t1emrzAlVDFp+iMMKnPdYy6XO4MCE= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.12/go.mod h1:GQ73XawFFiWxyWXMHWfhiomvP3tXtdNar/fi8z18sx0= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.5 h1:SciGFVNZ4mHdm7gpD1dgZYnCuVdX1s+lFTg4+4DOy70= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.5/go.mod h1:iW40X4QBmUxdP+fZNOpfmkdMZqsovezbAeO+Ubiv2pk= +github.com/aws/smithy-go v1.25.0 h1:Sz/XJ64rwuiKtB6j98nDIPyYrV1nVNJ4YU74gttcl5U= +github.com/aws/smithy-go v1.25.0/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= github.com/bitfield/gotestdox v0.2.2 h1:x6RcPAbBbErKLnapz1QeAlf3ospg8efBsedU93CDsnE= github.com/bitfield/gotestdox v0.2.2/go.mod h1:D+gwtS0urjBrzguAkTM2wodsTQYFHdpx8eqRJ3N+9pY= +github.com/bits-and-blooms/bitset v1.24.4 h1:95H15Og1clikBrKr/DuzMXkQzECs1M6hhoGXLwLQOZE= +github.com/bits-and-blooms/bitset v1.24.4/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= @@ -46,6 +120,8 @@ github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5P github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5 h1:6xNmx7iTtyBRev0+D/Tv1FZd4SCg8axKApyNyRsAt/w= +github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5/go.mod h1:KdCmV+x/BuvyMxRnYBlmVaq4OLiKW6iRQfvC62cvdkI= github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI= github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M= github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE= @@ -82,6 +158,14 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/ebitengine/purego v0.10.0 h1:QIw4xfpWT6GWTzaW5XEKy3HXoqrJGx1ijYHzTF0/ISU= github.com/ebitengine/purego v0.10.0/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= +github.com/envoyproxy/go-control-plane v0.14.0 h1:hbG2kr4RuFj222B6+7T83thSPqLjwBIfQawTkC++2HA= +github.com/envoyproxy/go-control-plane v0.14.0/go.mod h1:NcS5X47pLl/hfqxU70yPwL9ZMkUlwlKxtAohpi2wBEU= +github.com/envoyproxy/go-control-plane/envoy v1.36.0 h1:yg/JjO5E7ubRyKX3m07GF3reDNEnfOboJ0QySbH736g= +github.com/envoyproxy/go-control-plane/envoy v1.36.0/go.mod h1:ty89S1YCCVruQAm9OtKeEkQLTb+Lkz0k8v9W0Oxsv98= +github.com/envoyproxy/go-control-plane/ratelimit v0.1.0 h1:/G9QYbddjL25KvtKTv3an9lx6VBE2cnb8wp1vEGNYGI= +github.com/envoyproxy/go-control-plane/ratelimit v0.1.0/go.mod h1:Wk+tMFAFbCXaJPzVVHnPgRKdUdwW/KdbRt94AzgRee4= +github.com/envoyproxy/protoc-gen-validate v1.3.0 h1:TvGH1wof4H33rezVKWSpqKz5NXWg5VPuZ0uONDT6eb4= +github.com/envoyproxy/protoc-gen-validate v1.3.0/go.mod h1:HvYl7zwPa5mffgyeTUHA9zHIH36nmrm7oCbo4YKoSWA= github.com/exaring/otelpgx v0.9.3 h1:4yO02tXC7ZJZ+hcqcUkfxblYNCIFGVhpUWI0iw1TzPU= github.com/exaring/otelpgx v0.9.3/go.mod h1:R5/M5LWsPPBZc1SrRE5e0DiU48bI78C1/GPTWs6I66U= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= @@ -106,6 +190,8 @@ github.com/go-faster/city v1.0.1 h1:4WAxSZ3V2Ws4QRDrscLEDcibJY8uf41H6AhXDrNDcGw= github.com/go-faster/city v1.0.1/go.mod h1:jKcUJId49qdW3L1qKHH/3wPeUstCVpVSXTM6vO3VcTw= github.com/go-faster/errors v0.7.1 h1:MkJTnDoEdi9pDabt1dpWf7AA8/BaSYZqibYyhZ20AYg= github.com/go-faster/errors v0.7.1/go.mod h1:5ySTjWFiphBs07IKuiL69nxdfd5+fzh1u7FPGZP2quo= +github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA= +github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= @@ -161,19 +247,35 @@ github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/martian/v3 v3.3.3 h1:DIhPTQrbPkgs2yJYdXU/eNACCG5DVQjySNRNlflZ9Fc= +github.com/google/martian/v3 v3.3.3/go.mod h1:iEPrYcgCF7jA9OtScMFQyAlZZ4YXTKEtJ1E6RWzmBA0= github.com/google/nftables v0.3.0 h1:bkyZ0cbpVeMHXOrtlFc8ISmfVqq5gPJukoYieyVmITg= github.com/google/nftables v0.3.0/go.mod h1:BCp9FsrbF1Fn/Yu6CLUc9GGZFw/+hsxfluNXXmxBfRM= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= +github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= +github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/googleapis/enterprise-certificate-proxy v0.3.11 h1:vAe81Msw+8tKUxi2Dqh/NZMz7475yUvmRIkXr4oN2ao= +github.com/googleapis/enterprise-certificate-proxy v0.3.11/go.mod h1:RFV7MUdlb7AgEq2v7FmMCfeSMCllAzWxFgRdusoGks8= +github.com/googleapis/gax-go/v2 v2.17.0 h1:RksgfBpxqff0EZkDWYuz9q/uWsTVz+kf43LsZ1J6SMc= +github.com/googleapis/gax-go/v2 v2.17.0/go.mod h1:mzaqghpQp4JDh3HvADwrat+6M3MOIDp5YKHhb9PAgDY= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= +github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79 h1:+ngKgrYPPJrOjhax5N+uePQ0Fh1Z7PheYoUI/0nzkPA= +github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.3.2 h1:sGm2vDRFUrQJO/Veii4h4zG2vvqG6uWNkBHSTqXOZk0= github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.3.2/go.mod h1:wd1YpapPLivG6nQgbf7ZkG1hhSOXDhhn4MLTknx2aAc= github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 h1:HWRh5R2+9EifMyIHV7ZV+MIZqgz+PMpZ14Jynv3O2Zs= github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0/go.mod h1:JfhWUomR1baixubs02l85lZYYOm7LV6om4ceouMv45c= +github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= +github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= +github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k= +github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= +github.com/hashicorp/go-retryablehttp v0.7.7 h1:C8hUCYzor8PIfXHa4UrZkU4VvK8o9ISHxT2Q8+VepXU= +github.com/hashicorp/go-retryablehttp v0.7.7/go.mod h1:pkQpWZeYWskR+D1tR2O5OcBFOxfA7DoAO6xtkuQnHTk= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/jackc/pgerrcode v0.0.0-20250907135507-afb5586c32a6 h1:D/V0gu4zQ3cL2WKeVNVM4r2gLxGGf6McLwgXzRTo2RQ= @@ -193,6 +295,8 @@ github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFF github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/juju/gnuflag v0.0.0-20171113085948-2ce1bb71843d/go.mod h1:2PavIy+JPciBPrBUjwbNvtwB6RQlve+hkpll6QSNmOE= +github.com/karlseguin/expect v1.0.2-0.20190806010014-778a5f0c6003 h1:vJ0Snvo+SLMY72r5J4sEfkuE7AFbixEP2qRbEcum/wA= +github.com/karlseguin/expect v1.0.2-0.20190806010014-778a5f0c6003/go.mod h1:zNBxMY8P21owkeogJELCLeHIt+voOSduHYTFUbwRAV8= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= @@ -207,6 +311,24 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/launchdarkly/ccache v1.1.0 h1:voD1M+ZJXR3MREOKtBwgTF9hYHl1jg+vFKS/+VAkR2k= +github.com/launchdarkly/ccache v1.1.0/go.mod h1:TlxzrlnzvYeXiLHmesMuvoZetu4Z97cV1SsdqqBJi1Q= +github.com/launchdarkly/eventsource v1.10.0 h1:H9Tp6AfGu/G2qzBJC26iperrvwhzdbiA/gx7qE2nDFI= +github.com/launchdarkly/eventsource v1.10.0/go.mod h1:J3oa50bPvJesZqNAJtb5btSIo5N6roDWhiAS3IpsKck= +github.com/launchdarkly/go-jsonstream/v3 v3.1.0 h1:U/7/LplZO72XefBQ+FzHf6o4FwLHVqBE+4V58Ornu/E= +github.com/launchdarkly/go-jsonstream/v3 v3.1.0/go.mod h1:2Pt4BR5AwWgsuVTCcIpB6Os04JFIKWfoA+7faKkZB5E= +github.com/launchdarkly/go-sdk-common/v3 v3.3.0 h1:kkf78wcKX+DOXzNjG29i+py/P+XMIw8/mXS7eEWGQwU= +github.com/launchdarkly/go-sdk-common/v3 v3.3.0/go.mod h1:mXFmDGEh4ydK3QilRhrAyKuf9v44VZQWnINyhqbbOd0= +github.com/launchdarkly/go-sdk-events/v3 v3.5.0 h1:Yav8Thm70dZbO8U1foYwZPf3w60n/lNBRaYeeNM/qg4= +github.com/launchdarkly/go-sdk-events/v3 v3.5.0/go.mod h1:oepYWQ2RvvjfL2WxkE1uJJIuRsIMOP4WIVgUpXRPcNI= +github.com/launchdarkly/go-semver v1.0.3 h1:agIy/RN3SqeQDIfKkl+oFslEdeIs7pgsJBs3CdCcGQM= +github.com/launchdarkly/go-semver v1.0.3/go.mod h1:xFmMwXba5Mb+3h72Z+VeSs9ahCvKo2QFUTHRNHVqR28= +github.com/launchdarkly/go-server-sdk-evaluation/v3 v3.0.1 h1:rTgcYAFraGFj7sBMB2b7JCYCm0b9kph4FaMX02t4osQ= +github.com/launchdarkly/go-server-sdk-evaluation/v3 v3.0.1/go.mod h1:fPS5d+zOsgFnMunj+Ki6jjlZtFvo4h9iNbtNXxzYn58= +github.com/launchdarkly/go-server-sdk/v7 v7.13.0 h1:ajiZOPBwmWVFFgP+EMdy3oS1Xl9wNDlEd/7Zn/0I2JU= +github.com/launchdarkly/go-server-sdk/v7 v7.13.0/go.mod h1:6krbDWp417H7lIg+3ehh/A/AW5xwHtiUFg06fvNYHAk= +github.com/launchdarkly/go-test-helpers/v3 v3.1.0 h1:E3bxJMzMoA+cJSF3xxtk2/chr1zshl1ZWa0/oR+8bvg= +github.com/launchdarkly/go-test-helpers/v3 v3.1.0/go.mod h1:Ake5+hZFS/DmIGKx/cizhn5W9pGA7pplcR7xCxWiLIo= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/lib/pq v1.11.2 h1:x6gxUeu39V0BHZiugWe8LXZYZ+Utk7hSJGThs8sdzfs= @@ -257,6 +379,8 @@ github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee/go.mod h1:yWu github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw= github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= +github.com/mschoch/smat v0.2.0 h1:8imxQsjDm8yFEAVBe7azKmKSgzSkZXDuKkSq9374khM= +github.com/mschoch/smat v0.2.0/go.mod h1:kc9mz7DoBKqDyiRL7VZN8KvXQMWeTaVnttLRXOlotKw= github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/ngrok/firewall_toolkit v0.0.18 h1:/+Rx/5qXXO8FpOoKpPnyR2nw8Y3KumuulSNZa3XGZE8= @@ -294,6 +418,8 @@ github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJw github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M= github.com/orcaman/concurrent-map/v2 v2.0.1 h1:jOJ5Pg2w1oeB6PeDurIYf6k9PQ+aTITr/6lP/L/zp6c= github.com/orcaman/concurrent-map/v2 v2.0.1/go.mod h1:9Eq3TG2oBe5FirmYWQfYO5iH1q0Jv47PLaNK++uCdOM= +github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= +github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/paulmach/orb v0.11.1 h1:3koVegMC4X/WeiXYz9iswopaTwMem53NzTJuTF20JzU= github.com/paulmach/orb v0.11.1/go.mod h1:5mULz1xQfs3bmQm63QEJA6lNGujuRafwA5S/EnuLaLU= github.com/paulmach/protoscan v0.2.1/go.mod h1:SpcSwydNLrxUGSDvXvO0P7g7AuhJ7lcKfDlhJCDw2gY= @@ -304,6 +430,8 @@ github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0V github.com/pierrec/lz4/v4 v4.1.22 h1:cKFw6uJDK+/gfw5BcDL0JL5aBsAFdsIT18eRtLj7VIU= github.com/pierrec/lz4/v4 v4.1.22/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= +github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -342,6 +470,8 @@ github.com/speakeasy-api/jsonpath v0.6.0 h1:IhtFOV9EbXplhyRqsVhHoBmmYjblIRh5D1/g github.com/speakeasy-api/jsonpath v0.6.0/go.mod h1:ymb2iSkyOycmzKwbEAYPJV/yi2rSmvBCLZJcyD+VVWw= github.com/speakeasy-api/openapi-overlay v0.10.2 h1:VOdQ03eGKeiHnpb1boZCGm7x8Haj6gST0P3SGTX95GU= github.com/speakeasy-api/openapi-overlay v0.10.2/go.mod h1:n0iOU7AqKpNFfEt6tq7qYITC4f0yzVVdFw0S7hukemg= +github.com/spiffe/go-spiffe/v2 v2.6.0 h1:l+DolpxNWYgruGQVV0xsfeya3CsC7m8iBzDnMpsbLuo= +github.com/spiffe/go-spiffe/v2 v2.6.0/go.mod h1:gm2SeUoMZEtpnzPNs2Csc0D/gX33k1xIx7lEzqblHEs= github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= @@ -379,6 +509,8 @@ github.com/vmware-labs/yaml-jsonpath v0.3.2 h1:/5QKeCBGdsInyDCyVNLbXyilb61MXGi9N github.com/vmware-labs/yaml-jsonpath v0.3.2/go.mod h1:U6whw1z03QyqgWdgXxvVnQ90zN1BWz5V+51Ewf8k+rQ= github.com/woodsbury/decimal128 v1.4.0 h1:xJATj7lLu4f2oObouMt2tgGiElE5gO6mSWUjQsBgUlc= github.com/woodsbury/decimal128 v1.4.0/go.mod h1:BP46FUrVjVhdTbKT+XuQh2xfQaGki9LMIRJSFuh6THU= +github.com/wsxiaoys/terminal v0.0.0-20160513160801-0940f3fc43a0 h1:3UeQBvD0TFrlVjOeLOBz+CPAI8dnbqNSVwUwRrkp7vQ= +github.com/wsxiaoys/terminal v0.0.0-20160513160801-0940f3fc43a0/go.mod h1:IXCdmsXIht47RaVFLEdVnh1t+pgYtTAhQGj73kz+2DM= github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= github.com/xdg-go/scram v1.1.1/go.mod h1:RaEWvsqvNKKvBPvcKeFjrG2cJqOkHTiyTpzz23ni57g= github.com/xdg-go/stringprep v1.0.3/go.mod h1:W3f5j4i+9rC0kuIEJL0ky1VpHXQU3ocBgklLGvcBnW8= @@ -397,6 +529,8 @@ go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/contrib/bridges/otelzap v0.14.0 h1:2nKw2ZXZOC0N8RBsBbYwGwfKR7kJWzzyCZ6QfUGW/es= go.opentelemetry.io/contrib/bridges/otelzap v0.14.0/go.mod h1:kvyVt0WEI5BB6XaIStXPIkCSQ2nSkyd8IZnAHLEXge4= +go.opentelemetry.io/contrib/detectors/gcp v1.39.0 h1:kWRNZMsfBHZ+uHjiH4y7Etn2FK26LAGkNFw7RHv1DhE= +go.opentelemetry.io/contrib/detectors/gcp v1.39.0/go.mod h1:t/OGqzHBa5v6RHZwrDBJ2OirWc+4q/w2fTbLZwAKjTk= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.65.0 h1:XmiuHzgJt067+a6kwyAzkhXooYVv3/TOw9cM2VfJgUM= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.65.0/go.mod h1:KDgtbWKTQs4bM+VPUr6WlL9m/WXcmkCcBlIzqxPGzmI= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 h1:ssfIgGNANqpVFCndZvcuyKbl0g+UAVcbBcqGkG28H0Y= @@ -413,6 +547,8 @@ go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0 h1:88Y4s2C8oTui1LGM6bT go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0/go.mod h1:Vl1/iaggsuRlrHf/hfPJPvVag77kKyvrLeD10kpMl+A= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.39.0 h1:in9O8ESIOlwJAEGTkkf34DesGRAc/Pn8qJ7k3r/42LM= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.39.0/go.mod h1:Rp0EXBm5tfnv0WL+ARyO/PHBEaEAT8UUHQ6AGJcSq6c= +go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.39.0 h1:5gn2urDL/FBnK8OkCfD1j3/ER79rUuTYmCvlXBKeYL8= +go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.39.0/go.mod h1:0fBG6ZJxhqByfFZDwSwpZGzJU671HkwpWaNe2t4VUPI= go.opentelemetry.io/otel/log v0.19.0 h1:KUZs/GOsw79TBBMfDWsXS+KZ4g2Ckzksd1ymzsIEbo4= go.opentelemetry.io/otel/log v0.19.0/go.mod h1:5DQYeGmxVIr4n0/BcJvF4upsraHjg6vudJJpnkL6Ipk= go.opentelemetry.io/otel/log/logtest v0.15.0 h1:porNFuxAjodl6LhePevOc3n7bo3Wi3JhGXNWe7KP8iU= @@ -468,6 +604,8 @@ golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= +golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= +golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -512,6 +650,8 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= @@ -526,6 +666,10 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4= gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E= +google.golang.org/api v0.267.0 h1:w+vfWPMPYeRs8qH1aYYsFX68jMls5acWl/jocfLomwE= +google.golang.org/api v0.267.0/go.mod h1:Jzc0+ZfLnyvXma3UtaTl023TdhZu6OMBP9tJ+0EmFD0= +google.golang.org/genproto v0.0.0-20260128011058-8636f8732409 h1:VQZ/yAbAtjkHgH80teYd2em3xtIkkHd7ZhqfH2N9CsM= +google.golang.org/genproto v0.0.0-20260128011058-8636f8732409/go.mod h1:rxKD3IEILWEu3P44seeNOAwZN4SaoKaQ/2eTg4mM6EM= google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA= google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:7QBABkRtR8z+TEnmXTqIqwJLlzrZKVfAUm7tY3yGv0M= google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:m8qni9SQFH0tJc1X0vmnpw/0t+AImlSvp30sEupozUg= diff --git a/tests/integration/internal/tests/api/sandboxes/sandbox_rapid_pause_resume_test.go b/tests/integration/internal/tests/api/sandboxes/sandbox_rapid_pause_resume_test.go new file mode 100644 index 0000000000..0cdbd75eb9 --- /dev/null +++ b/tests/integration/internal/tests/api/sandboxes/sandbox_rapid_pause_resume_test.go @@ -0,0 +1,191 @@ +package sandboxes + +import ( + "context" + "crypto/sha256" + "io" + "net/http" + "os" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/e2b-dev/infra/packages/shared/pkg/storage" + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" + "github.com/e2b-dev/infra/tests/integration/internal/api" + "github.com/e2b-dev/infra/tests/integration/internal/setup" + "github.com/e2b-dev/infra/tests/integration/internal/utils" +) + +// TestSandboxRapidSnapshotForkChain builds a tree of snapshots in rapid +// succession, exercising the multi-layer upload coordination: +// +// A +// ā”œā”€ā”€ B ── D +// └── C +// +// Each child snapshot is created (and a sandbox forked from it) before the +// parent's upload has finalized. The verifier reads each build's V4 header +// directly from object storage and checks (a) ancestor lineage in the Builds +// map, and (b) self's data file checksum against BuildData.Checksum. If the +// inter-uploader sync was wrong, ancestors would be missing or self's data +// would not match its recorded checksum. +func TestSandboxRapidSnapshotForkChain(t *testing.T) { + t.Parallel() + c := setup.GetAPIClient() + ctx := t.Context() + + rootSbx := utils.SetupSandboxWithCleanup(t, c, utils.WithAutoPause(false)) + + snapA := createSnapshotTemplateWithCleanup(t, c, rootSbx.SandboxID, nil) + buildA := defaultTagBuildID(t, ctx, c, snapA.SnapshotID) + + sbxB := utils.SetupSandboxWithCleanup(t, c, + utils.WithTemplateID(snapA.SnapshotID), + utils.WithAutoPause(false), + ) + snapB := createSnapshotTemplateWithCleanup(t, c, sbxB.SandboxID, nil) + buildB := defaultTagBuildID(t, ctx, c, snapB.SnapshotID) + + sbxC := utils.SetupSandboxWithCleanup(t, c, + utils.WithTemplateID(snapA.SnapshotID), + utils.WithAutoPause(false), + ) + snapC := createSnapshotTemplateWithCleanup(t, c, sbxC.SandboxID, nil) + buildC := defaultTagBuildID(t, ctx, c, snapC.SnapshotID) + + sbxD := utils.SetupSandboxWithCleanup(t, c, + utils.WithTemplateID(snapB.SnapshotID), + utils.WithAutoPause(false), + ) + snapD := createSnapshotTemplateWithCleanup(t, c, sbxD.SandboxID, nil) + buildD := defaultTagBuildID(t, ctx, c, snapD.SnapshotID) + + chain := []chainNode{ + {name: "A", templateID: snapA.SnapshotID, buildID: buildA, parent: ""}, + {name: "B", templateID: snapB.SnapshotID, buildID: buildB, parent: buildA}, + {name: "C", templateID: snapC.SnapshotID, buildID: buildC, parent: buildA}, + {name: "D", templateID: snapD.SnapshotID, buildID: buildD, parent: buildB}, + } + + verifyChainOnStorage(t, ctx, chain) +} + +type chainNode struct { + name string + templateID string + buildID string + parent string // empty for root +} + +// verifyChainOnStorage loads each build's V4 memfile/rootfs headers directly +// from the configured storage backend and asserts (a) ancestor lineage in +// the Builds map and (b) self's data file matches its recorded checksum. +// +// Skipped when TEMPLATE_BUCKET_NAME / STORAGE_PROVIDER aren't set. +func verifyChainOnStorage(t *testing.T, ctx context.Context, chain []chainNode) { + t.Helper() + + if os.Getenv("TEMPLATE_BUCKET_NAME") == "" && !storage.IsLocal() { + t.Log("storage env not configured (TEMPLATE_BUCKET_NAME / STORAGE_PROVIDER); skipping direct storage verification") + + return + } + + persistence, err := storage.GetStorageProvider(ctx, storage.TemplateStorageConfig) + require.NoError(t, err, "build storage provider") + + ancestors := make(map[string][]string, len(chain)) + for _, node := range chain { + var chainAncestors []string + if node.parent != "" { + chainAncestors = append(chainAncestors, ancestors[node.parent]...) + chainAncestors = append(chainAncestors, node.parent) + } + ancestors[node.buildID] = chainAncestors + + paths := storage.Paths{BuildID: node.buildID} + verifyHeader(t, ctx, persistence, node, paths, storage.MemfileName, paths.MemfileHeader(), storage.MemfileObjectType, chainAncestors) + verifyHeader(t, ctx, persistence, node, paths, storage.RootfsName, paths.RootfsHeader(), storage.RootFSObjectType, chainAncestors) + } +} + +func verifyHeader(t *testing.T, ctx context.Context, persistence storage.StorageProvider, node chainNode, paths storage.Paths, fileName, headerPath string, objType storage.SeekableObjectType, ancestors []string) { + t.Helper() + + h := loadHeaderWithPolling(t, ctx, persistence, headerPath, node.name, fileName) + require.NotNilf(t, h.Builds, "%s/%s: V4 header should carry Builds map", node.name, fileName) + + selfUUID := uuid.MustParse(node.buildID) + bd, ok := h.Builds[selfUUID] + require.Truef(t, ok, "%s/%s: Builds map missing self entry %s", node.name, fileName, node.buildID) + + for _, ancestor := range ancestors { + ancUUID := uuid.MustParse(ancestor) + _, ok := h.Builds[ancUUID] + assert.Truef(t, ok, "%s/%s: Builds map missing ancestor %s — child finalized before parent's SwapHeader", node.name, fileName, ancestor) + } + + verifyChecksum(t, ctx, persistence, node, paths, fileName, objType, bd) +} + +// verifyChecksum streams self's data file through SHA-256 and compares to +// BuildData.Checksum. For unchanged files (empty diff) the entry has zero +// values and this is a no-op. +func verifyChecksum(t *testing.T, ctx context.Context, persistence storage.StorageProvider, node chainNode, paths storage.Paths, fileName string, objType storage.SeekableObjectType, bd header.BuildData) { + t.Helper() + + if bd.Size == 0 { + return // no data uploaded for this file in this layer + } + + dataPath := paths.DataFile(fileName, bd.FrameData.CompressionType()) + + obj, err := persistence.OpenSeekable(ctx, dataPath, objType) + require.NoErrorf(t, err, "%s/%s: open data file %s", node.name, fileName, dataPath) + + rc, err := obj.OpenRangeReader(ctx, 0, bd.Size, bd.FrameData) + require.NoErrorf(t, err, "%s/%s: open range reader", node.name, fileName) + defer rc.Close() + + hasher := sha256.New() + n, err := io.Copy(hasher, rc) + require.NoErrorf(t, err, "%s/%s: stream data through hasher", node.name, fileName) + require.Equalf(t, bd.Size, n, "%s/%s: streamed bytes (%d) differ from BuildData.Size (%d)", node.name, fileName, n, bd.Size) + + var got [32]byte + copy(got[:], hasher.Sum(nil)) + require.Equalf(t, bd.Checksum, got, "%s/%s: data SHA-256 does not match BuildData.Checksum — upload corrupted or checksum stale", node.name, fileName) +} + +// loadHeaderWithPolling waits for the V4 header to appear in object storage — +// snapshot uploads are async, so the header may not be present immediately +// after the snapshot endpoint returns 201. +func loadHeaderWithPolling(t *testing.T, ctx context.Context, persistence storage.StorageProvider, path, name, fileLabel string) *header.Header { + t.Helper() + + var h *header.Header + require.Eventually(t, func() bool { + var err error + h, err = header.LoadHeader(ctx, persistence, path) + + return err == nil && h != nil + }, 2*time.Minute, 500*time.Millisecond, "%s/%s: %s never appeared in storage", name, fileLabel, path) + + return h +} + +func defaultTagBuildID(t *testing.T, ctx context.Context, c *api.ClientWithResponses, snapshotID string) string { + t.Helper() + + tagsResp, err := c.GetTemplatesTemplateIDTagsWithResponse(ctx, snapshotID, setup.WithAPIKey()) + require.NoError(t, err) + require.Equal(t, http.StatusOK, tagsResp.StatusCode()) + require.NotNil(t, tagsResp.JSON200) + require.NotEmpty(t, *tagsResp.JSON200) + + return findDefaultTagBuildID(t, *tagsResp.JSON200).String() +}