Skip to content

Commit 5f3bb5f

Browse files
committed
utility function to flatten tensor
1 parent b77d969 commit 5f3bb5f

3 files changed

Lines changed: 21 additions & 0 deletions

File tree

src/utilities/tensor.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,14 @@ TensorShard shard_view(const Tensor& src, int idx, int num) {
119119
return TensorShard{shard, idx, num, src.Sizes};
120120
}
121121

122+
Tensor flat_view(const Tensor& src) {
123+
Tensor dst{src};
124+
dst.Sizes.fill(0);
125+
dst.Sizes[0] = src.nelem();
126+
dst.Rank = 1;
127+
return dst;
128+
}
129+
122130
void visit(const std::function<void(Tensor&)>& func, SimpleTensorContainer& container) {
123131
auto cs = container.num_tensors();
124132
for(std::size_t i = 0; i < cs; ++i) {
@@ -168,6 +176,14 @@ GenericTensorContainer shard_empty_container(GenericTensorContainer&& c, int wor
168176
return std::move(c);
169177
}
170178

179+
GenericTensorContainer flattened_view(const GenericTensorContainer& c) {
180+
std::vector<Tensor> flats(c.num_tensors());
181+
for (std::size_t i = 0; i < c.num_tensors(); ++i) {
182+
flats.at(i) = flat_view(c.get_tensor(i));
183+
}
184+
return GenericTensorContainer{flats};
185+
}
186+
171187
GenericTensorContainer shard_view(const GenericTensorContainer& c, int rank, int world) {
172188
std::vector<Tensor> shards(c.num_tensors());
173189
for (std::size_t i = 0; i < c.num_tensors(); ++i) {

src/utilities/tensor.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,4 +160,6 @@ class TensorShard : public Tensor {
160160
};
161161

162162
TensorShard shard_view(const Tensor& src, int idx, int num);
163+
Tensor flat_view(const Tensor& src);
164+
163165
#endif //LLMQ_SRC_UTILS_TENSOR_H

src/utilities/tensor_container.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ class GenericTensorContainer final : public SimpleTensorContainer {
6161
//! are `nullptr`, but sizes have been set up.
6262
GenericTensorContainer shard_empty_container(GenericTensorContainer&& c, int world);
6363

64+
//! Flattens all tensors is the container.
65+
GenericTensorContainer flattened_view(const GenericTensorContainer& c);
66+
6467
//! Shards a non-empty tensor container. The returned container's tensors are _views_ into
6568
//! the original container's tensors.
6669
GenericTensorContainer shard_view(const GenericTensorContainer& c, int rank, int world);

0 commit comments

Comments
 (0)