Skip to content

Commit 664abf8

Browse files
authored
[ET Device Support] DeviceAllocator interface and DeviceAllocatorRegistry (#19496)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #18474 * #18473 * #18472 * #18375 * #19497 * __->__ #19496 This diff introduces the `DeviceAllocator` abstract interface and `DeviceAllocatorRegistry` for device-specific memory allocation. This is a foundational abstraction that enables the runtime to dispatch memory operations to the appropriate device backend other than CPU (CUDA, etc.). **DeviceAllocator interface provides:** - `allocate()` / `deallocate()` - Dynamic device memory allocation - `copy_host_to_device()` / `copy_device_to_host()` - Data transfer between host and device - `device_type()` - Returns the device type this allocator handles **DeviceAllocatorRegistry provides:** - Singleton registry mapping DeviceType → DeviceAllocator - `register_allocator()` / `get_allocator()` methods - Fixed-size array indexed by device type (no dynamic allocation, embedded-friendly) **Design notes:** - Registry stores raw pointers (non-owning) - allocators are expected to be singletons with static lifetime - Follows ExecuTorch's embedded-first philosophy (no std::unique_ptr, no heap allocation in registry) - Convenience free functions `register_device_allocator()` and `get_device_allocator()` for ease of use Differential Revision: [D93635656](https://our.internmc.facebook.com/intern/diff/D93635656/)
1 parent d306410 commit 664abf8

6 files changed

Lines changed: 521 additions & 1 deletion

File tree

runtime/core/device_allocator.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/core/device_allocator.h>
10+
11+
#include <executorch/runtime/platform/assert.h>
12+
13+
namespace executorch {
14+
namespace runtime {
15+
16+
DeviceAllocatorRegistry& DeviceAllocatorRegistry::instance() {
17+
static DeviceAllocatorRegistry registry;
18+
return registry;
19+
}
20+
21+
void DeviceAllocatorRegistry::register_allocator(DeviceAllocator* alloc) {
22+
ET_CHECK_MSG(alloc != nullptr, "Cannot register a null allocator");
23+
auto type = alloc->device_type();
24+
auto index = static_cast<size_t>(type);
25+
ET_CHECK_MSG(
26+
index < etensor::kNumDeviceTypes,
27+
"Invalid device type: %d",
28+
static_cast<int>(type));
29+
ET_CHECK_MSG(
30+
allocators_[index] == nullptr,
31+
"Allocator already registered for device type: %d",
32+
static_cast<int>(type));
33+
allocators_[index] = alloc;
34+
}
35+
36+
DeviceAllocator* DeviceAllocatorRegistry::get_allocator(
37+
etensor::DeviceType type) {
38+
auto index = static_cast<size_t>(type);
39+
if (index >= etensor::kNumDeviceTypes) {
40+
return nullptr;
41+
}
42+
return allocators_[index];
43+
}
44+
45+
// Convenience free functions
46+
47+
void register_device_allocator(DeviceAllocator* alloc) {
48+
DeviceAllocatorRegistry::instance().register_allocator(alloc);
49+
}
50+
51+
DeviceAllocator* get_device_allocator(etensor::DeviceType type) {
52+
return DeviceAllocatorRegistry::instance().get_allocator(type);
53+
}
54+
55+
} // namespace runtime
56+
} // namespace executorch

runtime/core/device_allocator.h

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <cstddef>
12+
13+
#include <executorch/runtime/core/error.h>
14+
#include <executorch/runtime/core/memory_allocator.h>
15+
#include <executorch/runtime/core/portable_type/device.h>
16+
#include <executorch/runtime/core/result.h>
17+
18+
namespace executorch {
19+
namespace runtime {
20+
21+
/**
22+
* Abstract interface for device-specific memory allocation.
23+
*
24+
* Each device type (CUDA, etc.) provides a concrete implementation
25+
* that handles memory allocation on that device. Implementations are
26+
* expected to be singletons with static lifetime, registered via
27+
* DeviceAllocatorRegistry.
28+
*/
29+
class DeviceAllocator {
30+
public:
31+
/**
32+
* Default alignment of memory returned by allocate(). Reuses
33+
* MemoryAllocator::kDefaultAlignment so host- and device-side allocations
34+
* share the same baseline contract. Backends whose underlying device APIs
35+
* already provide stronger guarantees (e.g. cudaMalloc returns 256-byte
36+
* aligned pointers) will trivially satisfy this.
37+
*/
38+
static constexpr size_t kDefaultAlignment =
39+
MemoryAllocator::kDefaultAlignment;
40+
41+
virtual ~DeviceAllocator() = default;
42+
/**
43+
* Allocate device memory.
44+
*
45+
* @param nbytes Number of bytes to allocate.
46+
* @param index The device index.
47+
* @param alignment Minimum alignment of the returned pointer in bytes.
48+
* Must be a power of 2. Defaults to kDefaultAlignment.
49+
* @return A Result containing the device pointer on success, or an error.
50+
*/
51+
virtual Result<void*> allocate(
52+
size_t nbytes,
53+
etensor::DeviceIndex index,
54+
size_t alignment = kDefaultAlignment) = 0;
55+
56+
/**
57+
* Deallocate device memory previously allocated via allocate().
58+
*
59+
* @param ptr Pointer to the memory to deallocate.
60+
* @param index The device index.
61+
*/
62+
virtual void deallocate(void* ptr, etensor::DeviceIndex index) = 0;
63+
64+
/**
65+
* Copy data from host memory to device memory.
66+
*
67+
* @param dst Destination pointer (device memory).
68+
* @param src Source pointer (host memory).
69+
* @param nbytes Number of bytes to copy.
70+
* @param index The device index.
71+
* @return Error::Ok on success, or an appropriate error code on failure.
72+
*/
73+
virtual Error copy_host_to_device(
74+
void* dst,
75+
const void* src,
76+
size_t nbytes,
77+
etensor::DeviceIndex index) = 0;
78+
79+
/**
80+
* Copy data from device memory to host memory.
81+
*
82+
* @param dst Destination pointer (host memory).
83+
* @param src Source pointer (device memory).
84+
* @param nbytes Number of bytes to copy.
85+
* @param index The device index.
86+
* @return Error::Ok on success, or an appropriate error code on failure.
87+
*/
88+
virtual Error copy_device_to_host(
89+
void* dst,
90+
const void* src,
91+
size_t nbytes,
92+
etensor::DeviceIndex index) = 0;
93+
94+
/**
95+
* Returns the device type this allocator handles.
96+
*/
97+
virtual etensor::DeviceType device_type() const = 0;
98+
};
99+
100+
/**
101+
* Registry for device allocators.
102+
*
103+
* Provides a global mapping from DeviceType to DeviceAllocator instances.
104+
* Device allocators register themselves at static initialization time,
105+
* and the runtime queries the registry to find the appropriate allocator
106+
* for a given device type.
107+
*
108+
* Threading contract:
109+
* - Registration is expected to happen once per device type during static
110+
* initialization (single-threaded). The registry itself does not perform
111+
* any locking around register_allocator()/get_allocator(), and concurrent
112+
* registration is not supported.
113+
* - After registration, get_allocator() is safe to call concurrently from
114+
* multiple threads because the underlying array is never mutated again.
115+
* - The DeviceAllocator implementation is responsible for its own
116+
* thread-safety. When multiple Programs are loaded concurrently and each
117+
* needs device memory, the allocator must serialize access to any shared
118+
* state internally (similar to how XNNPACK's weight cache guards its
119+
* internal state). The registry does not provide any synchronization on
120+
* behalf of the allocator.
121+
*/
122+
class DeviceAllocatorRegistry {
123+
public:
124+
/**
125+
* Returns the singleton instance of the registry.
126+
*/
127+
static DeviceAllocatorRegistry& instance();
128+
129+
/**
130+
* Register an allocator. The device type is taken from
131+
* alloc->device_type(). Each device type may only be registered once;
132+
* attempting to register a second allocator for the same device type
133+
* will abort.
134+
*
135+
* Not thread-safe. Expected to be called during static initialization.
136+
*
137+
* @param alloc Pointer to the allocator (must have static lifetime).
138+
*/
139+
void register_allocator(DeviceAllocator* alloc);
140+
141+
/**
142+
* Get the allocator for a specific device type.
143+
*
144+
* Safe to call concurrently with other get_allocator() calls.
145+
*
146+
* @param type The device type.
147+
* @return Pointer to the allocator, or nullptr if not registered.
148+
*/
149+
DeviceAllocator* get_allocator(etensor::DeviceType type);
150+
151+
private:
152+
DeviceAllocatorRegistry() = default;
153+
154+
// Singletons must not be copied or moved; instance() returns a reference,
155+
// and silently shallow-copying the registry would lead to confusing bugs
156+
// where modifications to the copy don't affect the real singleton.
157+
DeviceAllocatorRegistry(const DeviceAllocatorRegistry&) = delete;
158+
DeviceAllocatorRegistry& operator=(const DeviceAllocatorRegistry&) = delete;
159+
DeviceAllocatorRegistry(DeviceAllocatorRegistry&&) = delete;
160+
DeviceAllocatorRegistry& operator=(DeviceAllocatorRegistry&&) = delete;
161+
162+
// Fixed-size array indexed by device type. This avoids dynamic allocation
163+
// and is suitable for embedded environments.
164+
DeviceAllocator* allocators_[etensor::kNumDeviceTypes] = {};
165+
};
166+
167+
// Convenience free functions
168+
169+
/**
170+
* Register a device allocator. The device type is taken from
171+
* alloc->device_type(). See DeviceAllocatorRegistry::register_allocator()
172+
* for the threading contract.
173+
*
174+
* @param alloc Pointer to the allocator (must have static lifetime).
175+
*/
176+
void register_device_allocator(DeviceAllocator* alloc);
177+
178+
/**
179+
* Get the device allocator for a specific device type.
180+
*
181+
* @param type The device type.
182+
* @return Pointer to the allocator, or nullptr if not registered.
183+
*/
184+
DeviceAllocator* get_device_allocator(etensor::DeviceType type);
185+
186+
} // namespace runtime
187+
} // namespace executorch
188+
189+
namespace torch {
190+
namespace executor {
191+
// TODO(T197294990): Remove these deprecated aliases once all users have moved
192+
// to the new `::executorch` namespaces.
193+
using ::executorch::runtime::DeviceAllocator;
194+
using ::executorch::runtime::DeviceAllocatorRegistry;
195+
using ::executorch::runtime::get_device_allocator;
196+
using ::executorch::runtime::register_device_allocator;
197+
} // namespace executor
198+
} // namespace torch

runtime/core/portable_type/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def define_common_targets():
2727
"//executorch/backends/...",
2828
"//executorch/extension/fb/dynamic_shim/...",
2929
"//executorch/kernels/portable/cpu/...",
30+
"//executorch/runtime/core:device_allocator",
3031
"//executorch/runtime/core/exec_aten/...",
3132
"//executorch/runtime/core/portable_type/test/...",
3233
],

runtime/core/targets.bzl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,20 @@ def define_common_targets():
8282
visibility = ["PUBLIC"],
8383
)
8484

85+
runtime.cxx_library(
86+
name = "device_allocator",
87+
srcs = ["device_allocator.cpp"],
88+
exported_headers = [
89+
"device_allocator.h",
90+
],
91+
exported_deps = [
92+
":core",
93+
":memory_allocator",
94+
"//executorch/runtime/core/exec_aten:lib",
95+
],
96+
visibility = ["PUBLIC"],
97+
)
98+
8599
for aten_mode in get_aten_mode_options():
86100
aten_suffix = ("_aten" if aten_mode else "")
87101
runtime.cxx_library(

0 commit comments

Comments
 (0)