Skip to content

Commit 4d743fc

Browse files
committed
[ET Device Support] Define et_copy runtime h2d and d2h copy ops
Pull Request resolved: #18729 Implement C++ runtime kernels for device copy ops using DeviceAllocator: - h2d_copy_out: infers device from out tensor, calls DeviceAllocator::copy_host_to_device - d2h_copy_out: infers device from self tensor, calls DeviceAllocator::copy_device_to_host - Registered via EXECUTORCH_LIBRARY macro ghstack-source-id: 364093609 @exported-using-ghexport Differential Revision: [D99636776](https://our.internmc.facebook.com/intern/diff/D99636776/)
1 parent 01aca04 commit 4d743fc

4 files changed

Lines changed: 480 additions & 0 deletions

File tree

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
/**
10+
* Runtime kernels for et_copy._h2d_copy and et_copy._d2h_copy ops.
11+
*
12+
* These ops transfer tensor data between CPU and device memory using
13+
* the DeviceAllocator interface. The device type is inferred from the
14+
* tensor metadata (out.device_type() for H2D, self.device_type() for D2H),
15+
* which was set during AOT serialization by PropagateDevicePass.
16+
*/
17+
18+
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
19+
#include <executorch/runtime/core/device_allocator.h>
20+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
21+
#include <executorch/runtime/kernel/kernel_includes.h>
22+
23+
namespace executorch::runtime::native {
24+
25+
using executorch::aten::Tensor;
26+
using executorch::runtime::KernelRuntimeContext;
27+
28+
/**
29+
* Copies tensor data from host (CPU) memory to device memory.
30+
*
31+
* self: source tensor on CPU
32+
* out: destination tensor on device (memory-planned by runtime)
33+
*
34+
* The device type and index are inferred from out's TensorImpl metadata.
35+
*/
36+
Tensor&
37+
_h2d_copy_out(KernelRuntimeContext& ctx, const Tensor& self, Tensor& out) {
38+
auto device_type = out.unsafeGetTensorImpl()->device_type();
39+
auto device_index = out.unsafeGetTensorImpl()->device_index();
40+
41+
ET_KERNEL_CHECK_MSG(
42+
ctx,
43+
self.unsafeGetTensorImpl()->device_type() == etensor::DeviceType::CPU,
44+
InvalidArgument,
45+
out,
46+
"_h2d_copy: source tensor must be on CPU, got device_type=%d",
47+
static_cast<int>(self.unsafeGetTensorImpl()->device_type()));
48+
49+
ET_KERNEL_CHECK_MSG(
50+
ctx,
51+
device_type != etensor::DeviceType::CPU,
52+
InvalidArgument,
53+
out,
54+
"_h2d_copy: destination tensor must be on a non-CPU device");
55+
56+
auto nbytes = self.nbytes();
57+
ET_KERNEL_CHECK_MSG(
58+
ctx,
59+
nbytes == out.nbytes(),
60+
InvalidArgument,
61+
out,
62+
"_h2d_copy: size mismatch: self.nbytes()=%zu, out.nbytes()=%zu",
63+
nbytes,
64+
out.nbytes());
65+
66+
DeviceAllocator* allocator = get_device_allocator(device_type);
67+
ET_KERNEL_CHECK_MSG(
68+
ctx,
69+
allocator != nullptr,
70+
NotFound,
71+
out,
72+
"_h2d_copy: no device allocator registered for device_type=%d",
73+
static_cast<int>(device_type));
74+
75+
Error err = allocator->copy_host_to_device(
76+
out.mutable_data_ptr(), self.const_data_ptr(), nbytes, device_index);
77+
ET_KERNEL_CHECK_MSG(
78+
ctx,
79+
err == Error::Ok,
80+
Internal,
81+
out,
82+
"_h2d_copy: copy_host_to_device failed");
83+
84+
return out;
85+
}
86+
87+
/**
88+
* Copies tensor data from device memory to host (CPU) memory.
89+
*
90+
* self: source tensor on device
91+
* out: destination tensor on CPU (memory-planned by runtime)
92+
*
93+
* The device type and index are inferred from self's TensorImpl metadata.
94+
*/
95+
Tensor&
96+
_d2h_copy_out(KernelRuntimeContext& ctx, const Tensor& self, Tensor& out) {
97+
auto device_type = self.unsafeGetTensorImpl()->device_type();
98+
auto device_index = self.unsafeGetTensorImpl()->device_index();
99+
100+
ET_KERNEL_CHECK_MSG(
101+
ctx,
102+
device_type != etensor::DeviceType::CPU,
103+
InvalidArgument,
104+
out,
105+
"_d2h_copy: source tensor must be on a non-CPU device");
106+
107+
ET_KERNEL_CHECK_MSG(
108+
ctx,
109+
out.unsafeGetTensorImpl()->device_type() == etensor::DeviceType::CPU,
110+
InvalidArgument,
111+
out,
112+
"_d2h_copy: destination tensor must be on CPU, got device_type=%d",
113+
static_cast<int>(out.unsafeGetTensorImpl()->device_type()));
114+
115+
auto nbytes = self.nbytes();
116+
ET_KERNEL_CHECK_MSG(
117+
ctx,
118+
nbytes == out.nbytes(),
119+
InvalidArgument,
120+
out,
121+
"_d2h_copy: size mismatch: self.nbytes()=%zu, out.nbytes()=%zu",
122+
nbytes,
123+
out.nbytes());
124+
125+
DeviceAllocator* allocator = get_device_allocator(device_type);
126+
ET_KERNEL_CHECK_MSG(
127+
ctx,
128+
allocator != nullptr,
129+
NotFound,
130+
out,
131+
"_d2h_copy: no device allocator registered for device_type=%d",
132+
static_cast<int>(device_type));
133+
134+
Error err = allocator->copy_device_to_host(
135+
out.mutable_data_ptr(), self.const_data_ptr(), nbytes, device_index);
136+
ET_KERNEL_CHECK_MSG(
137+
ctx,
138+
err == Error::Ok,
139+
Internal,
140+
out,
141+
"_d2h_copy: copy_device_to_host failed");
142+
143+
return out;
144+
}
145+
146+
} // namespace executorch::runtime::native
147+
148+
EXECUTORCH_LIBRARY(
149+
et_copy,
150+
"_h2d_copy.out",
151+
executorch::runtime::native::_h2d_copy_out);
152+
EXECUTORCH_LIBRARY(
153+
et_copy,
154+
"_d2h_copy.out",
155+
executorch::runtime::native::_d2h_copy_out);

kernels/portable/cpu/targets.bzl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,24 @@ def define_common_targets():
7575
],
7676
)
7777

78+
# Device copy ops (h2d_copy, d2h_copy) for transferring data between
79+
# CPU and device memory. Uses DeviceAllocator interface.
80+
runtime.cxx_library(
81+
name = "op__device_copy",
82+
srcs = ["op__device_copy.cpp"],
83+
visibility = ["PUBLIC"],
84+
# Constructor needed for op registration.
85+
compiler_flags = ["-Wno-global-constructors"],
86+
deps = [
87+
"//executorch/runtime/core:device_allocator",
88+
"//executorch/runtime/core/exec_aten:lib",
89+
"//executorch/runtime/kernel:kernel_includes",
90+
"//executorch/extension/kernel_util:kernel_util",
91+
],
92+
# @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole)
93+
link_whole = True,
94+
)
95+
7896
# Used for dtype selective build. Collect source and header files.
7997
runtime.filegroup(
8098
name = "portable_source_files",

0 commit comments

Comments
 (0)