Skip to content

Commit d1821a5

Browse files
voltjiawooway777
authored andcommitted
加入使用数组作为 shapestrides 创建 ninetoothed::Tensor 的方式
1 parent 586420d commit d1821a5

1 file changed

Lines changed: 2 additions & 0 deletions

File tree

src/infiniop/ninetoothed/utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ class Tensor {
1919

2020
Tensor(const void *data, std::initializer_list<Size> shape, std::initializer_list<Stride> strides) : Tensor{data, decltype(shape_){shape}, decltype(strides_){strides}} {}
2121

22+
Tensor(const void *data, const Size *shape, const Stride *strides, Size ndim) : data_{data}, shape_{shape, shape + ndim}, strides_{strides, strides + ndim}, ndim_{shape_.size()} {}
23+
2224
Tensor(const T value) : value_{value}, data_{&value_}, ndim_{0} {}
2325

2426
operator NineToothedTensor() { return {const_cast<Data>(data_), shape_.data(), strides_.data()}; }

0 commit comments

Comments
 (0)