This repository was archived by the owner on May 13, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 41
Expand file tree
/
Copy pathfused_linear.cpp
More file actions
102 lines (93 loc) · 3.34 KB
/
Copy pathfused_linear.cpp
File metadata and controls
102 lines (93 loc) · 3.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#include "fused_linear.h"
#include <glog/logging.h>
#include <torch/torch.h>
#include "linear.h"
#include "model_loader/state_dict.h"
#include "model_parallel/parallel_args.h"
#include "quantization/quant_args.h"
namespace llm {
FusedColumnParallelLinearImpl::FusedColumnParallelLinearImpl(
int64_t in_features,
const std::vector<int64_t>& out_features_vec,
const std::vector<std::string>& prefixes,
bool bias,
bool gather_output,
const QuantArgs& quant_args,
const ParallelArgs& parallel_args,
const torch::TensorOptions& options) {
prefixes_ = prefixes;
// check if the linear layers can be fused
fused_ = quant_args.can_be_fused();
if (fused_) {
// fused linear layer
const int64_t out_features = std::accumulate(
out_features_vec.begin(), out_features_vec.end(), int64_t(0));
fused_linear_ = ColumnParallelLinear(in_features,
out_features,
bias,
gather_output,
quant_args,
parallel_args,
options);
// calculate split sizes
split_sizes_.reserve(out_features_vec.size());
const auto world_size = parallel_args.world_size();
for (const auto& out_features : out_features_vec) {
CHECK(out_features % world_size == 0)
<< "out_features " << out_features << " not divisible by world_size "
<< world_size;
split_sizes_.push_back(out_features / world_size);
}
} else {
// non-fused linear layers
parallel_linears_.reserve(out_features_vec.size());
for (const auto& out_features : out_features_vec) {
parallel_linears_.emplace_back(in_features,
out_features,
bias,
gather_output,
quant_args,
parallel_args,
options);
}
}
}
std::vector<torch::Tensor> FusedColumnParallelLinearImpl::forward(
torch::Tensor input) {
if (fused_) {
auto fused_output = fused_linear_->forward(input);
return fused_output.split(split_sizes_, /*dim=*/1);
}
// otherwise, use the non-fused linear layers
std::vector<torch::Tensor> outputs;
outputs.reserve(parallel_linears_.size());
for (auto& parallel_linear : parallel_linears_) {
auto output = parallel_linear->forward(input);
outputs.push_back(output);
}
return outputs;
}
size_t FusedColumnParallelLinearImpl::load(const StateDict& state_dict,
const std::string&) {
if (fused_) {
fused_linear_->load_state_dict(state_dict, prefixes_);
} else {
CHECK_EQ(parallel_linears_.size(), prefixes_.size());
for (size_t i = 0; i < parallel_linears_.size(); ++i) {
parallel_linears_[i]->load_state_dict(state_dict.select(prefixes_[i]));
}
}
return 0;
}
bool FusedColumnParallelLinearImpl::verify(
const std::string& name_prefix) const {
if (fused_) {
fused_linear_->verify_loaded_weights(name_prefix);
} else {
for (const auto& parallel_linear : parallel_linears_) {
parallel_linear->verify_loaded_weights(name_prefix);
}
}
return true;
}
} // namespace llm