-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtensor.hpp
More file actions
executable file
·35 lines (30 loc) · 925 Bytes
/
tensor.hpp
File metadata and controls
executable file
·35 lines (30 loc) · 925 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#include <iostream>
#include <vector>
#include <assert.h>
#include <cmath>
#include <memory>
//TODO: using templates
class Tensor{
std::shared_ptr<std::vector<double>> data;
public:
std::vector<int> shape;
std::vector<int> stride;
int size;
Tensor(const std::vector<int>& shape);
Tensor();
double& get(const std::vector<int>& pos);
const double& get(const std::vector<int>& pos) const;
double& get();
const double& get() const;
//double get(vector<int> pos); idk
Tensor operator+(const Tensor& other);
Tensor operator*(const Tensor& other);
Tensor mm(const Tensor& other);
Tensor sigmoid();
Tensor sigmoidDeriv();
void printShape();
Tensor transpose();
Tensor broadcast(const std::vector<int>& shape) const;
};
template <typename Func>
void iterate(const std::vector<int>& shape, const std::vector<int>& stride, Func&& fn);