@@ -41,6 +41,8 @@ using TensorPtr = std::shared_ptr<executorch::aten::Tensor>;
4141 * @param deleter A custom deleter function for managing the lifetime of the
4242 * data buffer. If provided, this deleter will be called when the managed Tensor
4343 * object is destroyed.
44+ * @param device_type The target device type (default CPU, meaning no copy).
45+ * @param device_index The target device index (default 0).
4446 * @return A TensorPtr that manages the newly created Tensor.
4547 */
4648TensorPtr make_tensor_ptr (
@@ -52,7 +54,10 @@ TensorPtr make_tensor_ptr(
5254 executorch::aten::ScalarType::Float,
5355 const executorch::aten::TensorShapeDynamism dynamism =
5456 executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND,
55- std::function<void (void *)> deleter = nullptr);
57+ std::function<void (void *)> deleter = nullptr,
58+ runtime::etensor::DeviceType device_type =
59+ runtime::etensor::DeviceType::CPU,
60+ runtime::etensor::DeviceIndex device_index = 0);
5661
5762/* *
5863 * Creates a TensorPtr that manages a Tensor with the specified properties.
@@ -64,6 +69,8 @@ TensorPtr make_tensor_ptr(
6469 * @param deleter A custom deleter function for managing the lifetime of the
6570 * data buffer. If provided, this deleter will be called when the managed Tensor
6671 * object is destroyed.
72+ * @param device_type The target device type (default CPU, meaning no copy).
73+ * @param device_index The target device index (default 0).
6774 * @return A TensorPtr that manages the newly created Tensor.
6875 */
6976inline TensorPtr make_tensor_ptr (
@@ -73,9 +80,20 @@ inline TensorPtr make_tensor_ptr(
7380 executorch::aten::ScalarType::Float,
7481 const executorch::aten::TensorShapeDynamism dynamism =
7582 executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND,
76- std::function<void (void *)> deleter = nullptr) {
83+ std::function<void (void *)> deleter = nullptr,
84+ runtime::etensor::DeviceType device_type =
85+ runtime::etensor::DeviceType::CPU,
86+ runtime::etensor::DeviceIndex device_index = 0) {
7787 return make_tensor_ptr (
78- std::move (sizes), data, {}, {}, type, dynamism, std::move (deleter));
88+ std::move (sizes),
89+ data,
90+ {},
91+ {},
92+ type,
93+ dynamism,
94+ std::move (deleter),
95+ device_type,
96+ device_index);
7997}
8098
8199/* *
@@ -96,6 +114,8 @@ inline TensorPtr make_tensor_ptr(
96114 * @param type The scalar type of the tensor elements. If it differs from the
97115 * deduced type, the data will be cast to this type if allowed.
98116 * @param dynamism Specifies the mutability of the tensor's shape.
117+ * @param device_type The target device type (default CPU, meaning no copy).
118+ * @param device_index The target device index (default 0).
99119 * @return A TensorPtr that manages the newly created TensorImpl.
100120 */
101121template <
@@ -109,7 +129,10 @@ inline TensorPtr make_tensor_ptr(
109129 std::vector<executorch::aten::StridesType> strides = {},
110130 executorch::aten::ScalarType type = deduced_type,
111131 executorch::aten::TensorShapeDynamism dynamism =
112- executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) {
132+ executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND,
133+ runtime::etensor::DeviceType device_type =
134+ runtime::etensor::DeviceType::CPU,
135+ runtime::etensor::DeviceIndex device_index = 0 ) {
113136 ET_CHECK_MSG (
114137 data.size () ==
115138 executorch::aten::compute_numel (sizes.data (), sizes.size ()),
@@ -145,7 +168,9 @@ inline TensorPtr make_tensor_ptr(
145168 std::move (strides),
146169 type,
147170 dynamism,
148- [data_ptr = std::move (data_ptr)](void *) {});
171+ [data_ptr = std::move (data_ptr)](void *) {},
172+ device_type,
173+ device_index);
149174 }
150175 const auto raw_data_ptr = data.data ();
151176 auto data_ptr = std::make_shared<std::vector<T>>(std::move (data));
@@ -156,7 +181,9 @@ inline TensorPtr make_tensor_ptr(
156181 std::move (strides),
157182 type,
158183 dynamism,
159- [data_ptr = std::move (data_ptr)](void *) {});
184+ [data_ptr = std::move (data_ptr)](void *) {},
185+ device_type,
186+ device_index);
160187}
161188
162189/* *
@@ -174,6 +201,8 @@ inline TensorPtr make_tensor_ptr(
174201 * @param type The scalar type of the tensor elements. If it differs from the
175202 * deduced type, the data will be cast to this type if allowed.
176203 * @param dynamism Specifies the mutability of the tensor's shape.
204+ * @param device_type The target device type (default CPU, meaning no copy).
205+ * @param device_index The target device index (default 0).
177206 * @return A TensorPtr that manages the newly created TensorImpl.
178207 */
179208template <
@@ -184,11 +213,21 @@ inline TensorPtr make_tensor_ptr(
184213 std::vector<T> data,
185214 executorch::aten::ScalarType type = deduced_type,
186215 executorch::aten::TensorShapeDynamism dynamism =
187- executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) {
216+ executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND,
217+ runtime::etensor::DeviceType device_type =
218+ runtime::etensor::DeviceType::CPU,
219+ runtime::etensor::DeviceIndex device_index = 0 ) {
188220 std::vector<executorch::aten::SizesType> sizes{
189221 executorch::aten::SizesType (data.size ())};
190222 return make_tensor_ptr (
191- std::move (sizes), std::move (data), {0 }, {1 }, type, dynamism);
223+ std::move (sizes),
224+ std::move (data),
225+ {0 },
226+ {1 },
227+ type,
228+ dynamism,
229+ device_type,
230+ device_index);
192231}
193232
194233/* *
@@ -211,6 +250,8 @@ inline TensorPtr make_tensor_ptr(
211250 * @param type The scalar type of the tensor elements. If it differs from the
212251 * deduced type, the data will be cast to this type if allowed.
213252 * @param dynamism Specifies the mutability of the tensor's shape.
253+ * @param device_type The target device type (default CPU, meaning no copy).
254+ * @param device_index The target device index (default 0).
214255 * @return A TensorPtr that manages the newly created TensorImpl.
215256 */
216257template <
@@ -224,14 +265,19 @@ inline TensorPtr make_tensor_ptr(
224265 std::vector<executorch::aten::StridesType> strides = {},
225266 executorch::aten::ScalarType type = deduced_type,
226267 executorch::aten::TensorShapeDynamism dynamism =
227- executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) {
268+ executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND,
269+ runtime::etensor::DeviceType device_type =
270+ runtime::etensor::DeviceType::CPU,
271+ runtime::etensor::DeviceIndex device_index = 0 ) {
228272 return make_tensor_ptr (
229273 std::move (sizes),
230274 std::vector<T>(std::move (list)),
231275 std::move (dim_order),
232276 std::move (strides),
233277 type,
234- dynamism);
278+ dynamism,
279+ device_type,
280+ device_index);
235281}
236282
237283/* *
@@ -251,6 +297,8 @@ inline TensorPtr make_tensor_ptr(
251297 * @param type The scalar type of the tensor elements. If it differs from the
252298 * deduced type, the data will be cast to this type if allowed.
253299 * @param dynamism Specifies the mutability of the tensor's shape.
300+ * @param device_type The target device type (default CPU, meaning no copy).
301+ * @param device_index The target device index (default 0).
254302 * @return A TensorPtr that manages the newly created TensorImpl.
255303 */
256304template <
@@ -261,11 +309,21 @@ inline TensorPtr make_tensor_ptr(
261309 std::initializer_list<T> list,
262310 executorch::aten::ScalarType type = deduced_type,
263311 executorch::aten::TensorShapeDynamism dynamism =
264- executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) {
312+ executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND,
313+ runtime::etensor::DeviceType device_type =
314+ runtime::etensor::DeviceType::CPU,
315+ runtime::etensor::DeviceIndex device_index = 0 ) {
265316 std::vector<executorch::aten::SizesType> sizes{
266317 executorch::aten::SizesType (list.size ())};
267318 return make_tensor_ptr (
268- std::move (sizes), std::move (list), {0 }, {1 }, type, dynamism);
319+ std::move (sizes),
320+ std::move (list),
321+ {0 },
322+ {1 },
323+ type,
324+ dynamism,
325+ device_type,
326+ device_index);
269327}
270328
271329/* *
@@ -294,6 +352,8 @@ inline TensorPtr make_tensor_ptr(T value) {
294352 * @param strides A vector specifying the strides of each dimension.
295353 * @param type The scalar type of the tensor elements.
296354 * @param dynamism Specifies the mutability of the tensor's shape.
355+ * @param device_type The target device type (default CPU, meaning no copy).
356+ * @param device_index The target device index (default 0).
297357 * @return A TensorPtr managing the newly created Tensor.
298358 */
299359TensorPtr make_tensor_ptr (
@@ -303,7 +363,10 @@ TensorPtr make_tensor_ptr(
303363 std::vector<executorch::aten::StridesType> strides,
304364 executorch::aten::ScalarType type = executorch::aten::ScalarType::Float,
305365 executorch::aten::TensorShapeDynamism dynamism =
306- executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND);
366+ executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND,
367+ runtime::etensor::DeviceType device_type =
368+ runtime::etensor::DeviceType::CPU,
369+ runtime::etensor::DeviceIndex device_index = 0 );
307370
308371/* *
309372 * Creates a TensorPtr that manages a Tensor with the specified properties.
@@ -316,16 +379,28 @@ TensorPtr make_tensor_ptr(
316379 * @param data A vector containing the raw memory for the tensor's data.
317380 * @param type The scalar type of the tensor elements.
318381 * @param dynamism Specifies the mutability of the tensor's shape.
382+ * @param device_type The target device type (default CPU, meaning no copy).
383+ * @param device_index The target device index (default 0).
319384 * @return A TensorPtr managing the newly created Tensor.
320385 */
321386inline TensorPtr make_tensor_ptr (
322387 std::vector<executorch::aten::SizesType> sizes,
323388 std::vector<uint8_t > data,
324389 executorch::aten::ScalarType type = executorch::aten::ScalarType::Float,
325390 executorch::aten::TensorShapeDynamism dynamism =
326- executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) {
391+ executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND,
392+ runtime::etensor::DeviceType device_type =
393+ runtime::etensor::DeviceType::CPU,
394+ runtime::etensor::DeviceIndex device_index = 0 ) {
327395 return make_tensor_ptr (
328- std::move (sizes), std::move (data), {}, {}, type, dynamism);
396+ std::move (sizes),
397+ std::move (data),
398+ {},
399+ {},
400+ type,
401+ dynamism,
402+ device_type,
403+ device_index);
329404}
330405
331406/* *
0 commit comments