Skip to content

Commit 47a8b76

Browse files
committed
[ET Device Support] DeviceMemoryBuffer RAII class for device memory lifetime management
Introduces DeviceMemoryBuffer, an RAII wrapper that owns a single device memory allocation. On destruction, it automatically calls DeviceAllocator::deallocate() to free the memory. This mirrors the role of std::vector<uint8_t> for CPU planned buffers, but for non-cpu device memory (CUDA, etc.). Key features: - Static factory create(size, type, index) looks up DeviceAllocator from registry - Move-only semantics (no copy) to enforce single ownership - as_span() accessor wraps device pointer for use with HierarchicalAllocator - Destructor is no-op for default-constructed or moved-from instances Differential Revision: [D97850709](https://our.internmc.facebook.com/intern/diff/D97850709/) ghstack-source-id: 357060894 Pull Request resolved: #18473
1 parent a807e51 commit 47a8b76

6 files changed

Lines changed: 365 additions & 0 deletions

File tree

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/core/device_memory_buffer.h>
10+
11+
namespace executorch::runtime {
12+
13+
Result<DeviceMemoryBuffer> DeviceMemoryBuffer::create(
14+
size_t size,
15+
etensor::DeviceType type,
16+
etensor::DeviceIndex index) {
17+
DeviceAllocator* allocator = get_device_allocator(type);
18+
if (allocator == nullptr) {
19+
ET_LOG(
20+
Error,
21+
"No device allocator registered for device type %d",
22+
static_cast<int>(type));
23+
return Error::NotFound;
24+
}
25+
26+
auto result = allocator->allocate(size, index);
27+
if (!result.ok()) {
28+
return result.error();
29+
}
30+
31+
return DeviceMemoryBuffer(result.get(), size, allocator, index);
32+
}
33+
34+
} // namespace executorch::runtime
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <cstddef>
12+
#include <cstdint>
13+
14+
#include <executorch/runtime/core/device_allocator.h>
15+
#include <executorch/runtime/core/result.h>
16+
#include <executorch/runtime/core/span.h>
17+
18+
namespace executorch::runtime {
19+
20+
/**
21+
* RAII wrapper that owns a single device memory allocation.
22+
*
23+
* On destruction, calls DeviceAllocator::deallocate() to free the memory.
24+
* This mirrors the role of std::vector<uint8_t> for CPU planned buffers,
25+
* but for device memory (CUDA, etc.).
26+
*
27+
* Move-only: cannot be copied, but can be moved to transfer ownership.
28+
*/
29+
class DeviceMemoryBuffer final {
30+
public:
31+
/**
32+
* Creates a DeviceMemoryBuffer by allocating device memory.
33+
*
34+
* Looks up the DeviceAllocator for the given device type via the
35+
* DeviceAllocatorRegistry. If no allocator is registered for the type,
36+
* returns Error::NotFound.
37+
*
38+
* @param size Number of bytes to allocate.
39+
* @param type The device type (e.g., CUDA).
40+
* @param index The device index (e.g., 0 for cuda:0).
41+
* @return A Result containing the DeviceMemoryBuffer on success, or an error.
42+
*/
43+
static Result<DeviceMemoryBuffer> create(
44+
size_t size,
45+
etensor::DeviceType type,
46+
etensor::DeviceIndex index = 0);
47+
48+
DeviceMemoryBuffer() = default;
49+
50+
~DeviceMemoryBuffer() {
51+
if (ptr_ != nullptr && allocator_ != nullptr) {
52+
allocator_->deallocate(ptr_, device_index_);
53+
}
54+
}
55+
56+
// Move constructor: transfer ownership.
57+
DeviceMemoryBuffer(DeviceMemoryBuffer&& other) noexcept
58+
: ptr_(other.ptr_),
59+
size_(other.size_),
60+
allocator_(other.allocator_),
61+
device_index_(other.device_index_) {
62+
other.ptr_ = nullptr;
63+
other.size_ = 0;
64+
other.allocator_ = nullptr;
65+
}
66+
67+
// Move assignment: release current, take ownership.
68+
DeviceMemoryBuffer& operator=(DeviceMemoryBuffer&& other) noexcept {
69+
if (this != &other) {
70+
if (ptr_ != nullptr && allocator_ != nullptr) {
71+
allocator_->deallocate(ptr_, device_index_);
72+
}
73+
ptr_ = other.ptr_;
74+
size_ = other.size_;
75+
allocator_ = other.allocator_;
76+
device_index_ = other.device_index_;
77+
other.ptr_ = nullptr;
78+
other.size_ = 0;
79+
other.allocator_ = nullptr;
80+
}
81+
return *this;
82+
}
83+
84+
// Non-copyable.
85+
DeviceMemoryBuffer(const DeviceMemoryBuffer&) = delete;
86+
DeviceMemoryBuffer& operator=(const DeviceMemoryBuffer&) = delete;
87+
88+
/// Returns the device pointer, or nullptr if empty/moved-from.
89+
void* data() const {
90+
return ptr_;
91+
}
92+
93+
/// Returns the size in bytes of the allocation.
94+
size_t size() const {
95+
return size_;
96+
}
97+
98+
/**
99+
* Returns a Span<uint8_t> wrapping the device pointer.
100+
*
101+
* This is intended for use with HierarchicalAllocator, which only performs
102+
* pointer arithmetic on the span data and never dereferences it. Device
103+
* pointers are valid for pointer arithmetic from the CPU side.
104+
*/
105+
Span<uint8_t> as_span() const {
106+
return {static_cast<uint8_t*>(ptr_), size_};
107+
}
108+
109+
private:
110+
DeviceMemoryBuffer(
111+
void* ptr,
112+
size_t size,
113+
DeviceAllocator* allocator,
114+
etensor::DeviceIndex device_index)
115+
: ptr_(ptr),
116+
size_(size),
117+
allocator_(allocator),
118+
device_index_(device_index) {}
119+
120+
void* ptr_ = nullptr;
121+
size_t size_ = 0;
122+
DeviceAllocator* allocator_ = nullptr;
123+
etensor::DeviceIndex device_index_ = 0;
124+
};
125+
126+
} // namespace executorch::runtime

runtime/core/portable_type/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def define_common_targets():
2727
"//executorch/backends/...",
2828
"//executorch/extension/fb/dynamic_shim/...",
2929
"//executorch/kernels/portable/cpu/...",
30+
"//executorch/runtime/core/...",
3031
"//executorch/runtime/core/exec_aten/...",
3132
"//executorch/runtime/core/portable_type/test/...",
3233
],

runtime/core/targets.bzl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,33 @@ def define_common_targets():
141141
visibility = ["//executorch/..."],
142142
)
143143

144+
runtime.cxx_library(
145+
name = "device_allocator",
146+
srcs = ["device_allocator.cpp"],
147+
exported_headers = [
148+
"device_allocator.h",
149+
],
150+
exported_deps = [
151+
":core",
152+
"//executorch/runtime/core/portable_type:portable_type",
153+
],
154+
deps = [
155+
"//executorch/runtime/platform:platform",
156+
],
157+
visibility = ["PUBLIC"],
158+
)
159+
160+
runtime.cxx_library(
161+
name = "device_memory_buffer",
162+
srcs = ["device_memory_buffer.cpp"],
163+
exported_headers = ["device_memory_buffer.h"],
164+
exported_deps = [
165+
":core",
166+
":device_allocator",
167+
],
168+
visibility = ["PUBLIC"],
169+
)
170+
144171
runtime.cxx_library(
145172
name = "tag",
146173
srcs = ["tag.cpp"],
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/core/device_memory_buffer.h>
10+
11+
#include <gtest/gtest.h>
12+
13+
#include <executorch/runtime/platform/runtime.h>
14+
15+
using executorch::runtime::DeviceAllocator;
16+
using executorch::runtime::DeviceMemoryBuffer;
17+
using executorch::runtime::Error;
18+
using executorch::runtime::Result;
19+
using executorch::runtime::get_device_allocator;
20+
using executorch::runtime::register_device_allocator;
21+
using executorch::runtime::etensor::DeviceIndex;
22+
using executorch::runtime::etensor::DeviceType;
23+
24+
/**
25+
* A mock DeviceAllocator for testing DeviceMemoryBuffer.
26+
* Returns pointers into a local buffer and tracks call counts.
27+
*/
28+
class MockAllocator : public DeviceAllocator {
29+
public:
30+
explicit MockAllocator(DeviceType type) : type_(type) {}
31+
32+
Result<void*> allocate(size_t nbytes, DeviceIndex index) override {
33+
allocate_count_++;
34+
last_allocate_size_ = nbytes;
35+
return static_cast<void*>(buffer_);
36+
}
37+
38+
void deallocate(void* ptr, DeviceIndex index) override {
39+
deallocate_count_++;
40+
last_deallocate_ptr_ = ptr;
41+
}
42+
43+
Error copy_host_to_device(
44+
void* dst,
45+
const void* src,
46+
size_t nbytes,
47+
DeviceIndex index) override {
48+
return Error::Ok;
49+
}
50+
51+
Error copy_device_to_host(
52+
void* dst,
53+
const void* src,
54+
size_t nbytes,
55+
DeviceIndex index) override {
56+
return Error::Ok;
57+
}
58+
59+
DeviceType device_type() const override {
60+
return type_;
61+
}
62+
63+
int allocate_count_ = 0;
64+
int deallocate_count_ = 0;
65+
size_t last_allocate_size_ = 0;
66+
void* last_deallocate_ptr_ = nullptr;
67+
uint8_t buffer_[256] = {};
68+
69+
private:
70+
DeviceType type_;
71+
};
72+
73+
// Global mock registered once before all tests run.
74+
static MockAllocator g_mock_cuda(DeviceType::CUDA);
75+
76+
class DeviceMemoryBufferTest : public ::testing::Test {
77+
protected:
78+
static void SetUpTestSuite() {
79+
executorch::runtime::runtime_init();
80+
register_device_allocator(DeviceType::CUDA, &g_mock_cuda);
81+
}
82+
83+
void SetUp() override {
84+
// Reset counters before each test.
85+
g_mock_cuda.allocate_count_ = 0;
86+
g_mock_cuda.deallocate_count_ = 0;
87+
g_mock_cuda.last_allocate_size_ = 0;
88+
g_mock_cuda.last_deallocate_ptr_ = nullptr;
89+
}
90+
};
91+
92+
TEST_F(DeviceMemoryBufferTest, DefaultConstructedIsEmpty) {
93+
DeviceMemoryBuffer buf;
94+
EXPECT_EQ(buf.data(), nullptr);
95+
EXPECT_EQ(buf.size(), 0);
96+
97+
auto span = buf.as_span();
98+
EXPECT_EQ(span.data(), nullptr);
99+
EXPECT_EQ(span.size(), 0);
100+
}
101+
102+
TEST_F(DeviceMemoryBufferTest, CreateAllocatesAndDestructorDeallocates) {
103+
{
104+
auto result = DeviceMemoryBuffer::create(1024, DeviceType::CUDA, 0);
105+
ASSERT_TRUE(result.ok());
106+
107+
auto buf = std::move(result.get());
108+
EXPECT_NE(buf.data(), nullptr);
109+
EXPECT_EQ(buf.size(), 1024);
110+
EXPECT_EQ(g_mock_cuda.allocate_count_, 1);
111+
EXPECT_EQ(g_mock_cuda.last_allocate_size_, 1024);
112+
EXPECT_EQ(g_mock_cuda.deallocate_count_, 0);
113+
}
114+
EXPECT_EQ(g_mock_cuda.deallocate_count_, 1);
115+
EXPECT_EQ(g_mock_cuda.last_deallocate_ptr_, g_mock_cuda.buffer_);
116+
}
117+
118+
TEST_F(DeviceMemoryBufferTest, CreateFailsWithNoRegisteredAllocator) {
119+
auto result = DeviceMemoryBuffer::create(512, DeviceType::CPU, 0);
120+
EXPECT_FALSE(result.ok());
121+
EXPECT_EQ(result.error(), Error::NotFound);
122+
}
123+
124+
TEST_F(DeviceMemoryBufferTest, MoveConstructorTransfersOwnership) {
125+
auto result = DeviceMemoryBuffer::create(256, DeviceType::CUDA, 0);
126+
ASSERT_TRUE(result.ok());
127+
auto original = std::move(result.get());
128+
void* original_ptr = original.data();
129+
130+
DeviceMemoryBuffer moved(std::move(original));
131+
132+
EXPECT_EQ(original.data(), nullptr);
133+
EXPECT_EQ(original.size(), 0);
134+
EXPECT_EQ(moved.data(), original_ptr);
135+
EXPECT_EQ(moved.size(), 256);
136+
EXPECT_EQ(g_mock_cuda.deallocate_count_, 0);
137+
}
138+
139+
TEST_F(DeviceMemoryBufferTest, MoveAssignmentTransfersOwnership) {
140+
auto result = DeviceMemoryBuffer::create(128, DeviceType::CUDA, 0);
141+
ASSERT_TRUE(result.ok());
142+
auto original = std::move(result.get());
143+
void* original_ptr = original.data();
144+
145+
DeviceMemoryBuffer target;
146+
target = std::move(original);
147+
148+
EXPECT_EQ(original.data(), nullptr);
149+
EXPECT_EQ(target.data(), original_ptr);
150+
EXPECT_EQ(target.size(), 128);
151+
EXPECT_EQ(g_mock_cuda.deallocate_count_, 0);
152+
}
153+
154+
TEST_F(DeviceMemoryBufferTest, DestructorNoOpForDefaultConstructed) {
155+
{
156+
DeviceMemoryBuffer buf;
157+
}
158+
EXPECT_EQ(g_mock_cuda.deallocate_count_, 0);
159+
}
160+
161+
TEST_F(DeviceMemoryBufferTest, AsSpanWrapsDevicePointer) {
162+
auto result = DeviceMemoryBuffer::create(2048, DeviceType::CUDA, 0);
163+
ASSERT_TRUE(result.ok());
164+
auto buf = std::move(result.get());
165+
166+
auto span = buf.as_span();
167+
EXPECT_EQ(span.data(), static_cast<uint8_t*>(buf.data()));
168+
EXPECT_EQ(span.size(), 2048);
169+
}

runtime/core/test/targets.bzl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@ def define_common_targets():
77
TARGETS and BUCK files that call this function.
88
"""
99

10+
runtime.cxx_test(
11+
name = "device_memory_buffer_test",
12+
srcs = ["device_memory_buffer_test.cpp"],
13+
deps = [
14+
"//executorch/runtime/core:device_memory_buffer",
15+
],
16+
)
17+
1018
runtime.cxx_test(
1119
name = "span_test",
1220
srcs = ["span_test.cpp"],

0 commit comments

Comments
 (0)