Skip to content

Commit e84440b

Browse files
[slimtensor] Add permute() and reshape() view operations (#16839)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #16444 by @Gasoonjia ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/gasoonjia/86/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/86/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/gasoonjia/85/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/86/orig Differential Revision: [D89952083](https://our.internmc.facebook.com/intern/diff/D89952083/) @diff-train-skip-merge --------- Co-authored-by: gasoonjia <gasoonjia@icloud.com> Co-authored-by: Gasoonjia <gasoonjia@meta.com>
1 parent 3dcd17b commit e84440b

5 files changed

Lines changed: 666 additions & 10 deletions

File tree

backends/aoti/slim/core/SlimTensor.h

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,33 @@ class SlimTensor {
312312
set_sizes_and_strides(sizes, makeArrayRef(contig_strides));
313313
}
314314

315+
/**
316+
* Returns a copy of this tensor.
317+
*
318+
* @return A new SlimTensor with same content.
319+
*/
320+
SlimTensor clone() const {
321+
return _clone_impl(
322+
this->sizes(), this->strides(), this->dtype(), this->device());
323+
}
324+
325+
/**
326+
* Returns a contiguous copy of this tensor.
327+
* If the tensor is already contiguous, returns a copy with independent
328+
* storage.
329+
*
330+
* @return A new contiguous SlimTensor.
331+
*/
332+
SlimTensor clone_contiguous() const {
333+
std::vector<int64_t> contig_strides =
334+
compute_contiguous_strides(this->sizes());
335+
return _clone_impl(
336+
this->sizes(),
337+
makeArrayRef(contig_strides),
338+
this->dtype(),
339+
this->device());
340+
}
341+
315342
// =========================================================================
316343
// View Operations
317344
// =========================================================================
@@ -364,6 +391,39 @@ class SlimTensor {
364391
makeArrayRef(sizes), makeArrayRef(strides), storage_offset);
365392
}
366393

394+
/**
395+
* Returns a new tensor with dimensions permuted according to dims.
396+
* The returned tensor shares the same underlying storage.
397+
*
398+
* @param dims The permutation of dimensions.
399+
* @return A new SlimTensor with permuted dimensions.
400+
*/
401+
inline SlimTensor permute(IntArrayRef dims) const;
402+
403+
/**
404+
* Overload for initializer lists.
405+
*/
406+
inline SlimTensor permute(std::initializer_list<int64_t> dims) const {
407+
return permute(makeArrayRef(dims));
408+
}
409+
410+
/**
411+
* Returns a tensor with the same data and number of elements as this tensor,
412+
* but with the specified shape. If possible, returns a view; otherwise
413+
* creates a contiguous copy.
414+
*
415+
* @param shape The target shape (may contain one -1 for inference).
416+
* @return A new SlimTensor with the specified shape.
417+
*/
418+
inline SlimTensor reshape(IntArrayRef shape) const;
419+
420+
/**
421+
* Overload for initializer lists.
422+
*/
423+
inline SlimTensor reshape(std::initializer_list<int64_t> shape) const {
424+
return reshape(makeArrayRef(shape));
425+
}
426+
367427
// =========================================================================
368428
// Copy Operation
369429
// =========================================================================
@@ -445,6 +505,18 @@ class SlimTensor {
445505
}
446506

447507
private:
508+
SlimTensor _clone_impl(
509+
c10::IntArrayRef sizes,
510+
c10::IntArrayRef strides,
511+
c10::ScalarType dtype,
512+
const c10::Device& device) const {
513+
Storage storage = new_storage(sizes, strides, dtype, device);
514+
SlimTensor result =
515+
SlimTensor(std::move(storage), sizes, strides, dtype, 0);
516+
result.copy_(*this);
517+
return result;
518+
}
519+
448520
void refresh_numel() {
449521
numel_ = compute_numel(sizes_and_strides_.sizes_arrayref());
450522
}

backends/aoti/slim/core/SlimTensorView-incl.h

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,62 @@ inline SlimTensor& SlimTensor::as_strided_(
4949
return *this;
5050
}
5151

52+
inline SlimTensor SlimTensor::permute(IntArrayRef dims) const {
53+
const size_t ndim = this->dim();
54+
ET_CHECK_MSG(
55+
ndim == dims.size(),
56+
"permute: dims length (%zu) must equal tensor.dim() (%zu)",
57+
dims.size(),
58+
ndim);
59+
60+
IntArrayRef old_sizes = this->sizes();
61+
IntArrayRef old_strides = this->strides();
62+
std::vector<int64_t> new_sizes(ndim);
63+
std::vector<int64_t> new_strides(ndim);
64+
std::vector<bool> seen_dims(ndim, false);
65+
66+
for (size_t i = 0; i < ndim; i++) {
67+
int64_t d = c10::maybe_wrap_dim(dims[i], ndim);
68+
ET_CHECK_MSG(!seen_dims[d], "permute: duplicate dims are not allowed");
69+
seen_dims[d] = true;
70+
new_sizes[i] = old_sizes[d];
71+
new_strides[i] = old_strides[d];
72+
}
73+
74+
SlimTensor result = *this;
75+
result.as_strided_(
76+
makeArrayRef(new_sizes),
77+
makeArrayRef(new_strides),
78+
this->storage_offset());
79+
return result;
80+
}
81+
82+
inline SlimTensor SlimTensor::reshape(IntArrayRef proposed_shape) const {
83+
std::vector<int64_t> final_shape_vec =
84+
infer_size(proposed_shape, static_cast<int64_t>(this->numel()));
85+
86+
// compute_stride returns the proper strides to use if this
87+
// reshape can be just a view.
88+
std::optional<std::vector<int64_t>> new_strides_opt = compute_stride(
89+
this->sizes(), this->strides(), makeArrayRef(final_shape_vec));
90+
91+
// Create a view if possible
92+
if (new_strides_opt.has_value()) {
93+
SlimTensor result = *this;
94+
result.as_strided_(
95+
makeArrayRef(final_shape_vec),
96+
makeArrayRef(new_strides_opt.value()),
97+
this->storage_offset());
98+
return result;
99+
}
100+
101+
// If a view is not possible, create a contiguous clone and reshape that
102+
SlimTensor contiguous_clone = this->clone_contiguous();
103+
// After cloning, the tensor is already contiguous. We just need to update
104+
// its metadata to reflect the new shape. This is effectively a view of
105+
// the new contiguous clone.
106+
contiguous_clone.set_sizes_contiguous(makeArrayRef(final_shape_vec));
107+
return contiguous_clone;
108+
}
109+
52110
} // namespace executorch::backends::aoti::slim

backends/aoti/slim/core/test/targets.bzl

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,6 @@ def get_backend_mode():
77

88
def define_common_targets():
99
"""Define test targets for SlimTensor core module."""
10-
runtime.cxx_test(
11-
name = "test_slimtensor_dtypes",
12-
srcs = [
13-
"test_slimtensor_dtypes.cpp",
14-
],
15-
deps = [
16-
"//executorch/backends/aoti/slim/factory:empty",
17-
],
18-
)
19-
2010
# Backend mode specific tests
2111
for backend_mode in get_backend_mode():
2212
backend_suffix = "_" + backend_mode if backend_mode == "cuda" else ""
@@ -77,3 +67,16 @@ def define_common_targets():
7767
],
7868
**backend_kwargs
7969
)
70+
71+
72+
runtime.cxx_test(
73+
name = "test_permute_reshape" + backend_suffix,
74+
srcs = [
75+
"test_permute_reshape.cpp",
76+
],
77+
deps = [
78+
"//executorch/backends/aoti/slim/core:slimtensor",
79+
"//executorch/backends/aoti/slim/factory:empty",
80+
],
81+
**backend_kwargs
82+
)

0 commit comments

Comments
 (0)