Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions packages/api/internal/cache/snapshots/snapshot_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"time"

"github.com/google/uuid"
"github.com/redis/go-redis/v9"
"go.opentelemetry.io/otel"

Expand Down Expand Up @@ -73,6 +74,50 @@ func (c *SnapshotCache) Get(ctx context.Context, sandboxID string) (*SnapshotInf
return info, nil
}

// GetByTeam returns the last snapshot for a sandbox scoped to a specific team.
// It uses the cache for the initial lookup and validates team ownership at the DB level,
// avoiding a separate post-fetch ownership check.
func (c *SnapshotCache) GetByTeam(ctx context.Context, sandboxID string, teamID uuid.UUID) (*SnapshotInfo, error) {
ctx, span := tracer.Start(ctx, "get last snapshot by team")
defer span.End()

// Try cache first; if the cached entry belongs to the right team, return it directly.
info, err := c.cache.GetOrSet(ctx, sandboxID, c.fetchFromDB)
if err != nil {
return nil, err
}

if info.NotFound {
return nil, ErrSnapshotNotFound
}

// Cache hit and team matches – fast path.
if info.Snapshot.TeamID == teamID {
return info, nil
}

// Cache hit but team mismatch: the cached entry may belong to a different team.
// Fall back to a team-scoped DB query to get the authoritative answer.
row, err := c.db.GetLastSnapshotByTeam(ctx, queries.GetLastSnapshotByTeamParams{
SandboxID: sandboxID,
TeamID: teamID,
Comment thread
AdaAibaby marked this conversation as resolved.
})
if err != nil {
if dberrors.IsNotFoundError(err) {
return nil, ErrSnapshotNotFound
}

return nil, fmt.Errorf("fetching last snapshot by team: %w", err)
}

return &SnapshotInfo{
Aliases: row.Aliases,
Names: row.Names,
Snapshot: row.Snapshot,
EnvBuild: row.EnvBuild,
}, nil
}

func (c *SnapshotCache) fetchFromDB(ctx context.Context, sandboxID string) (*SnapshotInfo, error) {
ctx, span := tracer.Start(ctx, "fetch last snapshot from DB")
defer span.End()
Expand Down
157 changes: 157 additions & 0 deletions packages/api/internal/cache/snapshots/snapshot_cache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
package snapshotcache

import (
"testing"
"time"

"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgtype"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/e2b-dev/infra/packages/db/pkg/testutils"
"github.com/e2b-dev/infra/packages/db/pkg/types"
"github.com/e2b-dev/infra/packages/db/queries"
redis_utils "github.com/e2b-dev/infra/packages/shared/pkg/redis"
)

func setupCache(t *testing.T) (*SnapshotCache, *testutils.Database) {
t.Helper()
db := testutils.SetupDatabase(t)
redis := redis_utils.SetupInstance(t)
cache := NewSnapshotCache(db.SqlcClient, redis)
t.Cleanup(func() { _ = cache.Close(t.Context()) })

return cache, db
}

func upsertSnapshot(t *testing.T, db *testutils.Database, teamID uuid.UUID, baseTemplateID string) (sandboxID string) {
t.Helper()
sandboxID = "sandbox-" + uuid.New().String()
envdVersion := "v1.0.0"
totalDisk := int64(1024)
allowInternet := true

_, err := db.SqlcClient.UpsertSnapshot(t.Context(), queries.UpsertSnapshotParams{
TemplateID: "tmpl-" + uuid.New().String(),
TeamID: teamID,
SandboxID: sandboxID,
BaseTemplateID: baseTemplateID,
StartedAt: pgtype.Timestamptz{Time: time.Now(), Valid: true},
Vcpu: 2,
RamMb: 2048,
TotalDiskSizeMb: &totalDisk,
Metadata: types.JSONBStringMap{},
KernelVersion: "6.1.0",
FirecrackerVersion: "1.4.0",
EnvdVersion: &envdVersion,
Secure: false,
AllowInternetAccess: &allowInternet,
AutoPause: true,
OriginNodeID: "test-node",
Status: types.BuildStatusSuccess,
})
require.NoError(t, err)

return sandboxID
}

// TestSnapshotCache_GetByTeam_HitCorrectTeam verifies that GetByTeam returns the
// snapshot when the teamID matches the snapshot owner.
func TestSnapshotCache_GetByTeam_HitCorrectTeam(t *testing.T) {
t.Parallel()
cache, db := setupCache(t)
ctx := t.Context()

teamID := testutils.CreateTestTeam(t, db)
baseTemplateID := testutils.CreateTestTemplate(t, db, teamID)
sandboxID := upsertSnapshot(t, db, teamID, baseTemplateID)

info, err := cache.GetByTeam(ctx, sandboxID, teamID)
require.NoError(t, err)
assert.Equal(t, sandboxID, info.Snapshot.SandboxID)
assert.Equal(t, teamID, info.Snapshot.TeamID)
}

// TestSnapshotCache_GetByTeam_WrongTeamReturnsNotFound verifies that GetByTeam
// returns ErrSnapshotNotFound when the teamID does not match the snapshot owner.
func TestSnapshotCache_GetByTeam_WrongTeamReturnsNotFound(t *testing.T) {
t.Parallel()
cache, db := setupCache(t)
ctx := t.Context()

ownerTeamID := testutils.CreateTestTeam(t, db)
otherTeamID := testutils.CreateTestTeam(t, db)
baseTemplateID := testutils.CreateTestTemplate(t, db, ownerTeamID)
sandboxID := upsertSnapshot(t, db, ownerTeamID, baseTemplateID)

// Warm the cache with the owner's team entry.
_, err := cache.Get(ctx, sandboxID)
require.NoError(t, err)

// Now query with a different team – should fall back to DB and return not-found.
_, err = cache.GetByTeam(ctx, sandboxID, otherTeamID)
require.ErrorIs(t, err, ErrSnapshotNotFound)
}

// TestSnapshotCache_GetByTeam_CacheHitFastPath verifies that when the cached entry
// already belongs to the requested team, GetByTeam returns it without a DB round-trip.
// We verify this indirectly: after the first call populates the cache, a second call
// with the same teamID must also succeed.
func TestSnapshotCache_GetByTeam_CacheHitFastPath(t *testing.T) {
t.Parallel()
cache, db := setupCache(t)
ctx := t.Context()

teamID := testutils.CreateTestTeam(t, db)
baseTemplateID := testutils.CreateTestTemplate(t, db, teamID)
sandboxID := upsertSnapshot(t, db, teamID, baseTemplateID)

// First call – populates cache.
info1, err := cache.GetByTeam(ctx, sandboxID, teamID)
require.NoError(t, err)

// Second call – should hit cache fast path.
info2, err := cache.GetByTeam(ctx, sandboxID, teamID)
require.NoError(t, err)

assert.Equal(t, info1.Snapshot.SandboxID, info2.Snapshot.SandboxID)
assert.Equal(t, info1.EnvBuild.ID, info2.EnvBuild.ID)
}

// TestSnapshotCache_GetByTeam_UnknownSandboxReturnsNotFound verifies that
// GetByTeam returns ErrSnapshotNotFound for a sandboxID that does not exist.
func TestSnapshotCache_GetByTeam_UnknownSandboxReturnsNotFound(t *testing.T) {
t.Parallel()
cache, db := setupCache(t)
ctx := t.Context()

teamID := testutils.CreateTestTeam(t, db)

_, err := cache.GetByTeam(ctx, "nonexistent-sandbox-"+uuid.New().String(), teamID)
require.ErrorIs(t, err, ErrSnapshotNotFound)
}

// TestSnapshotCache_GetByTeam_InvalidateFlushesCache verifies that after Invalidate
// is called, GetByTeam re-fetches from the DB.
func TestSnapshotCache_GetByTeam_InvalidateFlushesCache(t *testing.T) {
t.Parallel()
cache, db := setupCache(t)
ctx := t.Context()

teamID := testutils.CreateTestTeam(t, db)
baseTemplateID := testutils.CreateTestTemplate(t, db, teamID)
sandboxID := upsertSnapshot(t, db, teamID, baseTemplateID)

// Populate cache.
_, err := cache.GetByTeam(ctx, sandboxID, teamID)
require.NoError(t, err)

// Invalidate.
cache.Invalidate(ctx, sandboxID)

// Should still succeed (re-fetches from DB).
info, err := cache.GetByTeam(ctx, sandboxID, teamID)
require.NoError(t, err)
assert.Equal(t, sandboxID, info.Snapshot.SandboxID)
}
10 changes: 1 addition & 9 deletions packages/api/internal/handlers/sandbox_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,7 @@ func (a *APIStore) PostSandboxesSandboxIDConnect(c *gin.Context, sandboxID api.S
continue
}

// TODO: ENG-3544 scope GetLastSnapshot query by teamID to avoid post-fetch ownership check.
lastSnapshot, err := a.snapshotCache.Get(ctx, sandboxID)
lastSnapshot, err := a.snapshotCache.GetByTeam(ctx, sandboxID, teamID)
if err != nil {
if errors.Is(err, snapshotcache.ErrSnapshotNotFound) {
logger.L().Debug(ctx, "Snapshot not found", logger.WithSandboxID(sandboxID))
Expand All @@ -131,13 +130,6 @@ func (a *APIStore) PostSandboxesSandboxIDConnect(c *gin.Context, sandboxID api.S
return
}

if lastSnapshot.Snapshot.TeamID != teamID {
telemetry.ReportError(ctx, fmt.Sprintf("snapshot for sandbox '%s' doesn't belong to team '%s'", sandboxID, teamID.String()), nil)
a.sendAPIStoreError(c, http.StatusNotFound, utils.SandboxNotFoundMsg(sandboxID))

return
}

sbxlogger.E(&sbxlogger.SandboxMetadata{
SandboxID: sandboxID,
TemplateID: lastSnapshot.Snapshot.EnvID,
Expand Down
12 changes: 2 additions & 10 deletions packages/api/internal/handlers/sandbox_get.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,8 @@ func (a *APIStore) GetSandboxesSandboxID(c *gin.Context, id string) {
return
}

// If sandbox not found try to get the latest snapshot
// TODO: ENG-3544 scope GetLastSnapshot query by teamID to avoid post-fetch ownership check.
lastSnapshot, err := a.snapshotCache.Get(ctx, sandboxId)
// If sandbox not found try to get the latest snapshot scoped to this team.
lastSnapshot, err := a.snapshotCache.GetByTeam(ctx, sandboxId, team.ID)
if err != nil {
if errors.Is(err, snapshotcache.ErrSnapshotNotFound) {
telemetry.ReportError(ctx, "snapshot not found", err, telemetry.WithSandboxID(sandboxId))
Expand All @@ -179,13 +178,6 @@ func (a *APIStore) GetSandboxesSandboxID(c *gin.Context, id string) {
return
}

if lastSnapshot.Snapshot.TeamID != team.ID {
telemetry.ReportError(ctx, fmt.Sprintf("snapshot for sandbox '%s' doesn't belong to team '%s'", sandboxId, team.ID.String()), nil)
a.sendAPIStoreError(c, http.StatusNotFound, utils.SandboxNotFoundMsg(id))

return
}

memoryMB := int32(lastSnapshot.EnvBuild.RamMb)
cpuCount := int32(lastSnapshot.EnvBuild.Vcpu)

Expand Down
12 changes: 1 addition & 11 deletions packages/api/internal/handlers/sandbox_pause.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,18 +81,8 @@ func (a *APIStore) PostSandboxesSandboxIDPause(c *gin.Context, sandboxID api.San
}

func pauseHandleNotRunningSandbox(ctx context.Context, cache *snapshotcache.SnapshotCache, sandboxID string, teamID uuid.UUID) api.APIError {
// TODO: ENG-3544 scope GetLastSnapshot query by teamID to avoid post-fetch ownership check.
snap, err := cache.Get(ctx, sandboxID)
_, err := cache.GetByTeam(ctx, sandboxID, teamID)
if err == nil {
if snap.Snapshot.TeamID != teamID {
logger.L().Debug(ctx, "Snapshot team mismatch on pause", logger.WithSandboxID(sandboxID), logger.WithTeamID(teamID.String()))

return api.APIError{
Code: http.StatusNotFound,
ClientMsg: utils.SandboxNotFoundMsg(sandboxID),
}
}

logger.L().Warn(ctx, "Sandbox is already paused", logger.WithSandboxID(sandboxID))

return api.APIError{
Expand Down
10 changes: 1 addition & 9 deletions packages/api/internal/handlers/sandbox_resume.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,7 @@ func (a *APIStore) PostSandboxesSandboxIDResume(c *gin.Context, sandboxID api.Sa
}
}

// TODO: ENG-3544 scope GetLastSnapshot query by teamID to avoid post-fetch ownership check.
lastSnapshot, err := a.snapshotCache.Get(ctx, sandboxID)
lastSnapshot, err := a.snapshotCache.GetByTeam(ctx, sandboxID, teamID)
if err != nil {
if errors.Is(err, snapshotcache.ErrSnapshotNotFound) {
logger.L().Debug(ctx, "Snapshot not found", logger.WithSandboxID(sandboxID))
Expand All @@ -139,13 +138,6 @@ func (a *APIStore) PostSandboxesSandboxIDResume(c *gin.Context, sandboxID api.Sa
return
}

if lastSnapshot.Snapshot.TeamID != teamID {
telemetry.ReportError(ctx, fmt.Sprintf("snapshot for sandbox '%s' doesn't belong to team '%s'", sandboxID, teamID.String()), nil)
a.sendAPIStoreError(c, http.StatusNotFound, utils.SandboxNotFoundMsg(sandboxID))

return
}

sbxlogger.E(&sbxlogger.SandboxMetadata{
SandboxID: sandboxID,
TemplateID: lastSnapshot.Snapshot.EnvID,
Expand Down
Loading