Skip to content

Commit bba6a25

Browse files
authored
Merge pull request #436 from qinyiqun/marlin
Marlin
2 parents bb2c1a5 + 9f697b0 commit bba6a25

30 files changed

Lines changed: 1052 additions & 53 deletions

csrc/engine/compiler/paged_compiler.cpp

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "../../utils.hpp"
44

55
namespace infinilm::engine {
6+
67
PagedCompiler::PagedCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier)
78
: GraphCompiler(model, barrier) {
89
for (size_t b = 1; b < 64; ++b) {
@@ -27,7 +28,8 @@ void PagedCompiler::compile() {
2728
block_tables_holder_ = infinicore::Tensor::empty(
2829
{nblocks * max_batch_size}, infinicore::DataType::I32, infinicore::context::getDevice());
2930
set_zeros(block_tables_holder_);
30-
for (size_t b : decode_batch_sizes_) {
31+
32+
auto make_decode_input = [&](size_t b) {
3133
InfinilmModel::Input input;
3234
input.input_ids = infinicore::Tensor::empty({1, b}, infinicore::DataType::I64, infinicore::context::getDevice());
3335
input.position_ids = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
@@ -59,8 +61,31 @@ void PagedCompiler::compile() {
5961
input.block_tables,
6062
input.slot_mapping,
6163
};
64+
return input;
65+
};
66+
67+
{
68+
const size_t warmup_batch_size = std::min(max_batch_size, static_cast<size_t>(64));
69+
auto input = make_decode_input(warmup_batch_size);
70+
model_->forward(input);
71+
infinicore::context::syncStream();
72+
// Warmup runs the eager Marlin path and may leave per-layer lock
73+
// workspaces dirty. Reset before CUDA graph capture so capture
74+
// starts from the same all-zero lock state as normal execution.
75+
model_->reset_runtime_state();
76+
infinicore::context::syncStream();
77+
}
78+
79+
for (size_t b : decode_batch_sizes_) {
80+
auto input = make_decode_input(b);
6281

6382
barrier_->wait();
83+
// Capture must not start with stale Marlin locks from previous
84+
// warmup/capture attempts. This reset is intentionally outside
85+
// graph capture; the current implementation still pays a memset
86+
// before every graph replay in get_compiled().
87+
model_->reset_runtime_state();
88+
infinicore::context::syncStream();
6489
infinicore::context::startGraphRecording();
6590
auto output = model_->forward(input);
6691
auto graph = infinicore::context::stopGraphRecording();
@@ -101,12 +126,19 @@ PagedCompiler::Compiled PagedCompiler::get_compiled(const InfinilmModel::Input &
101126
return {nullptr, nullptr};
102127
}
103128

104-
// Initialize full padding to -1, then overwrite the narrowed logical region.
105-
// This matches scheduler padding semantics without risking -1 access during graph recording.
129+
// Initialize only the active graph rows to -1, then overwrite the
130+
// runtime logical region. Avoid clearing the full preallocated
131+
// holder on every decode token.
106132
auto &graph_block_tables = graph_input.block_tables.value();
107-
set_minus_one(graph_block_tables);
108-
graph_input.block_tables.value()->narrow({{1, 0, block_per_req}})->copy_from(input.block_tables.value());
133+
set_minus_one_device_async(graph_block_tables);
134+
graph_block_tables->narrow({{1, 0, block_per_req}})->copy_from(input.block_tables.value());
109135
graph_input.slot_mapping.value()->copy_from(input.slot_mapping.value());
136+
// CUDA graph replay reuses the same per-layer Marlin workspaces.
137+
// The graph itself does not contain a workspace reset, so enqueue
138+
// one on the same stream before launch. This is correct but costs
139+
// decode latency; the intended follow-up is a reusable global
140+
// zero workspace/lock buffer shared by all Marlin layers.
141+
model_->reset_runtime_state();
110142

111143
auto graph = std::get<0>(result->second.compiled);
112144
auto shared_output = std::shared_ptr<InfinilmModel::Output>(new InfinilmModel::Output{std::get<1>(result->second.compiled)->logits->resume_from_blob_()});

csrc/engine/infer_engine.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,10 @@ InferEngine::InferEngine(
4242
enable_graph_compiling,
4343
attention_backend_));
4444
}
45-
// Compile the model on all workers
46-
this->compile();
45+
// Graphs must be compiled after weights are loaded and post-processed.
46+
// Quantized models may replace their linear implementations during
47+
// process_weights_after_loading(), so compiling here would capture stale
48+
// fallback operators.
4749
}
4850

4951
//------------------------------------------------------
@@ -77,6 +79,8 @@ void InferEngine::process_weights_after_loading() {
7779
for (auto &worker : workers_) {
7880
worker->process_weights_after_loading();
7981
}
82+
weights_finalized_ = true;
83+
this->compile();
8084
}
8185

8286
//------------------------------------------------------
@@ -94,6 +98,13 @@ std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> InferEng
9498
return results;
9599
}
96100

101+
std::vector<std::string> InferEngine::state_dict_keys() {
102+
if (0 == workers_.size()) {
103+
throw std::runtime_error(" Model object not found. ");
104+
}
105+
return workers_.front()->state_dict_keys();
106+
}
107+
97108
//------------------------------------------------------
98109
// forward
99110
//------------------------------------------------------
@@ -159,6 +170,9 @@ InferEngine::Output InferEngine::forward(const InferEngine::Input &input) {
159170
}
160171

161172
void InferEngine::compile() {
173+
if (!weights_finalized_) {
174+
return;
175+
}
162176
for (auto &worker : workers_) {
163177
worker->compile();
164178
}

csrc/engine/infer_engine.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ class InferEngine {
4242
// return the parameters (i.e. weights and biases).
4343
std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> state_dict();
4444

45+
std::vector<std::string> state_dict_keys();
46+
4547
// Run a single forward pass on all workers and return the outputs from all ranks
4648
Output forward(const Input &input);
4749

@@ -65,6 +67,7 @@ class InferEngine {
6567
std::unique_ptr<cache::CacheConfig> cache_config_;
6668
std::shared_ptr<infinilm::config::ModelConfig> model_config_;
6769
backends::AttentionBackend attention_backend_ = backends::AttentionBackend::Default;
70+
bool weights_finalized_ = false;
6871
};
6972

7073
} // namespace infinilm::engine

csrc/engine/rank_worker.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,17 @@ std::unordered_map<std::string, infinicore::nn::Parameter> RankWorker::state_dic
149149
return model_->state_dict();
150150
}
151151

152+
std::vector<std::string> RankWorker::state_dict_keys() {
153+
std::unique_lock<std::mutex> lk(mutex_);
154+
cv_.wait(lk, [&] { return init_done_ || should_exit_; });
155+
156+
if (!model_) {
157+
throw std::runtime_error("state_dict_keys called before model initialization");
158+
}
159+
160+
return model_->state_dict_keys();
161+
}
162+
152163
//------------------------------------------------------
153164
// run -- asynchronous
154165
//------------------------------------------------------
@@ -365,6 +376,8 @@ void RankWorker::thread_loop() {
365376
// Handle preprocess command
366377
try {
367378
model_->process_weights_after_loading();
379+
infinicore::context::syncStream();
380+
infinicore::context::trimMemory();
368381
} catch (const std::exception &e) {
369382
{
370383
std::lock_guard<std::mutex> lk(mutex_);

csrc/engine/rank_worker.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ class RankWorker {
9292
// return the parameters (i.e. weights and biases).
9393
std::unordered_map<std::string, infinicore::nn::Parameter> state_dict();
9494

95+
std::vector<std::string> state_dict_keys();
96+
9597
// Submit a run (forward) job.
9698
void run(const Input &args);
9799

csrc/layers/attention/attention.hpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,14 @@ class Attention : public infinicore::nn::Module {
2020
infinicore::Tensor forward(const infinicore::Tensor &positions,
2121
const infinicore::Tensor &hidden_states) const;
2222

23-
void process_fused_weights_after_loading() {
23+
void process_weights_after_loading() override {
2424
qkv_proj_->process_weights_after_loading();
2525
}
2626

27+
void reset_runtime_state() const override {
28+
qkv_proj_->reset_runtime_state();
29+
}
30+
2731
size_t layer_idx() const { return layer_idx_; }
2832
size_t num_heads() const { return num_attention_heads_; }
2933
size_t num_kv_heads() const { return num_key_value_heads_; }
@@ -55,7 +59,7 @@ class Attention : public infinicore::nn::Module {
5559
INFINICORE_NN_PARAMETER(kv_cache_v_scale);
5660
};
5761
void init_kv_cache_quant_params(std::function<void(const std::string &, infinicore::nn::Parameter)> register_fn,
58-
const infinicore::Device &device,
59-
infinicore::nn::Parameter &kv_cache_k_scale,
60-
infinicore::nn::Parameter &kv_cache_v_scale);
62+
const infinicore::Device &device,
63+
infinicore::nn::Parameter &kv_cache_k_scale,
64+
infinicore::nn::Parameter &kv_cache_v_scale);
6165
} // namespace infinilm::layers::attention

csrc/layers/linear/base_linear.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,22 +59,22 @@ void BaseLinear::process_weights_after_loading() {
5959
params[name] = static_cast<const infinicore::Tensor &>(param);
6060
}
6161

62-
auto new_quant = quantization_->process_weights_after_loading(params, device_);
62+
auto new_quant = quantization_->process_weights_after_loading(params, device_, split_dim_);
6363
if (!new_quant) return;
6464

65-
for (auto &[name, param] : parameters_) {
66-
param = infinicore::nn::Parameter();
67-
}
68-
65+
parameters_.clear();
6966
for (const auto &[name, tensor] : params) {
70-
auto it = parameters_.find(name);
71-
if (it == parameters_.end()) continue;
72-
it->second = infinicore::nn::Parameter(tensor);
67+
parameters_.emplace(name, infinicore::nn::Parameter(tensor));
7368
}
69+
params.clear();
7470

7571
quantization_ = std::move(new_quant);
7672
}
7773

74+
void BaseLinear::reset_runtime_state() const {
75+
quantization_->reset_runtime_state();
76+
}
77+
7878
// Backward compatible accessors
7979

8080
infinicore::Tensor BaseLinear::weight() const {

csrc/layers/linear/base_linear.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#pragma once
22

3-
#include "infinicore/ops.hpp"
43
#include "../quantization/quantization.hpp"
54
#include "infinicore/nn/module.hpp"
5+
#include "infinicore/ops.hpp"
66
#include <infiniccl.h>
77
#include <optional>
88

@@ -45,7 +45,8 @@ class BaseLinear : public infinicore::nn::Module {
4545
infinicore::Tensor get_param(const std::string &name) const;
4646

4747
std::shared_ptr<infinilm::quantization::BaseQuantization> get_quantization() const { return quantization_; }
48-
virtual void process_weights_after_loading();
48+
void process_weights_after_loading() override;
49+
void reset_runtime_state() const override;
4950

5051
// Split fused linear parameters into named sub-parameters
5152
std::vector<infinilm::quantization::SplitParam> split_params(

csrc/layers/mlp/mlp.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,14 @@ class MLP : public infinicore::nn::Module {
3636
*/
3737
infinicore::Tensor forward(const infinicore::Tensor &hidden_states) const;
3838

39-
void process_fused_weights_after_loading() {
39+
void process_weights_after_loading() override {
4040
gate_up_proj_->process_weights_after_loading();
4141
}
4242

43+
void reset_runtime_state() const override {
44+
gate_up_proj_->reset_runtime_state();
45+
}
46+
4347
// Module information
4448
size_t hidden_size() const { return hidden_size_; }
4549
size_t intermediate_size() const { return intermediate_size_; }

csrc/layers/quantization/awq.cpp

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
#include "awq.hpp"
2+
#include "awq_marlin.hpp"
23
#include "infinicore/ops/linear_w4a16_awq.hpp"
4+
#include "marlin_support.hpp"
5+
#include "marlin_utils.hpp"
36
#include <optional>
47

58
namespace infinilm::quantization {
@@ -21,12 +24,9 @@ std::vector<ParamDescriptor> AWQ::get_param_layout(
2124
int packing_num = get_packing_num();
2225

2326
std::vector<ParamDescriptor> descs;
24-
descs.push_back({"qweight", {in_features, out_features / packing_num},
25-
infinicore::DataType::I32, awq_tp_dim, tp_rank, tp_size});
26-
descs.push_back({"scales", {in_features / group_size, out_features},
27-
dtype, awq_tp_dim, tp_rank, tp_size});
28-
descs.push_back({"qzeros", {in_features / group_size, out_features / packing_num},
29-
infinicore::DataType::I32, awq_tp_dim, tp_rank, tp_size});
27+
descs.push_back({"qweight", {in_features, out_features / packing_num}, infinicore::DataType::I32, awq_tp_dim, tp_rank, tp_size});
28+
descs.push_back({"scales", {in_features / group_size, out_features}, dtype, awq_tp_dim, tp_rank, tp_size});
29+
descs.push_back({"qzeros", {in_features / group_size, out_features / packing_num}, infinicore::DataType::I32, awq_tp_dim, tp_rank, tp_size});
3030
if (bias) {
3131
descs.push_back({"bias", {out_features}, dtype, -1, 0, 1});
3232
}
@@ -52,6 +52,55 @@ infinicore::Tensor AWQ::forward(
5252
return infinicore::op::linear_w4a16_awq(input_contiguous->contiguous(), qweight, scales, qzeros, bias_opt);
5353
}
5454

55+
std::shared_ptr<BaseQuantization> AWQ::process_weights_after_loading(
56+
ParamsMap &params,
57+
const infinicore::Device &device,
58+
int /*split_dim*/) const {
59+
if (device.getType() != infinicore::Device::Type::NVIDIA) {
60+
return nullptr;
61+
}
62+
63+
#if INFINILM_ENABLE_MARLIN
64+
const int bits = get_or<int>("bits", get_or<int>("w_bit", 4));
65+
if (bits != 4) {
66+
return nullptr;
67+
}
68+
69+
auto qweight = params.at("qweight");
70+
const size_t input_size_per_partition = qweight->size(0);
71+
const size_t output_size_per_partition = qweight->size(1) * get_packing_num();
72+
const int group_size = get_group_size();
73+
if (!marlin::supports_shape(input_size_per_partition, output_size_per_partition, group_size)) {
74+
return nullptr;
75+
}
76+
77+
params["qweight"] = marlin::awq_marlin_repack(
78+
qweight,
79+
input_size_per_partition,
80+
output_size_per_partition,
81+
bits);
82+
params["scales"] = marlin::permute_scales(
83+
params.at("scales"),
84+
input_size_per_partition,
85+
output_size_per_partition,
86+
group_size);
87+
params["qzeros"] = marlin::awq_to_marlin_zero_points(
88+
params.at("qzeros"),
89+
input_size_per_partition / static_cast<size_t>(group_size == -1 ? input_size_per_partition : group_size),
90+
output_size_per_partition,
91+
bits);
92+
params["g_idx"] = marlin::make_empty_i32(device);
93+
params["perm"] = marlin::make_empty_i32(device);
94+
params["a_scales"] = marlin::make_empty_i32(device);
95+
params["global_scales"] = marlin::make_empty_i32(device);
96+
97+
return std::make_shared<AWQMarlin>(get_config(), input_size_per_partition, output_size_per_partition);
98+
#else
99+
(void)params;
100+
return nullptr;
101+
#endif
102+
}
103+
55104
std::vector<SplitParam> AWQ::split_params(
56105
const std::unordered_map<std::string, infinicore::nn::Parameter> &params,
57106
const std::vector<SplitInfo> &splits,

0 commit comments

Comments
 (0)