Skip to content

Commit 3d6bc62

Browse files
committed
feat(core): introduce TensorShape metadata struct and factory helpers
- Implement `TensorShape` struct aligned to 16-byte boundaries. - Add host/device `numel()` method to compute total element count. - Add host-side `compute_strides()` and `is_contiguous()` logic for row-major layouts. - Provide initialization shortcuts for 1D, 2D, 3D, and 4D tensor shapes.
1 parent 157b4f0 commit 3d6bc62

1 file changed

Lines changed: 100 additions & 0 deletions

File tree

cuda/includes/tensor.cuh

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#pragma once
2+
#include "common.h"
3+
// TensorShape — dimensions + strides (row-major by default)
4+
struct QX_ALIGN_16 TensorShape
5+
{
6+
int dims[QX_MAX_DIMS];
7+
int strides[QX_MAX_DIMS];
8+
int ndim;
9+
int _pad;
10+
11+
QX_HOST_DEVICE QX_INLINE int64_t numel() const
12+
{
13+
int64_t n = 1;
14+
for (int i = 0; i < ndim; i++)
15+
n *= dims[i];
16+
return n;
17+
}
18+
19+
QX_HOST QX_INLINE void compute_strides()
20+
{
21+
strides[ndim - 1] = 1;
22+
for (int i = ndim - 2; i >= 0; i--)
23+
strides[i] = strides[i + 1] * dims[i + 1];
24+
}
25+
26+
QX_HOST QX_INLINE bool is_contiguous() const
27+
{
28+
int expected = 1;
29+
for (int i = ndim - 1; i >= 0; i--)
30+
{
31+
if (strides[i] != expected)
32+
return false;
33+
expected *= dims[i];
34+
}
35+
return true;
36+
}
37+
};
38+
39+
static inline TensorShape make_shape(const int *d, int ndim)
40+
{
41+
TensorShape s;
42+
s.ndim = ndim;
43+
s._pad = 0;
44+
for (int i = 0; i < ndim; i++)
45+
s.dims[i] = d[i];
46+
for (int i = ndim; i < QX_MAX_DIMS; i++)
47+
{
48+
s.dims[i] = 1;
49+
s.strides[i] = 1;
50+
}
51+
s.compute_strides();
52+
return s;
53+
}
54+
static inline TensorShape make_shape1d(int a)
55+
{
56+
int d[] = {a};
57+
return make_shape(d, 1);
58+
}
59+
static inline TensorShape make_shape2d(int a, int b)
60+
{
61+
int d[] = {a, b};
62+
return make_shape(d, 2);
63+
}
64+
static inline TensorShape make_shape3d(int a, int b, int c)
65+
{
66+
int d[] = {a, b, c};
67+
return make_shape(d, 3);
68+
}
69+
static inline TensorShape make_shape4d(int a, int b, int c, int e)
70+
{
71+
int d[] = {a, b, c, e};
72+
return make_shape(d, 4);
73+
}
74+
// Tensor — primary data carrier (host struct, kernels get raw pointers)
75+
struct Tensor
76+
{
77+
void *data;
78+
TensorShape shape;
79+
DType dtype;
80+
MemLocation mem_loc;
81+
bool owns_data;
82+
int device_id;
83+
char name[64];
84+
85+
template <typename T>
86+
QX_HOST_DEVICE QX_INLINE T *as()
87+
{
88+
return reinterpret_cast<T *>(data);
89+
}
90+
template <typename T>
91+
QX_HOST_DEVICE QX_INLINE const T *as() const
92+
{
93+
return reinterpret_cast<const T *>(data);
94+
}
95+
96+
QX_HOST QX_INLINE size_t nbytes() const { return (size_t)shape.numel() * dtype_size(dtype); }
97+
QX_HOST_DEVICE QX_INLINE int dim(int i) const { return shape.dims[i]; }
98+
QX_HOST_DEVICE QX_INLINE int ndim() const { return shape.ndim; }
99+
QX_HOST_DEVICE QX_INLINE int64_t numel() const { return shape.numel(); }
100+
};

0 commit comments

Comments
 (0)