diff --git a/packages/orchestrator/internal/sandbox/uffd/mapping/firecracker.go b/packages/orchestrator/internal/sandbox/uffd/mapping/firecracker.go deleted file mode 100644 index dbff8bbd7a..0000000000 --- a/packages/orchestrator/internal/sandbox/uffd/mapping/firecracker.go +++ /dev/null @@ -1,32 +0,0 @@ -package mapping - -import "fmt" - -type GuestRegionUffdMapping struct { - BaseHostVirtAddr uintptr `json:"base_host_virt_addr"` - Size uintptr `json:"size"` - Offset uintptr `json:"offset"` - // This is actually in bytes. - // This field is deprecated in the newer version of the Firecracer with a new field `page_size`. - PageSize uintptr `json:"page_size_kib"` -} - -func (m *GuestRegionUffdMapping) relativeOffset(addr uintptr) int64 { - return int64(m.Offset + addr - m.BaseHostVirtAddr) -} - -type FcMappings []GuestRegionUffdMapping - -// Returns the relative offset and the page size of the mapped range for a given address -func (m FcMappings) GetRange(addr uintptr) (int64, int64, error) { - for _, m := range m { - if addr < m.BaseHostVirtAddr || m.BaseHostVirtAddr+m.Size <= addr { - // Outside of this mapping - continue - } - - return m.relativeOffset(addr), int64(m.PageSize), nil - } - - return 0, 0, fmt.Errorf("address %d not found in any mapping", addr) -} diff --git a/packages/orchestrator/internal/sandbox/uffd/mapping/mapping.go b/packages/orchestrator/internal/sandbox/uffd/mapping/mapping.go deleted file mode 100644 index ffa8c4a6fa..0000000000 --- a/packages/orchestrator/internal/sandbox/uffd/mapping/mapping.go +++ /dev/null @@ -1,5 +0,0 @@ -package mapping - -type Mappings interface { - GetRange(addr uintptr) (offset int64, pagesize int64, err error) -} diff --git a/packages/orchestrator/internal/sandbox/uffd/memory/mapping.go b/packages/orchestrator/internal/sandbox/uffd/memory/mapping.go new file mode 100644 index 0000000000..824a1e4adb --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/memory/mapping.go @@ -0,0 +1,32 @@ +package memory + +import ( + "fmt" +) + +type AddressNotFoundError struct { + hostVirtAddr uintptr +} + +func (e AddressNotFoundError) Error() string { + return fmt.Sprintf("address %d not found in any mapping", e.hostVirtAddr) +} + +type Mapping struct { + Regions []Region +} + +func NewMapping(regions []Region) *Mapping { + return &Mapping{Regions: regions} +} + +// GetOffset returns the relative offset and the page size of the mapped range for a given address. +func (m *Mapping) GetOffset(hostVirtAddr uintptr) (int64, uint64, error) { + for _, r := range m.Regions { + if hostVirtAddr >= r.BaseHostVirtAddr && hostVirtAddr < r.endHostVirtAddr() { + return r.shiftedOffset(hostVirtAddr), uint64(r.PageSize), nil + } + } + + return 0, 0, AddressNotFoundError{hostVirtAddr: hostVirtAddr} +} diff --git a/packages/orchestrator/internal/sandbox/uffd/memory/mapping_test.go b/packages/orchestrator/internal/sandbox/uffd/memory/mapping_test.go new file mode 100644 index 0000000000..7b7d87e06f --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/memory/mapping_test.go @@ -0,0 +1,247 @@ +package memory + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +func TestMapping_GetOffset(t *testing.T) { + regions := []Region{ + { + BaseHostVirtAddr: 0x1000, + Size: 0x2000, + Offset: 0x5000, + PageSize: header.PageSize, + }, + { + BaseHostVirtAddr: 0x5000, + Size: 0x1000, + Offset: 0x8000, + PageSize: header.PageSize, + }, + } + mapping := NewMapping(regions) + + tests := []struct { + name string + hostVirtAddr uintptr + expectedOffset int64 + expectedSize uint64 + expectError error + }{ + { + name: "valid address in first region", + hostVirtAddr: 0x1500, + expectedOffset: 0x5500, // 0x5000 + (0x1500 - 0x1000) + expectedSize: 0x1000, + }, + { + name: "valid address at start of first region", + hostVirtAddr: 0x1000, + expectedOffset: 0x5000, + expectedSize: 0x1000, + }, + { + name: "valid address at end-1 of first region", + hostVirtAddr: 0x2FFF, // 0x1000 + 0x2000 - 1 + expectedOffset: 0x6FFF, // 0x5000 + (0x2FFF - 0x1000) + expectedSize: 0x1000, + }, + { + name: "valid address in second region", + hostVirtAddr: 0x5500, + expectedOffset: 0x8500, // 0x8000 + (0x5500 - 0x5000) + expectedSize: 0x1000, + }, + { + name: "valid address at start of second region", + hostVirtAddr: 0x5000, + expectedOffset: 0x8000, + expectedSize: 0x1000, + }, + { + name: "valid address at end-1 of second region", + hostVirtAddr: 0x5FFF, + expectedOffset: 0x8FFF, // 0x8000 + (0x5FFF - 0x5000) + expectedSize: 0x1000, + }, + { + name: "address before first region", + hostVirtAddr: 0x500, + expectError: AddressNotFoundError{hostVirtAddr: 0x500}, + }, + { + name: "address after last region", + hostVirtAddr: 0x7000, + expectError: AddressNotFoundError{hostVirtAddr: 0x7000}, + }, + { + name: "address in gap between regions", + hostVirtAddr: 0x4000, + expectError: AddressNotFoundError{hostVirtAddr: 0x4000}, + }, + { + name: "address at exact end of first region (exclusive)", + hostVirtAddr: 0x3000, // 0x1000 + 0x2000 + expectError: AddressNotFoundError{hostVirtAddr: 0x3000}, + }, + { + name: "address at exact end of second region (exclusive)", + hostVirtAddr: 0x6000, // 0x5000 + 0x1000 + expectError: AddressNotFoundError{hostVirtAddr: 0x6000}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + offset, size, err := mapping.GetOffset(tt.hostVirtAddr) + if tt.expectError != nil { + require.ErrorIs(t, err, tt.expectError) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectedOffset, offset) + assert.Equal(t, tt.expectedSize, size) + } + }) + } +} + +func TestMapping_EmptyRegions(t *testing.T) { + mapping := NewMapping([]Region{}) + + // Test GetOffset with empty regions + _, _, err := mapping.GetOffset(0x1000) + require.Error(t, err) +} + +func TestMapping_OverlappingRegions(t *testing.T) { + // Test with overlapping regions (edge case) + regions := []Region{ + { + BaseHostVirtAddr: 0x1000, + Size: 0x2000, + Offset: 0x5000, + PageSize: header.PageSize, + }, + { + BaseHostVirtAddr: 0x2000, // Overlaps with first region + Size: 0x1000, + Offset: 0x8000, + PageSize: header.PageSize, + }, + } + mapping := NewMapping(regions) + + // The first matching region should be returned + offset, size, err := mapping.GetOffset(0x2500) // In overlap area + require.NoError(t, err) + + // Should get result from first region + require.Equal(t, int64(0x5000+(0x2500-0x1000)), offset) // 0x6500 + require.Equal(t, uint64(header.PageSize), size) + + // Also test that the underlying implementation prefers the first region if both regions contain the address + offset2, size2, err2 := mapping.GetOffset(0x2000) + require.NoError(t, err2) + require.Equal(t, int64(0x5000+(0x2000-0x1000)), offset2) // 0x6000 from first region + require.Equal(t, uint64(header.PageSize), size2) +} + +func TestMapping_BoundaryConditions(t *testing.T) { + regions := []Region{ + { + BaseHostVirtAddr: 0x1000, + Size: 0x2000, + Offset: 0x5000, + PageSize: header.PageSize, + }, + } + mapping := NewMapping(regions) + + // Test exact start boundary + offset, _, err := mapping.GetOffset(0x1000) + require.NoError(t, err) + require.Equal(t, int64(0x5000), offset) // 0x5000 + (0x1000 - 0x1000) + + // Test just before end boundary (exclusive) + offset, _, err = mapping.GetOffset(0x2FFF) // 0x1000 + 0x2000 - 1 + require.NoError(t, err) + require.Equal(t, int64(0x5000+(0x2FFF-0x1000)), offset) // 0x6FFF + + // Test exact end boundary (should fail - exclusive) + _, _, err = mapping.GetOffset(0x3000) // 0x1000 + 0x2000 + require.Error(t, err) + + // Test below start boundary (should fail) + _, _, err = mapping.GetOffset(0x0FFF) + require.Error(t, err) +} + +func TestMapping_SingleLargeRegion(t *testing.T) { + // Entire 64-bit address space region + regions := []Region{ + { + BaseHostVirtAddr: 0x0, + Size: ^uintptr(0), // Max uintptr + Offset: 0x100, + PageSize: header.PageSize, + }, + } + mapping := NewMapping(regions) + + offset, size, err := mapping.GetOffset(0xABCDEF) + require.NoError(t, err) + require.Equal(t, int64(0x100+0xABCDEF), offset) + require.Equal(t, uint64(header.PageSize), size) +} + +func TestMapping_ZeroSizeRegion(t *testing.T) { + regions := []Region{ + { + BaseHostVirtAddr: 0x2000, + Size: 0, + Offset: 0x1000, + PageSize: header.PageSize, + }, + } + mapping := NewMapping(regions) + _, _, err := mapping.GetOffset(0x2000) + require.Error(t, err) +} + +func TestMapping_MultipleRegionsSparse(t *testing.T) { + regions := []Region{ + { + BaseHostVirtAddr: 0x100, + Size: 0x100, + Offset: 0x1000, + PageSize: header.PageSize, + }, + { + BaseHostVirtAddr: 0x10000, + Size: 0x100, + Offset: 0x2000, + PageSize: header.PageSize, + }, + } + mapping := NewMapping(regions) + // Should succeed for start of first region + offset, size, err := mapping.GetOffset(0x100) + require.NoError(t, err) + require.Equal(t, int64(0x1000), offset) + require.Equal(t, uint64(header.PageSize), size) + + // Should succeed for start of second region + offset, size, err = mapping.GetOffset(0x10000) + require.NoError(t, err) + require.Equal(t, int64(0x2000), offset) + require.Equal(t, uint64(header.PageSize), size) + + // In gap + _, _, err = mapping.GetOffset(0x5000) + require.Error(t, err) +} diff --git a/packages/orchestrator/internal/sandbox/uffd/memory/region.go b/packages/orchestrator/internal/sandbox/uffd/memory/region.go new file mode 100644 index 0000000000..b3deab2006 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/memory/region.go @@ -0,0 +1,23 @@ +package memory + +// Region is a mapping of a region of memory of the guest to a region of memory on the host. +// The serialization is based on the Firecracker UFFD protocol communication. +// https://github.com/firecracker-microvm/firecracker/blob/ceeca6a14284537ae0b2a192cd2ffef10d3a81e2/src/vmm/src/persist.rs#L96 +type Region struct { + BaseHostVirtAddr uintptr `json:"base_host_virt_addr"` + Size uintptr `json:"size"` + Offset uintptr `json:"offset"` + // This field is deprecated in the newer version of the Firecracker with a new field `page_size`. + PageSize uintptr `json:"page_size_kib"` +} + +// endHostVirtAddr returns the end address of the region in host virtual address. +// The end address is exclusive. +func (r *Region) endHostVirtAddr() uintptr { + return r.BaseHostVirtAddr + r.Size +} + +// shiftedOffset returns the offset of the given address in the region. +func (r *Region) shiftedOffset(addr uintptr) int64 { + return int64(addr - r.BaseHostVirtAddr + r.Offset) +} diff --git a/packages/orchestrator/internal/sandbox/uffd/serve.go b/packages/orchestrator/internal/sandbox/uffd/serve.go deleted file mode 100644 index e83734eca5..0000000000 --- a/packages/orchestrator/internal/sandbox/uffd/serve.go +++ /dev/null @@ -1,205 +0,0 @@ -package uffd - -import ( - "context" - "errors" - "fmt" - "sync" - "syscall" - "unsafe" - - "go.uber.org/zap" - "golang.org/x/sync/errgroup" - "golang.org/x/sys/unix" - - "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" - "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/fdexit" - "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/mapping" - "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/userfaultfd" -) - -var ErrUnexpectedEventType = errors.New("unexpected event type") - -type GuestRegionUffdMapping struct { - BaseHostVirtAddr uintptr `json:"base_host_virt_addr"` - Size uintptr `json:"size"` - Offset uintptr `json:"offset"` - PageSize uintptr `json:"page_size_kib"` -} - -func Serve( - ctx context.Context, - uffd int, - mappings mapping.Mappings, - src block.Slicer, - fdExit *fdexit.FdExit, - missingRequests *sync.Map, - logger *zap.Logger, -) error { - pollFds := []unix.PollFd{ - {Fd: int32(uffd), Events: unix.POLLIN}, - {Fd: fdExit.Reader(), Events: unix.POLLIN}, - } - - var eg errgroup.Group - - eagainCounter := newEagainCounter(logger, "uffd: eagain during fd read (accumulated)") - defer eagainCounter.Close() - -outerLoop: - for { - if _, err := unix.Poll( - pollFds, - -1, - ); err != nil { - if err == unix.EINTR { - logger.Debug("uffd: interrupted polling, going back to polling") - - continue - } - - if err == unix.EAGAIN { - logger.Debug("uffd: eagain during fd polling, going back to polling") - - continue - } - - logger.Error("UFFD serve polling error", zap.Error(err)) - - return fmt.Errorf("failed polling: %w", err) - } - - exitFd := pollFds[1] - if exitFd.Revents&unix.POLLIN != 0 { - errMsg := eg.Wait() - if errMsg != nil { - logger.Warn("UFFD fd exit error while waiting for goroutines to finish", zap.Error(errMsg)) - - return fmt.Errorf("failed to handle uffd: %w", errMsg) - } - - return nil - } - - uffdFd := pollFds[0] - if uffdFd.Revents&unix.POLLIN == 0 { - // Uffd is not ready for reading as there is nothing to read on the fd. - // https://github.com/firecracker-microvm/firecracker/issues/5056 - // https://elixir.bootlin.com/linux/v6.8.12/source/fs/userfaultfd.c#L1149 - // TODO: Check for all the errors - // - https://docs.kernel.org/admin-guide/mm/userfaultfd.html - // - https://elixir.bootlin.com/linux/v6.8.12/source/fs/userfaultfd.c - // - https://man7.org/linux/man-pages/man2/userfaultfd.2.html - // It might be possible to just check for data != 0 in the syscall.Read loop - // but I don't feel confident about doing that. - logger.Debug("uffd: no data in fd, going back to polling") - - continue - } - - buf := make([]byte, unsafe.Sizeof(userfaultfd.UffdMsg{})) - - for { - _, err := syscall.Read(uffd, buf) - if err == syscall.EINTR { - logger.Debug("uffd: interrupted read, reading again") - - continue - } - - if err == nil { - // There is no error so we can proceed. - break - } - - if err == syscall.EAGAIN { - eagainCounter.Increase() - - // Continue polling the fd. - continue outerLoop - } - - logger.Error("uffd: read error", zap.Error(err)) - - return fmt.Errorf("failed to read: %w", err) - } - - eagainCounter.Log() - - msg := *(*userfaultfd.UffdMsg)(unsafe.Pointer(&buf[0])) - if userfaultfd.GetMsgEvent(&msg) != userfaultfd.UFFD_EVENT_PAGEFAULT { - logger.Error("UFFD serve unexpected event type", zap.Any("event_type", userfaultfd.GetMsgEvent(&msg))) - - return ErrUnexpectedEventType - } - - arg := userfaultfd.GetMsgArg(&msg) - pagefault := (*(*userfaultfd.UffdPagefault)(unsafe.Pointer(&arg[0]))) - - addr := userfaultfd.GetPagefaultAddress(&pagefault) - - offset, pagesize, err := mappings.GetRange(uintptr(addr)) - if err != nil { - logger.Error("UFFD serve get mapping error", zap.Error(err)) - - return fmt.Errorf("failed to map: %w", err) - } - - if _, ok := missingRequests.Load(offset); ok { - continue - } - - missingRequests.Store(offset, struct{}{}) - - eg.Go(func() error { - defer func() { - if r := recover(); r != nil { - logger.Error("UFFD serve panic", zap.Any("offset", offset), zap.Any("pagesize", pagesize), zap.Any("panic", r)) - } - }() - - b, err := src.Slice(ctx, offset, pagesize) - if err != nil { - signalErr := fdExit.SignalExit() - - joinedErr := errors.Join(err, signalErr) - - logger.Error("UFFD serve slice error", zap.Error(joinedErr)) - - return fmt.Errorf("failed to read from source: %w", joinedErr) - } - - cpy := userfaultfd.NewUffdioCopy( - b, - addr&^userfaultfd.CULong(pagesize-1), - userfaultfd.CULong(pagesize), - 0, - 0, - ) - - if _, _, errno := syscall.Syscall( - syscall.SYS_IOCTL, - uintptr(uffd), - userfaultfd.UFFDIO_COPY, - uintptr(unsafe.Pointer(&cpy)), - ); errno != 0 { - if errno == unix.EEXIST { - logger.Debug("UFFD serve page already mapped", zap.Any("offset", offset), zap.Any("pagesize", pagesize)) - - // Page is already mapped - return nil - } - - signalErr := fdExit.SignalExit() - - joinedErr := errors.Join(errno, signalErr) - - logger.Error("UFFD serve uffdio copy error", zap.Error(joinedErr)) - - return fmt.Errorf("failed uffdio copy %w", joinedErr) - } - - return nil - }) - } -} diff --git a/packages/orchestrator/internal/sandbox/uffd/handler.go b/packages/orchestrator/internal/sandbox/uffd/uffd.go similarity index 82% rename from packages/orchestrator/internal/sandbox/uffd/handler.go rename to packages/orchestrator/internal/sandbox/uffd/uffd.go index 0f9c98f995..8f2c0954d9 100644 --- a/packages/orchestrator/internal/sandbox/uffd/handler.go +++ b/packages/orchestrator/internal/sandbox/uffd/uffd.go @@ -7,7 +7,6 @@ import ( "fmt" "net" "os" - "sync" "syscall" "time" @@ -17,7 +16,8 @@ import ( "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/fdexit" - "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/mapping" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/memory" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/userfaultfd" "github.com/e2b-dev/infra/packages/shared/pkg/logger" "github.com/e2b-dev/infra/packages/shared/pkg/utils" ) @@ -27,19 +27,17 @@ var tracer = otel.Tracer("github.com/e2b-dev/infra/packages/orchestrator/interna const ( uffdMsgListenerTimeout = 10 * time.Second fdSize = 4 - mappingsSize = 1024 + regionMappingsSize = 1024 ) type Uffd struct { - exit *utils.ErrorOnce - readyCh chan struct{} - - fdExit *fdexit.FdExit - - lis *net.UnixListener - - memfile *block.TrackedSliceDevice + exit *utils.ErrorOnce + readyCh chan struct{} + fdExit *fdexit.FdExit + lis *net.UnixListener socketPath string + memfile *block.TrackedSliceDevice + handler utils.SetOnce[*userfaultfd.Userfaultfd] } var _ MemoryBackend = (*Uffd)(nil) @@ -59,8 +57,9 @@ func New(memfile block.ReadonlyDevice, socketPath string, blockSize int64) (*Uff exit: utils.NewErrorOnce(), readyCh: make(chan struct{}, 1), fdExit: fdExit, - memfile: trackedMemfile, socketPath: socketPath, + memfile: trackedMemfile, + handler: *utils.NewSetOnce[*userfaultfd.Userfaultfd](), }, nil } @@ -107,19 +106,19 @@ func (u *Uffd) handle(ctx context.Context, sandboxId string) error { unixConn := conn.(*net.UnixConn) - mappingsBuf := make([]byte, mappingsSize) + regionMappingsBuf := make([]byte, regionMappingsSize) uffdBuf := make([]byte, syscall.CmsgSpace(fdSize)) - numBytesMappings, numBytesFd, _, _, err := unixConn.ReadMsgUnix(mappingsBuf, uffdBuf) + numBytesMappings, numBytesFd, _, _, err := unixConn.ReadMsgUnix(regionMappingsBuf, uffdBuf) if err != nil { return fmt.Errorf("failed to read unix msg from connection: %w", err) } - mappingsBuf = mappingsBuf[:numBytesMappings] + regionMappingsBuf = regionMappingsBuf[:numBytesMappings] - var m mapping.FcMappings + var regions []memory.Region - err = json.Unmarshal(mappingsBuf, &m) + err = json.Unmarshal(regionMappingsBuf, ®ions) if err != nil { return fmt.Errorf("failed parsing memory mapping data: %w", err) } @@ -142,10 +141,22 @@ func (u *Uffd) handle(ctx context.Context, sandboxId string) error { return fmt.Errorf("expected 1 fd: found %d", len(fds)) } - uffd := fds[0] + m := memory.NewMapping(regions) + + uffd, err := userfaultfd.NewUserfaultfdFromFd( + uintptr(fds[0]), + u.memfile, + m, + zap.L().With(logger.WithSandboxID(sandboxId)), + ) + if err != nil { + return fmt.Errorf("failed to create uffd: %w", err) + } + + u.handler.SetValue(uffd) defer func() { - closeErr := syscall.Close(uffd) + closeErr := uffd.Close() if closeErr != nil { zap.L().Error("failed to close uffd", logger.WithSandboxID(sandboxId), zap.String("socket_path", u.socketPath), zap.Error(closeErr)) } @@ -153,16 +164,9 @@ func (u *Uffd) handle(ctx context.Context, sandboxId string) error { u.readyCh <- struct{}{} - missingRequests := &sync.Map{} - - err = Serve( + err = uffd.Serve( ctx, - uffd, - m, - u.memfile, u.fdExit, - missingRequests, - zap.L().With(logger.WithSandboxID(sandboxId)), ) if err != nil { return fmt.Errorf("failed handling uffd: %w", err) @@ -183,10 +187,6 @@ func (u *Uffd) Exit() *utils.ErrorOnce { return u.exit } -func (u *Uffd) TrackAndReturnNil() error { - return u.lis.Close() -} - func (u *Uffd) Disable() error { return u.memfile.Disable() } diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/constants.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/constants.go deleted file mode 100644 index 63a3b3e71e..0000000000 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/constants.go +++ /dev/null @@ -1,117 +0,0 @@ -package userfaultfd - -// https://docs.kernel.org/admin-guide/mm/userfaultfd.html -// https://man7.org/linux/man-pages/man2/userfaultfd.2.html -// https://github.com/torvalds/linux/blob/master/fs/userfaultfd.c -// https://github.com/loopholelabs/userfaultfd-go/blob/main/pkg/constants/cgo.go - -/* -#include -#include -#include -#include - -struct uffd_pagefault { - __u64 flags; - __u64 address; - __u32 ptid; -}; -*/ -import "C" -import "unsafe" - -const ( - NR_userfaultfd = C.__NR_userfaultfd - - UFFD_API = C.UFFD_API - UFFD_EVENT_PAGEFAULT = C.UFFD_EVENT_PAGEFAULT - - UFFDIO_REGISTER_MODE_MISSING = C.UFFDIO_REGISTER_MODE_MISSING - UFFDIO_REGISTER_MODE_WP = C.UFFDIO_REGISTER_MODE_WP - - UFFDIO_WRITEPROTECT_MODE_WP = C.UFFDIO_WRITEPROTECT_MODE_WP - UFFDIO_COPY_MODE_WP = C.UFFDIO_COPY_MODE_WP - - UFFDIO_API = C.UFFDIO_API - UFFDIO_REGISTER = C.UFFDIO_REGISTER - UFFDIO_WRITEPROTECT = C.UFFDIO_WRITEPROTECT - UFFDIO_COPY = C.UFFDIO_COPY - - UFFD_PAGEFAULT_FLAG_WP = C.UFFD_PAGEFAULT_FLAG_WP - UFFD_PAGEFAULT_FLAG_WRITE = C.UFFD_PAGEFAULT_FLAG_WRITE - - UFFD_FEATURE_MISSING_HUGETLBFS = C.UFFD_FEATURE_MISSING_HUGETLBFS - UFFD_FEATURE_WP_HUGETLBFS_SHMEM = C.UFFD_FEATURE_WP_HUGETLBFS_SHMEM -) - -type ( - CULong = C.ulonglong - CUChar = C.uchar - CLong = C.longlong - - UffdMsg = C.struct_uffd_msg - UffdPagefault = C.struct_uffd_pagefault - - UffdioAPI = C.struct_uffdio_api - UffdioRegister = C.struct_uffdio_register - UffdioRange = C.struct_uffdio_range - UffdioCopy = C.struct_uffdio_copy - UffdioWriteProtect = C.struct_uffdio_writeprotect -) - -func NewUffdioAPI(api, features CULong) UffdioAPI { - return UffdioAPI{ - api: api, - features: features, - } -} - -func NewUffdioRegister(start, length, mode CULong) UffdioRegister { - return UffdioRegister{ - _range: UffdioRange{ - start: start, - len: length, - }, - mode: mode, - } -} - -func NewUffdioCopy(b []byte, address CULong, pagesize CULong, mode CULong, bytesCopied CLong) UffdioCopy { - return UffdioCopy{ - src: CULong(uintptr(unsafe.Pointer(&b[0]))), - dst: address &^ (pagesize - 1), - len: pagesize, - mode: mode, - copy: bytesCopied, - } -} - -func NewUffdioWriteProtect(start, length, mode CULong) UffdioWriteProtect { - return UffdioWriteProtect{ - _range: UffdioRange{ - start: start, - len: length, - }, - mode: mode, - } -} - -func GetMsgEvent(msg *UffdMsg) CUChar { - return msg.event -} - -func GetMsgArg(msg *UffdMsg) [24]byte { - return msg.arg -} - -func GetPagefaultAddress(pagefault *UffdPagefault) CULong { - return pagefault.address -} - -func IsWritePageFault(pagefault *UffdPagefault) bool { - return pagefault.flags&UFFD_PAGEFAULT_FLAG_WRITE != 0 -} - -func IsWriteProtectPageFault(pagefault *UffdPagefault) bool { - return pagefault.flags&UFFD_PAGEFAULT_FLAG_WP != 0 -} diff --git a/packages/orchestrator/internal/sandbox/uffd/cross_process_helper_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/cross_process_helpers_test.go similarity index 90% rename from packages/orchestrator/internal/sandbox/uffd/cross_process_helper_test.go rename to packages/orchestrator/internal/sandbox/uffd/userfaultfd/cross_process_helpers_test.go index 683714f595..03556d51c5 100644 --- a/packages/orchestrator/internal/sandbox/uffd/cross_process_helper_test.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/cross_process_helpers_test.go @@ -1,4 +1,4 @@ -package uffd +package userfaultfd // This tests is creating uffd in the main process and handling the page faults in another process. // It prevents problems with Go mmap during testing (https://pojntfx.github.io/networked-linux-memsync/main.html#limitations) and also more accurately simulates what we do with Firecracker. @@ -27,9 +27,8 @@ import ( "golang.org/x/sys/unix" "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/fdexit" - "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/mapping" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/memory" "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/testutils" - "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/userfaultfd" ) // Main process, FC in our case @@ -44,17 +43,18 @@ func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error memoryArea, memoryStart, err := testutils.NewPageMmap(t, uint64(size), tt.pagesize) require.NoError(t, err) - uffd, err := userfaultfd.NewUserfaultfd(syscall.O_CLOEXEC | syscall.O_NONBLOCK) + // We can pass mapping nil as the serve is used only in the helper process. + uffdFd, err := newFd(syscall.O_CLOEXEC | syscall.O_NONBLOCK) require.NoError(t, err) t.Cleanup(func() { - userfaultfd.Close(uffd) + uffdFd.close() }) - err = userfaultfd.ConfigureApi(uffd, tt.pagesize) + err = uffdFd.configureApi(tt.pagesize) require.NoError(t, err) - err = userfaultfd.Register(uffd, memoryStart, uint64(size), userfaultfd.UFFDIO_REGISTER_MODE_MISSING) + err = uffdFd.register(memoryStart, uint64(size), UFFDIO_REGISTER_MODE_MISSING) require.NoError(t, err) cmd := exec.CommandContext(t.Context(), os.Args[0], "-test.run=TestHelperServingProcess") @@ -62,7 +62,7 @@ func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error cmd.Env = append(cmd.Env, fmt.Sprintf("GO_MMAP_START=%d", memoryStart)) cmd.Env = append(cmd.Env, fmt.Sprintf("GO_MMAP_PAGE_SIZE=%d", tt.pagesize)) - dup, err := syscall.Dup(int(uffd)) + dup, err := syscall.Dup(int(uffdFd)) require.NoError(t, err) // clear FD_CLOEXEC on the dup we pass across exec @@ -174,7 +174,6 @@ func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error memoryArea: &memoryArea, pagesize: tt.pagesize, data: data, - uffd: uffd, offsetsOnce: offsetsOnce, }, nil } @@ -208,16 +207,46 @@ func crossProcessServe() error { uffdFile := os.NewFile(uintptr(3), os.Getenv("GO_UFFD_FILE")) defer uffdFile.Close() - uffd := uffdFile.Fd() + uffdFd := uffdFile.Fd() contentFile := os.NewFile(uintptr(4), "content") defer contentFile.Close() - offsetsFile := os.NewFile(uintptr(5), "offsets") + content, err := io.ReadAll(contentFile) + if err != nil { + return fmt.Errorf("exit reading content: %w", err) + } - readyFile := os.NewFile(uintptr(6), "ready") + pageSize, err := strconv.ParseInt(os.Getenv("GO_MMAP_PAGE_SIZE"), 10, 64) + if err != nil { + return fmt.Errorf("exit parsing page size: %w", err) + } + + data := testutils.NewMemorySlicer(content, pageSize) + + m := memory.NewMapping([]memory.Region{ + { + BaseHostVirtAddr: memoryStart, + Size: uintptr(len(content)), + Offset: 0, + PageSize: uintptr(pageSize), + }, + }) + + exitUffd := make(chan struct{}, 1) + defer close(exitUffd) + + logger, err := zap.NewDevelopment() + if err != nil { + return fmt.Errorf("exit creating logger: %w", err) + } + + uffd, err := NewUserfaultfdFromFd(uffdFd, data, m, logger) + if err != nil { + return fmt.Errorf("exit creating uffd: %w", err) + } - missingRequests := &sync.Map{} + offsetsFile := os.NewFile(uintptr(5), "offsets") offsetsSignal := make(chan os.Signal, 1) signal.Notify(offsetsSignal, syscall.SIGUSR2) @@ -231,7 +260,7 @@ func crossProcessServe() error { case <-ctx.Done(): return case <-offsetsSignal: - offsets, err := getAccessedOffsets(missingRequests) + offsets, err := getAccessedOffsets(&uffd.missingRequests) if err != nil { msg := fmt.Errorf("error getting accessed offsets from cross process: %w", err) @@ -260,34 +289,6 @@ func crossProcessServe() error { } }() - content, err := io.ReadAll(contentFile) - if err != nil { - return fmt.Errorf("exit reading content: %w", err) - } - - pageSize, err := strconv.Atoi(os.Getenv("GO_MMAP_PAGE_SIZE")) - if err != nil { - return fmt.Errorf("exit parsing page size: %w", err) - } - - data := testutils.NewMemorySlicer(content, int64(pageSize)) - - m := mapping.FcMappings([]mapping.GuestRegionUffdMapping{ - { - BaseHostVirtAddr: memoryStart, - Size: uintptr(len(content)), - Offset: 0, - PageSize: uintptr(pageSize), - }, - }) - - exitUffd := make(chan struct{}, 1) - - logger, err := zap.NewDevelopment() - if err != nil { - return fmt.Errorf("exit creating logger: %w", err) - } - fdExit, err := fdexit.New() if err != nil { return fmt.Errorf("exit creating fd exit: %w", err) @@ -299,7 +300,7 @@ func crossProcessServe() error { exitUffd <- struct{}{} }() - serverErr := Serve(ctx, int(uffd), m, data, fdExit, missingRequests, logger) + serverErr := uffd.Serve(ctx, fdExit) if serverErr != nil { msg := fmt.Errorf("error serving: %w", serverErr) @@ -332,6 +333,8 @@ func crossProcessServe() error { signal.Notify(exitSignal, syscall.SIGUSR1) defer signal.Stop(exitSignal) + readyFile := os.NewFile(uintptr(6), "ready") + closeErr := readyFile.Close() if closeErr != nil { return fmt.Errorf("error closing ready file: %w", closeErr) diff --git a/packages/orchestrator/internal/sandbox/uffd/eagain.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/eagain.go similarity index 97% rename from packages/orchestrator/internal/sandbox/uffd/eagain.go rename to packages/orchestrator/internal/sandbox/uffd/userfaultfd/eagain.go index f9575f21ea..29242c68c9 100644 --- a/packages/orchestrator/internal/sandbox/uffd/eagain.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/eagain.go @@ -1,4 +1,4 @@ -package uffd +package userfaultfd import ( "time" diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go new file mode 100644 index 0000000000..01a745c9eb --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go @@ -0,0 +1,166 @@ +package userfaultfd + +// https://docs.kernel.org/admin-guide/mm/userfaultfd.html +// https://man7.org/linux/man-pages/man2/userfaultfd.2.html +// https://github.com/torvalds/linux/blob/master/fs/userfaultfd.c +// https://github.com/loopholelabs/userfaultfd-go/blob/main/pkg/constants/cgo.go + +/* +#include +#include +#include +#include + +struct uffd_pagefault { + __u64 flags; + __u64 address; + __u32 ptid; +}; +*/ +import "C" + +import ( + "fmt" + "syscall" + "unsafe" + + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +const ( + NR_userfaultfd = C.__NR_userfaultfd + + UFFD_API = C.UFFD_API + UFFD_EVENT_PAGEFAULT = C.UFFD_EVENT_PAGEFAULT + + UFFDIO_REGISTER_MODE_MISSING = C.UFFDIO_REGISTER_MODE_MISSING + + UFFDIO_API = C.UFFDIO_API + UFFDIO_REGISTER = C.UFFDIO_REGISTER + UFFDIO_COPY = C.UFFDIO_COPY + + UFFD_PAGEFAULT_FLAG_WRITE = C.UFFD_PAGEFAULT_FLAG_WRITE + + UFFD_FEATURE_MISSING_HUGETLBFS = C.UFFD_FEATURE_MISSING_HUGETLBFS +) + +type ( + CULong = C.ulonglong + CUChar = C.uchar + CLong = C.longlong + + UffdMsg = C.struct_uffd_msg + UffdPagefault = C.struct_uffd_pagefault + + UffdioAPI = C.struct_uffdio_api + UffdioRegister = C.struct_uffdio_register + UffdioRange = C.struct_uffdio_range + UffdioCopy = C.struct_uffdio_copy + UffdioWriteProtect = C.struct_uffdio_writeprotect +) + +func newUffdioAPI(api, features CULong) UffdioAPI { + return UffdioAPI{ + api: api, + features: features, + } +} + +func newUffdioRegister(start, length, mode CULong) UffdioRegister { + return UffdioRegister{ + _range: UffdioRange{ + start: start, + len: length, + }, + mode: mode, + } +} + +func newUffdioCopy(b []byte, address CULong, pagesize CULong, mode CULong, bytesCopied CLong) UffdioCopy { + return UffdioCopy{ + src: CULong(uintptr(unsafe.Pointer(&b[0]))), + dst: address, + len: pagesize, + mode: mode, + copy: bytesCopied, + } +} + +func getMsgEvent(msg *UffdMsg) CUChar { + return msg.event +} + +func getMsgArg(msg *UffdMsg) [24]byte { + return msg.arg +} + +func getPagefaultAddress(pagefault *UffdPagefault) uintptr { + return uintptr(pagefault.address) +} + +// uffdFd is a helper type that wraps uffd fd. +type uffdFd uintptr + +// flags: syscall.O_CLOEXEC|syscall.O_NONBLOCK +func newFd(flags uintptr) (uffdFd, error) { + uffd, _, errno := syscall.Syscall(NR_userfaultfd, flags, 0, 0) + if errno != 0 { + return 0, fmt.Errorf("userfaultfd syscall failed: %w", errno) + } + + return uffdFd(uffd), nil +} + +// features: UFFD_FEATURE_MISSING_HUGETLBFS +// This is already called by the FC +func (u uffdFd) configureApi(pagesize uint64) error { + var features CULong + + // Only set the hugepage feature if we're using hugepages + if pagesize == header.HugepageSize { + features |= UFFD_FEATURE_MISSING_HUGETLBFS + } + + api := newUffdioAPI(UFFD_API, features) + ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(u), UFFDIO_API, uintptr(unsafe.Pointer(&api))) + if errno != 0 { + return fmt.Errorf("UFFDIO_API ioctl failed: %w (ret=%d)", errno, ret) + } + + return nil +} + +// mode: UFFDIO_REGISTER_MODE_WP|UFFDIO_REGISTER_MODE_MISSING +// This is already called by the FC, but only with the UFFDIO_REGISTER_MODE_MISSING +// We need to call it with UFFDIO_REGISTER_MODE_WP when we use both missing and wp +func (u uffdFd) register(addr uintptr, size uint64, mode CULong) error { + register := newUffdioRegister(CULong(addr), CULong(size), mode) + + ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(u), UFFDIO_REGISTER, uintptr(unsafe.Pointer(®ister))) + if errno != 0 { + return fmt.Errorf("UFFDIO_REGISTER ioctl failed: %w (ret=%d)", errno, ret) + } + + return nil +} + +// mode: UFFDIO_COPY_MODE_WP +// When we use both missing and wp, we need to use UFFDIO_COPY_MODE_WP, otherwise copying would unprotect the page +func (u uffdFd) copy(addr uintptr, data []byte, pagesize uint64, mode CULong) error { + cpy := newUffdioCopy(data, CULong(addr)&^CULong(pagesize-1), CULong(pagesize), mode, 0) + + if _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(u), UFFDIO_COPY, uintptr(unsafe.Pointer(&cpy))); errno != 0 { + return errno + } + + // Check if the copied size matches the requested pagesize + if uint64(cpy.copy) != pagesize { + return fmt.Errorf("UFFDIO_COPY copied %d bytes, expected %d", cpy.copy, pagesize) + } + + return nil +} + +func (u uffdFd) close() error { + return syscall.Close(int(u)) +} diff --git a/packages/orchestrator/internal/sandbox/uffd/helpers_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/helpers_test.go similarity index 98% rename from packages/orchestrator/internal/sandbox/uffd/helpers_test.go rename to packages/orchestrator/internal/sandbox/uffd/userfaultfd/helpers_test.go index 5ce63752a2..1a65664b86 100644 --- a/packages/orchestrator/internal/sandbox/uffd/helpers_test.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/helpers_test.go @@ -1,4 +1,4 @@ -package uffd +package userfaultfd import ( "bytes" @@ -39,7 +39,6 @@ type testHandler struct { memoryArea *[]byte pagesize uint64 data *testutils.MemorySlicer - uffd uintptr // Returns offsets of the pages that were faulted. // It can only be called once. offsetsOnce func() ([]uint, error) diff --git a/packages/orchestrator/internal/sandbox/uffd/missing_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_test.go similarity index 99% rename from packages/orchestrator/internal/sandbox/uffd/missing_test.go rename to packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_test.go index 54794cd6a5..20c8ddeeb3 100644 --- a/packages/orchestrator/internal/sandbox/uffd/missing_test.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_test.go @@ -1,4 +1,4 @@ -package uffd +package userfaultfd import ( "testing" diff --git a/packages/orchestrator/internal/sandbox/uffd/missing_write_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_write_test.go similarity index 99% rename from packages/orchestrator/internal/sandbox/uffd/missing_write_test.go rename to packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_write_test.go index 65cde3a19f..52a73e5fa2 100644 --- a/packages/orchestrator/internal/sandbox/uffd/missing_write_test.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_write_test.go @@ -1,4 +1,4 @@ -package uffd +package userfaultfd import ( "testing" diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/syscalls.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/syscalls.go deleted file mode 100644 index 8359acefc3..0000000000 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/syscalls.go +++ /dev/null @@ -1,73 +0,0 @@ -package userfaultfd - -import ( - "fmt" - "syscall" - "unsafe" - - "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" -) - -// flags: syscall.O_CLOEXEC|syscall.O_NONBLOCK -func NewUserfaultfd(flags uintptr) (uintptr, error) { - uffd, _, errno := syscall.Syscall(NR_userfaultfd, flags, 0, 0) - if errno != 0 { - return 0, fmt.Errorf("userfaultfd syscall failed: %w", errno) - } - - return uffd, nil -} - -// features: UFFD_FEATURE_MISSING_HUGETLBFS -// This is already called by the FC -func ConfigureApi(fd uintptr, pagesize uint64) error { - var features CULong - - // Only set the hugepage feature if we're using hugepages - if pagesize == header.HugepageSize { - features |= UFFD_FEATURE_MISSING_HUGETLBFS - } - - api := NewUffdioAPI(UFFD_API, features) - ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, fd, UFFDIO_API, uintptr(unsafe.Pointer(&api))) - if errno != 0 { - return fmt.Errorf("UFFDIO_API ioctl failed: %w (ret=%d)", errno, ret) - } - - return nil -} - -// mode: UFFDIO_REGISTER_MODE_WP|UFFDIO_REGISTER_MODE_MISSING -// This is already called by the FC, but only with the UFFDIO_REGISTER_MODE_MISSING -// We need to call it with UFFDIO_REGISTER_MODE_WP when we use both missing and wp -func Register(fd uintptr, addr uintptr, size uint64, mode CULong) error { - register := NewUffdioRegister(CULong(addr), CULong(size), mode) - - ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, fd, UFFDIO_REGISTER, uintptr(unsafe.Pointer(®ister))) - if errno != 0 { - return fmt.Errorf("UFFDIO_REGISTER ioctl failed: %w (ret=%d)", errno, ret) - } - - return nil -} - -// mode: UFFDIO_COPY_MODE_WP -// When we use both missing and wp, we need to use UFFDIO_COPY_MODE_WP, otherwise copying would unprotect the page -func Copy(fd uintptr, addr uintptr, data []byte, pagesize uint64, mode CULong) error { - cpy := NewUffdioCopy(data, CULong(addr)&^CULong(pagesize-1), CULong(pagesize), mode, 0) - - if _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, fd, UFFDIO_COPY, uintptr(unsafe.Pointer(&cpy))); errno != 0 { - return errno - } - - // Check if the copied size matches the requested pagesize - if uint64(cpy.copy) != pagesize { - return fmt.Errorf("UFFDIO_COPY copied %d bytes, expected %d", cpy.copy, pagesize) - } - - return nil -} - -func Close(fd uintptr) error { - return syscall.Close(int(fd)) -} diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go new file mode 100644 index 0000000000..6a2655c5fe --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go @@ -0,0 +1,246 @@ +package userfaultfd + +import ( + "context" + "errors" + "fmt" + "sync" + "syscall" + "unsafe" + + "go.uber.org/zap" + "golang.org/x/sync/errgroup" + "golang.org/x/sys/unix" + + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/fdexit" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/memory" +) + +var ErrUnexpectedEventType = errors.New("unexpected event type") + +type Userfaultfd struct { + fd uffdFd + + src block.Slicer + ma *memory.Mapping + + missingRequests sync.Map + + wg errgroup.Group + + logger *zap.Logger +} + +// NewUserfaultfdFromFd creates a new userfaultfd instance with optional configuration. +func NewUserfaultfdFromFd(fd uintptr, src block.Slicer, m *memory.Mapping, logger *zap.Logger) (*Userfaultfd, error) { + return &Userfaultfd{ + fd: uffdFd(fd), + src: src, + missingRequests: sync.Map{}, + ma: m, + logger: logger, + }, nil +} + +func (u *Userfaultfd) Close() error { + return u.fd.close() +} + +func (u *Userfaultfd) Serve( + ctx context.Context, + fdExit *fdexit.FdExit, +) error { + pollFds := []unix.PollFd{ + {Fd: int32(u.fd), Events: unix.POLLIN}, + {Fd: fdExit.Reader(), Events: unix.POLLIN}, + } + + eagainCounter := newEagainCounter(u.logger, "uffd: eagain during fd read (accumulated)") + defer eagainCounter.Close() + +outerLoop: + for { + if _, err := unix.Poll( + pollFds, + -1, + ); err != nil { + if err == unix.EINTR { + u.logger.Debug("uffd: interrupted polling, going back to polling") + + continue + } + + if err == unix.EAGAIN { + u.logger.Debug("uffd: eagain during polling, going back to polling") + + continue + } + + u.logger.Error("UFFD serve polling error", zap.Error(err)) + + return fmt.Errorf("failed polling: %w", err) + } + + exitFd := pollFds[1] + if exitFd.Revents&unix.POLLIN != 0 { + errMsg := u.wg.Wait() + if errMsg != nil { + u.logger.Warn("UFFD fd exit error while waiting for goroutines to finish", zap.Error(errMsg)) + + return fmt.Errorf("failed to handle uffd: %w", errMsg) + } + + return nil + } + + uffdFd := pollFds[0] + if uffdFd.Revents&unix.POLLIN == 0 { + // Uffd is not ready for reading as there is nothing to read on the fd. + // https://github.com/firecracker-microvm/firecracker/issues/5056 + // https://elixir.bootlin.com/linux/v6.8.12/source/fs/userfaultfd.c#L1149 + // TODO: Check for all the errors + // - https://docs.kernel.org/admin-guide/mm/userfaultfd.html + // - https://elixir.bootlin.com/linux/v6.8.12/source/fs/userfaultfd.c + // - https://man7.org/linux/man-pages/man2/userfaultfd.2.html + // It might be possible to just check for data != 0 in the syscall.Read loop + // but I don't feel confident about doing that. + u.logger.Debug("uffd: no data in fd, going back to polling") + + continue + } + + buf := make([]byte, unsafe.Sizeof(UffdMsg{})) + + for { + _, err := syscall.Read(int(u.fd), buf) + if err == syscall.EINTR { + u.logger.Debug("uffd: interrupted read, reading again") + + continue + } + + if err == nil { + // There is no error so we can proceed. + + eagainCounter.Log() + + break + } + + if err == syscall.EAGAIN { + eagainCounter.Increase() + + // Continue polling the fd. + continue outerLoop + } + + u.logger.Error("uffd: read error", zap.Error(err)) + + return fmt.Errorf("failed to read: %w", err) + } + + msg := *(*UffdMsg)(unsafe.Pointer(&buf[0])) + + if msgEvent := getMsgEvent(&msg); msgEvent != UFFD_EVENT_PAGEFAULT { + u.logger.Error("UFFD serve unexpected event type", zap.Any("event_type", msgEvent)) + + return ErrUnexpectedEventType + } + + arg := getMsgArg(&msg) + pagefault := (*(*UffdPagefault)(unsafe.Pointer(&arg[0]))) + flags := pagefault.flags + + addr := getPagefaultAddress(&pagefault) + + offset, pagesize, err := u.ma.GetOffset(addr) + if err != nil { + u.logger.Error("UFFD serve get mapping error", zap.Error(err)) + + return fmt.Errorf("failed to map: %w", err) + } + + // Handle write to missing page (WRITE flag) + // If the event has WRITE flag, it was a write to a missing page. + // For the write to be executed, we first need to copy the page from the source to the guest memory. + if flags&UFFD_PAGEFAULT_FLAG_WRITE != 0 { + err := u.handleMissing(ctx, fdExit.SignalExit, addr, offset, pagesize) + if err != nil { + return fmt.Errorf("failed to handle missing write: %w", err) + } + + continue + } + + // Handle read to missing page ("MISSING" flag) + // If the event has no flags, it was a read to a missing page and we need to copy the page from the source to the guest memory. + if flags == 0 { + err := u.handleMissing(ctx, fdExit.SignalExit, addr, offset, pagesize) + if err != nil { + return fmt.Errorf("failed to handle missing: %w", err) + } + + continue + } + + // MINOR and WP flags are not expected as we don't register the uffd with these flags. + return fmt.Errorf("unexpected event type: %d, closing uffd", flags) + } +} + +func (u *Userfaultfd) handleMissing( + ctx context.Context, + onFailure func() error, + addr uintptr, + offset int64, + pagesize uint64, +) error { + if _, ok := u.missingRequests.Load(offset); ok { + return nil + } + + u.missingRequests.Store(offset, struct{}{}) + + u.wg.Go(func() error { + defer func() { + if r := recover(); r != nil { + u.logger.Error("UFFD serve panic", zap.Any("pagesize", pagesize), zap.Any("panic", r)) + } + }() + + b, sliceErr := u.src.Slice(ctx, offset, int64(pagesize)) + if sliceErr != nil { + signalErr := onFailure() + + joinedErr := errors.Join(sliceErr, signalErr) + + u.logger.Error("UFFD serve slice error", zap.Error(joinedErr)) + + return fmt.Errorf("failed to read from source: %w", joinedErr) + } + + var copyMode CULong + + copyErr := u.fd.copy(addr, b, pagesize, copyMode) + if errors.Is(copyErr, unix.EEXIST) { + // Page is already mapped + + return nil + } + + if copyErr != nil { + signalErr := onFailure() + + joinedErr := errors.Join(copyErr, signalErr) + + u.logger.Error("UFFD serve uffdio copy error", zap.Error(joinedErr)) + + return fmt.Errorf("failed uffdio copy %w", joinedErr) + } + + return nil + }) + + return nil +}