|
| 1 | +#pragma once |
| 2 | +#include "common.h" |
| 3 | +// TensorShape — dimensions + strides (row-major by default) |
| 4 | +struct QX_ALIGN_16 TensorShape |
| 5 | +{ |
| 6 | + int dims[QX_MAX_DIMS]; |
| 7 | + int strides[QX_MAX_DIMS]; |
| 8 | + int ndim; |
| 9 | + int _pad; |
| 10 | + |
| 11 | + QX_HOST_DEVICE QX_INLINE int64_t numel() const |
| 12 | + { |
| 13 | + int64_t n = 1; |
| 14 | + for (int i = 0; i < ndim; i++) |
| 15 | + n *= dims[i]; |
| 16 | + return n; |
| 17 | + } |
| 18 | + |
| 19 | + QX_HOST QX_INLINE void compute_strides() |
| 20 | + { |
| 21 | + strides[ndim - 1] = 1; |
| 22 | + for (int i = ndim - 2; i >= 0; i--) |
| 23 | + strides[i] = strides[i + 1] * dims[i + 1]; |
| 24 | + } |
| 25 | + |
| 26 | + QX_HOST QX_INLINE bool is_contiguous() const |
| 27 | + { |
| 28 | + int expected = 1; |
| 29 | + for (int i = ndim - 1; i >= 0; i--) |
| 30 | + { |
| 31 | + if (strides[i] != expected) |
| 32 | + return false; |
| 33 | + expected *= dims[i]; |
| 34 | + } |
| 35 | + return true; |
| 36 | + } |
| 37 | +}; |
| 38 | + |
| 39 | +static inline TensorShape make_shape(const int *d, int ndim) |
| 40 | +{ |
| 41 | + TensorShape s; |
| 42 | + s.ndim = ndim; |
| 43 | + s._pad = 0; |
| 44 | + for (int i = 0; i < ndim; i++) |
| 45 | + s.dims[i] = d[i]; |
| 46 | + for (int i = ndim; i < QX_MAX_DIMS; i++) |
| 47 | + { |
| 48 | + s.dims[i] = 1; |
| 49 | + s.strides[i] = 1; |
| 50 | + } |
| 51 | + s.compute_strides(); |
| 52 | + return s; |
| 53 | +} |
| 54 | +static inline TensorShape make_shape1d(int a) |
| 55 | +{ |
| 56 | + int d[] = {a}; |
| 57 | + return make_shape(d, 1); |
| 58 | +} |
| 59 | +static inline TensorShape make_shape2d(int a, int b) |
| 60 | +{ |
| 61 | + int d[] = {a, b}; |
| 62 | + return make_shape(d, 2); |
| 63 | +} |
| 64 | +static inline TensorShape make_shape3d(int a, int b, int c) |
| 65 | +{ |
| 66 | + int d[] = {a, b, c}; |
| 67 | + return make_shape(d, 3); |
| 68 | +} |
| 69 | +static inline TensorShape make_shape4d(int a, int b, int c, int e) |
| 70 | +{ |
| 71 | + int d[] = {a, b, c, e}; |
| 72 | + return make_shape(d, 4); |
| 73 | +} |
| 74 | +// Tensor — primary data carrier (host struct, kernels get raw pointers) |
| 75 | +struct Tensor |
| 76 | +{ |
| 77 | + void *data; |
| 78 | + TensorShape shape; |
| 79 | + DType dtype; |
| 80 | + MemLocation mem_loc; |
| 81 | + bool owns_data; |
| 82 | + int device_id; |
| 83 | + char name[64]; |
| 84 | + |
| 85 | + template <typename T> |
| 86 | + QX_HOST_DEVICE QX_INLINE T *as() |
| 87 | + { |
| 88 | + return reinterpret_cast<T *>(data); |
| 89 | + } |
| 90 | + template <typename T> |
| 91 | + QX_HOST_DEVICE QX_INLINE const T *as() const |
| 92 | + { |
| 93 | + return reinterpret_cast<const T *>(data); |
| 94 | + } |
| 95 | + |
| 96 | + QX_HOST QX_INLINE size_t nbytes() const { return (size_t)shape.numel() * dtype_size(dtype); } |
| 97 | + QX_HOST_DEVICE QX_INLINE int dim(int i) const { return shape.dims[i]; } |
| 98 | + QX_HOST_DEVICE QX_INLINE int ndim() const { return shape.ndim; } |
| 99 | + QX_HOST_DEVICE QX_INLINE int64_t numel() const { return shape.numel(); } |
| 100 | +}; |
0 commit comments