Skip to content

Commit dbdf569

Browse files
committed
feat: add scatter operator, distinguish tensor and communication APIs via namespaces, and reorganize functions in misc files
1 parent fc361c7 commit dbdf569

19 files changed

Lines changed: 590 additions & 304 deletions

File tree

infini_train/include/autograd/comm.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class ProcessGroup;
1515
} // namespace nn::parallel
1616
} // namespace infini_train
1717

18-
namespace infini_train::autograd {
18+
namespace infini_train::autograd::comm {
1919
class Scatter : public autograd::Function {
2020
public:
2121
static constexpr char kType[] = "ScatterFunction";
@@ -99,4 +99,4 @@ class ReduceAddCoalesced : public autograd::Function {
9999
std::vector<Device> target_gpus_;
100100
int64_t num_inputs_ = 0;
101101
};
102-
} // namespace infini_train::autograd
102+
} // namespace infini_train::autograd::comm
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#pragma once
2+
3+
#include <memory>
4+
#include <vector>
5+
6+
#include "infini_train/include/autograd/function.h"
7+
8+
namespace infini_train {
9+
class Tensor;
10+
}
11+
12+
namespace infini_train::autograd {
13+
14+
class Gather : public Function {
15+
public:
16+
static constexpr char kType[] = "GatherFunction";
17+
18+
Gather(int64_t dim = 0) : Function(kType), dim_(dim) {}
19+
20+
std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
21+
void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
22+
const std::vector<std::shared_ptr<Tensor>> &output_tensors) override;
23+
std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;
24+
25+
private:
26+
const int64_t dim_ = 0;
27+
std::vector<int64_t> input_dims_;
28+
};
29+
30+
class Slice : public Function {
31+
public:
32+
static constexpr char kType[] = "SliceFunction";
33+
34+
Slice(const std::vector<int64_t> &starts, const std::vector<int64_t> &ends, const std::vector<int64_t> &steps)
35+
: Function(kType), starts_(starts), ends_(ends), steps_(steps) {}
36+
std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
37+
void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
38+
const std::vector<std::shared_ptr<Tensor>> &output_tensors) override;
39+
std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;
40+
41+
private:
42+
const std::vector<int64_t> starts_;
43+
const std::vector<int64_t> ends_;
44+
const std::vector<int64_t> steps_;
45+
};
46+
47+
} // namespace infini_train::autograd

infini_train/include/autograd/misc.h

Lines changed: 0 additions & 113 deletions
This file was deleted.
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#pragma once
2+
3+
#include <memory>
4+
#include <vector>
5+
6+
#include "infini_train/include/autograd/function.h"
7+
8+
namespace infini_train {
9+
class Tensor;
10+
}
11+
12+
namespace infini_train::autograd {
13+
14+
class NoOp : public Function {
15+
public:
16+
static constexpr char kType[] = "NoOpFunction";
17+
18+
explicit NoOp(const std::vector<int64_t> &output_dims) : Function(kType), output_dims_(output_dims) {}
19+
20+
std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
21+
void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
22+
const std::vector<std::shared_ptr<Tensor>> &output_tensors) override;
23+
std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;
24+
25+
private:
26+
const std::vector<int64_t> output_dims_;
27+
std::vector<int64_t> input_dims_;
28+
};
29+
30+
} // namespace infini_train::autograd
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#pragma once
2+
3+
#include <memory>
4+
#include <vector>
5+
6+
#include "infini_train/include/autograd/function.h"
7+
8+
namespace infini_train {
9+
class Tensor;
10+
}
11+
12+
namespace infini_train::autograd {
13+
14+
class Scatter : public Function {
15+
public:
16+
static constexpr char kType[] = "ScatterFunction";
17+
18+
explicit Scatter(const std::vector<int64_t> &output_dims) : Function(kType), output_dims_(output_dims) {}
19+
20+
std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
21+
void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
22+
const std::vector<std::shared_ptr<Tensor>> &output_tensors) override;
23+
std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;
24+
25+
private:
26+
std::vector<int64_t> output_dims_;
27+
};
28+
29+
} // namespace infini_train::autograd

infini_train/include/autograd/transform.h

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,4 +78,53 @@ class RepeatInterleave : public Function {
7878
std::vector<int64_t> input_dims_;
7979
};
8080

81+
class Split : public Function {
82+
public:
83+
static constexpr char kType[] = "SplitFunction";
84+
85+
Split(int64_t split_size, int dim = 0) : Function(kType), split_size_(split_size), dim_(dim) {}
86+
87+
std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
88+
void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
89+
const std::vector<std::shared_ptr<Tensor>> &output_tensors) override;
90+
std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;
91+
92+
private:
93+
const int64_t split_size_ = 0;
94+
const int dim_ = 0;
95+
std::vector<int64_t> input_dims_;
96+
};
97+
98+
class Stack : public Function {
99+
public:
100+
static constexpr char kType[] = "StackFunction";
101+
102+
Stack(int64_t dim) : Function(kType), dim_(dim) {}
103+
104+
std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
105+
void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
106+
const std::vector<std::shared_ptr<Tensor>> &output_tensors) override;
107+
std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;
108+
109+
private:
110+
int64_t dim_ = 0;
111+
std::vector<int64_t> input_dims_;
112+
};
113+
114+
class Concat : public Function {
115+
public:
116+
static constexpr char kType[] = "ConcatFunction";
117+
118+
Concat(int64_t dim) : Function(kType), dim_(dim) {}
119+
120+
std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
121+
void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
122+
const std::vector<std::shared_ptr<Tensor>> &output_tensors) override;
123+
std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;
124+
125+
private:
126+
const int64_t dim_ = 0;
127+
std::vector<std::vector<int64_t>> input_dims_list_;
128+
};
129+
81130
} // namespace infini_train::autograd

infini_train/src/autograd/comm.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#include "infini_train/include/nn/parallel/process_group.h"
99
#include "infini_train/include/tensor.h"
1010

11-
namespace infini_train::autograd {
11+
namespace infini_train::autograd::comm {
1212

1313
Scatter::Scatter(const std::vector<Device> &target_gpus, int64_t dim,
1414
const infini_train::nn::parallel::ProcessGroup *pg)
@@ -122,4 +122,4 @@ std::vector<std::shared_ptr<Tensor>>
122122
ReduceAddCoalesced::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) {
123123
return std::make_shared<Broadcast>(target_gpus_)->Apply(grad_outputs);
124124
}
125-
} // namespace infini_train::autograd
125+
} // namespace infini_train::autograd::comm
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#include "infini_train/include/autograd/indexing.h"
2+
3+
#include "glog/logging.h"
4+
5+
#include "infini_train/include/dispatcher.h"
6+
#include "infini_train/include/tensor.h"
7+
8+
namespace infini_train::autograd {
9+
std::vector<std::shared_ptr<Tensor>> Gather::Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) {
10+
CHECK_EQ(input_tensors.size(), 2);
11+
const auto &input = input_tensors[0];
12+
const auto &index = input_tensors[1];
13+
14+
auto device = input->GetDevice().type();
15+
auto kernel = Dispatcher::Instance().GetKernel({device, "GatherForward"});
16+
return {kernel.Call<std::shared_ptr<Tensor>>(input, index, dim_)};
17+
}
18+
19+
void Gather::SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
20+
const std::vector<std::shared_ptr<Tensor>> &) {
21+
const auto &input = input_tensors[0];
22+
const auto &index = input_tensors[1];
23+
input_dims_ = input->Dims();
24+
saved_tensors_ = {index};
25+
}
26+
27+
std::vector<std::shared_ptr<Tensor>> Gather::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) {
28+
CHECK_EQ(grad_outputs.size(), 1);
29+
const auto &grad_output = grad_outputs[0];
30+
const auto &index = saved_tensors_[0];
31+
32+
auto device = grad_outputs[0]->GetDevice();
33+
auto kernel = Dispatcher::Instance().GetKernel({device.type(), "GatherBackward"});
34+
return {kernel.Call<std::shared_ptr<Tensor>>(grad_output, index, dim_, input_dims_), nullptr};
35+
}
36+
37+
std::vector<std::shared_ptr<Tensor>> Slice::Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) {
38+
CHECK_EQ(input_tensors.size(), 1);
39+
const auto &input = input_tensors[0];
40+
41+
auto device = input->GetDevice().type();
42+
return {
43+
Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "SliceForward"}, input, starts_, ends_, steps_)};
44+
}
45+
46+
void Slice::SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
47+
const std::vector<std::shared_ptr<Tensor>> &) {
48+
// FIXME(dcj): only input's dim need to be saved
49+
const auto &input = input_tensors[0];
50+
saved_tensors_ = {input};
51+
}
52+
53+
std::vector<std::shared_ptr<Tensor>> Slice::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) {
54+
CHECK_EQ(saved_tensors_.size(), 1);
55+
const auto &input = saved_tensors_[0];
56+
const auto &grad_output = grad_outputs[0];
57+
58+
auto device = input->GetDevice().type();
59+
return {Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "SliceBackward"}, grad_output, input, starts_,
60+
ends_, steps_)};
61+
}
62+
63+
} // namespace infini_train::autograd

0 commit comments

Comments
 (0)