Skip to content

Commit bdec219

Browse files
chen2021673claude
authored andcommitted
feat(lora): add GetLoRAParameters and MergeAndUnload APIs
- Refactor GetLoRAParameters() to retrieve only LoRA parameters for optimizer - Add MergeAndUnload() to merge weights and export as standard model - Update gpt2/llama3 examples to use new GetLoRAParameters API - Refactor LoRA linear modules and fix dimension mismatch - Improve LoRA tests and update documentation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 08c0e1d commit bdec219

18 files changed

Lines changed: 692 additions & 528 deletions

File tree

docs/lora_usage.md

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ config.alpha = 16.0f; // 缩放因子
4141
auto lora_model = GetLoRAModel(model, config);
4242

4343
// 3. 获取可训练参数用于优化器
44-
auto trainable_params = lora_model->TrainableParameters();
44+
auto trainable_params = nn::lora::GetLoRAParameters(lora_model);
4545
auto optimizer = std::make_shared<Adam>(trainable_params, lr);
4646

4747
// 4. 训练循环
@@ -94,9 +94,6 @@ public:
9494
LoRAModel(std::shared_ptr<Module> base_model,
9595
const LoRAConfig &config);
9696
97-
// 获取可训练参数
98-
std::vector<std::shared_ptr<Tensor>> TrainableParameters() const;
99-
10097
// 获取所有参数
10198
std::vector<std::shared_ptr<Tensor>> Parameters() const override;
10299
@@ -237,11 +234,15 @@ void MergeLoRAWeights(std::shared_ptr<Module> model);
237234

238235
// 恢复原始基础权重
239236
void UnmergeLoRAWeights(std::shared_ptr<Module> model);
237+
238+
// 合并权重并卸载 LoRA 模块,返回纯基础模型
239+
std::shared_ptr<Module> MergeAndUnload(std::shared_ptr<Module> model);
240240
```
241241
242242
**使用场景:**
243243
- 推理时合并权重可以消除额外计算开销
244244
- 导出模型时合并权重得到标准模型格式
245+
- `MergeAndUnload`: 导出完整的标准模型,替换所有 LoRA 模块为普通 Linear 层
245246
246247
### 保存/加载函数
247248
@@ -351,7 +352,7 @@ int main() {
351352
// =========================================
352353

353354
// 创建优化器(只优化 LoRA 参数)
354-
auto trainable_params = lora_model->TrainableParameters();
355+
auto trainable_params = nn::lora::GetLoRAParameters(lora_model);
355356
auto optimizer = std::make_shared<Adam>(trainable_params, /*lr=*/1e-4);
356357

357358
// 训练循环
@@ -441,7 +442,39 @@ auto output = (*lora_model)({input_ids});
441442
UnmergeLoRAWeights(lora_model);
442443
```
443444
444-
### 示例 4: 自定义目标层
445+
### 示例 4: 导出标准模型 (MergeAndUnload)
446+
447+
使用 `MergeAndUnload` 将 LoRA 模型转换为标准模型,可以直接保存为普通模型文件:
448+
449+
```cpp
450+
// 加载基础模型并应用 LoRA
451+
auto model = std::make_shared<GPT2>(config);
452+
model->LoadWeights("gpt2_weights.bin");
453+
454+
LoRAConfig lora_config;
455+
lora_config.rank = 8;
456+
lora_config.alpha = 16.0f;
457+
lora_config.target_modules = ParseLoRATargetModules("c_attn,c_proj");
458+
auto lora_model = GetLoRAModel(model, lora_config);
459+
460+
// 训练...
461+
// ...
462+
463+
// 加载训练好的 LoRA 权重
464+
LoadLoRAWeights(lora_model, "gpt2_lora.bin");
465+
466+
// 合并并卸载 LoRA,返回标准模型
467+
// lora_model 中的所有 LoRALinear 都被替换为普通 Linear
468+
auto merged_model = MergeAndUnload(lora_model);
469+
470+
// 保存为标准模型(与原始模型格式相同)
471+
merged_model->SaveWeights("gpt2_finetuned.bin");
472+
473+
// 现在 merged_model 是一个普通模型,无需 LoRA 即可推理
474+
auto output = (*merged_model)({input_ids});
475+
```
476+
477+
### 示例 5: 自定义目标层
445478

446479
```cpp
447480
// 对所有线性层应用
@@ -599,7 +632,7 @@ int main() {
599632
auto lora_model = std::make_shared<LoRAModel>(base_model, lora_config);
600633

601634
// 4. 获取可训练参数用于优化器
602-
auto trainable_params = lora_model->TrainableParameters();
635+
auto trainable_params = nn::lora::GetLoRAParameters(lora_model);
603636
auto optimizer = std::make_shared<Adam>(trainable_params, 1e-5);
604637

605638
// 5. 打印摘要

example/gpt2/main.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ void Train(const nn::parallel::Rank &rank) {
239239

240240
auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size);
241241

242-
// Create optimizer - use LoRAModel's TrainableParameters() if LoRA is enabled
242+
// Create optimizer - use GetLoRAParameters if LoRA is enabled
243243
std::vector<std::shared_ptr<Tensor>> params_to_optimize;
244244
if (lora_enabled) {
245245
params_to_optimize = nn::lora::GetLoRAParameters(model);

example/llama3/main.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,14 @@ void Train(const nn::parallel::Rank &rank) {
263263
auto optimizer_creator = optimizers::Adam::Create(FLAGS_learning_rate);
264264
std::shared_ptr<Optimizer> optimizer = nullptr;
265265

266-
// Create optimizer - use TrainableParameters() as single source of truth
267-
std::vector<std::shared_ptr<Tensor>> params_to_optimize = model->TrainableParameters();
266+
std::vector<std::shared_ptr<Tensor>> params_to_optimize;
267+
if (lora_enabled) {
268+
params_to_optimize = nn::lora::GetLoRAParameters(model);
269+
LOG(INFO) << "Optimizing " << params_to_optimize.size() << " LoRA parameters";
270+
} else {
271+
params_to_optimize = model->Parameters();
272+
LOG(INFO) << "Optimizing " << params_to_optimize.size() << " model parameters";
273+
}
268274

269275
if (FLAGS_use_distributed_optimizer) {
270276
auto model_chunks = (pp_world_size > 1)

infini_train/include/nn/lora/lora_linear.h

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@
44
#include <vector>
55

66
#include "infini_train/include/nn/lora/lora_config.h"
7-
#include "infini_train/include/nn/modules/module.h"
7+
#include "infini_train/include/nn/modules/linear.h"
8+
9+
// Forward declarations for test functions (required for friend declarations)
10+
void test_lora_linear_init();
11+
void test_lora_linear_forward();
12+
void test_lora_linear_merge();
13+
void test_lora_utils();
814

915
namespace infini_train {
1016
class Tensor;
@@ -16,19 +22,13 @@ namespace infini_train::nn::lora {
1622
// LoRA wrapper for standard Linear layer
1723
// Implements: y = Wx + b + (alpha/r) * x @ A^T @ B^T
1824
// Where W is frozen, A and B are trainable low-rank matrices
19-
class LoRALinear : public nn::CloneableModule<LoRALinear> {
25+
class LoRALinear : public nn::Linear {
2026
public:
2127
static constexpr char kType[] = "LoRALinear";
2228

23-
// Parameter names
24-
static constexpr char kParamWeightName[] = "weight"; // Frozen base weight
25-
static constexpr char kParamBiasName[] = "bias"; // Frozen base bias
26-
static constexpr char kParamLoraAName[] = "lora_A"; // Trainable A matrix [rank, in_features]
27-
static constexpr char kParamLoraBName[] = "lora_B"; // Trainable B matrix [out_features, rank]
28-
29-
// Constructor from scratch
30-
LoRALinear(int64_t in_features, int64_t out_features, const LoRAConfig &config, bool bias = true,
31-
const Device *device = nullptr);
29+
// Parameter names for LoRA-specific parameters
30+
static constexpr char kParamLoraAName[] = "lora_A"; // Trainable A matrix [rank, in_features]
31+
static constexpr char kParamLoraBName[] = "lora_B"; // Trainable B matrix [out_features, rank]
3232

3333
// Constructor wrapping existing Linear module (transfers ownership of parameters)
3434
LoRALinear(std::shared_ptr<nn::Module> base_linear, const LoRAConfig &config);
@@ -43,33 +43,29 @@ class LoRALinear : public nn::CloneableModule<LoRALinear> {
4343
// Get only LoRA parameters (for optimizer)
4444
std::vector<std::shared_ptr<Tensor>> LoRAParameters() const;
4545

46-
// Override Parameters() to return all parameters (frozen base + trainable LoRA)
47-
std::vector<std::shared_ptr<Tensor>> Parameters() const override;
48-
49-
// Get trainable parameters (requires_grad == true)
50-
std::vector<std::shared_ptr<Tensor>> TrainableParameters() const;
51-
52-
// Get all parameters including frozen base weights (for state dict)
53-
std::vector<std::shared_ptr<Tensor>> AllParameters() const;
54-
5546
// Accessors
5647
int64_t in_features() const;
5748
int64_t out_features() const;
5849
int64_t rank() const;
5950
float scaling() const;
6051

6152
private:
53+
// Test-only: Create LoRA module from scratch (normal usage goes through InjectLoRALayers)
54+
LoRALinear(int64_t in_features, int64_t out_features, const LoRAConfig &config, bool bias, const Device *device);
55+
56+
// Test access
57+
friend void ::test_lora_linear_init();
58+
friend void ::test_lora_linear_forward();
59+
friend void ::test_lora_linear_merge();
60+
friend void ::test_lora_utils();
61+
6262
void InitLoRAWeights();
6363
void FreezeBaseWeights();
6464

6565
LoRAConfig config_;
66-
int64_t in_features_;
67-
int64_t out_features_;
68-
bool bias_;
66+
int64_t in_features_ = 0;
67+
int64_t out_features_ = 0;
6968
bool merged_ = false;
70-
71-
// Store original weight for unmerge
72-
std::shared_ptr<Tensor> original_weight_;
7369
};
7470

7571
} // namespace infini_train::nn::lora

infini_train/include/nn/lora/lora_parallel_linear.h

Lines changed: 15 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include <vector>
55

66
#include "infini_train/include/nn/lora/lora_config.h"
7-
#include "infini_train/include/nn/modules/module.h"
7+
#include "infini_train/include/nn/parallel/tensor_parallel.h"
88

99
namespace infini_train {
1010
class Tensor;
@@ -18,21 +18,19 @@ namespace infini_train::nn::lora {
1818
// LoRA A: [rank, in_features] - replicated across TP ranks (implemented as Linear)
1919
// LoRA B: [out_features_per_partition, rank] - sharded like base weight (implemented as ColumnParallelLinear with
2020
// gather_output)
21-
class LoRAColumnParallelLinear : public nn::CloneableModule<LoRAColumnParallelLinear> {
21+
class LoRAColumnParallelLinear : public nn::parallel::ColumnParallelLinear {
2222
public:
2323
static constexpr char kType[] = "LoRAColumnParallelLinear";
2424

25-
static constexpr char kParamWeightName[] = "weight";
26-
static constexpr char kParamBiasName[] = "bias";
2725
static constexpr char kParamLoraAName[] = "lora_A";
2826
static constexpr char kParamLoraBName[] = "lora_B";
2927

3028
// Constructor wrapping existing ColumnParallelLinear
31-
LoRAColumnParallelLinear(std::shared_ptr<nn::Module> base_module, const LoRAConfig &config, int64_t in_features,
32-
int64_t out_features);
29+
LoRAColumnParallelLinear(std::shared_ptr<parallel::ColumnParallelLinear> base_module, const LoRAConfig &config,
30+
int64_t in_features, int64_t out_features);
3331

3432
// Constructor wrapping existing ColumnParallelLinear (auto-infer dimensions from weight)
35-
LoRAColumnParallelLinear(std::shared_ptr<nn::Module> base_module, const LoRAConfig &config);
33+
LoRAColumnParallelLinear(std::shared_ptr<parallel::ColumnParallelLinear> base_module, const LoRAConfig &config);
3634

3735
std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
3836

@@ -41,9 +39,6 @@ class LoRAColumnParallelLinear : public nn::CloneableModule<LoRAColumnParallelLi
4139
bool IsMerged() const;
4240

4341
std::vector<std::shared_ptr<Tensor>> LoRAParameters() const;
44-
std::vector<std::shared_ptr<Tensor>> Parameters() const override;
45-
std::vector<std::shared_ptr<Tensor>> TrainableParameters() const;
46-
std::vector<std::shared_ptr<Tensor>> AllParameters() const;
4742

4843
int64_t in_features() const;
4944
int64_t out_features() const;
@@ -54,39 +49,28 @@ class LoRAColumnParallelLinear : public nn::CloneableModule<LoRAColumnParallelLi
5449
void FreezeBaseWeights();
5550

5651
LoRAConfig config_;
57-
int64_t in_features_;
58-
int64_t out_features_;
59-
int64_t out_features_per_partition_;
60-
bool bias_;
61-
bool gather_output_;
62-
bool input_is_parallel_;
63-
bool skip_bias_add_;
64-
bool sequence_parallel_;
52+
int64_t in_features_ = 0;
53+
int64_t out_features_ = 0;
54+
int64_t out_features_per_partition_ = 0;
6555
bool merged_ = false;
66-
67-
std::shared_ptr<Tensor> original_weight_;
68-
std::shared_ptr<nn::Module> base_module_; // Not registered in modules_ to avoid double-counting
6956
};
7057

7158
// LoRA wrapper for RowParallelLinear
7259
// Weight shape: [out_features, in_features_per_partition]
7360
// LoRA A: [rank, in_features_per_partition] - sharded like base weight (implemented as RowParallelLinear with
7461
// input_is_parallel) LoRA B: [out_features, rank] - replicated (implemented as Linear)
75-
class LoRARowParallelLinear : public nn::CloneableModule<LoRARowParallelLinear> {
62+
class LoRARowParallelLinear : public nn::parallel::RowParallelLinear {
7663
public:
7764
static constexpr char kType[] = "LoRARowParallelLinear";
78-
79-
static constexpr char kParamWeightName[] = "weight";
80-
static constexpr char kParamBiasName[] = "bias";
8165
static constexpr char kParamLoraAName[] = "lora_A";
8266
static constexpr char kParamLoraBName[] = "lora_B";
8367

8468
// Constructor wrapping existing RowParallelLinear
85-
LoRARowParallelLinear(std::shared_ptr<nn::Module> base_module, const LoRAConfig &config, int64_t in_features,
86-
int64_t out_features);
69+
LoRARowParallelLinear(std::shared_ptr<parallel::RowParallelLinear> base_module, const LoRAConfig &config,
70+
int64_t in_features, int64_t out_features);
8771

8872
// Constructor wrapping existing RowParallelLinear (auto-infer dimensions from weight)
89-
LoRARowParallelLinear(std::shared_ptr<nn::Module> base_module, const LoRAConfig &config);
73+
LoRARowParallelLinear(std::shared_ptr<parallel::RowParallelLinear> base_module, const LoRAConfig &config);
9074

9175
std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
9276

@@ -95,9 +79,6 @@ class LoRARowParallelLinear : public nn::CloneableModule<LoRARowParallelLinear>
9579
bool IsMerged() const;
9680

9781
std::vector<std::shared_ptr<Tensor>> LoRAParameters() const;
98-
std::vector<std::shared_ptr<Tensor>> Parameters() const override;
99-
std::vector<std::shared_ptr<Tensor>> TrainableParameters() const;
100-
std::vector<std::shared_ptr<Tensor>> AllParameters() const;
10182

10283
int64_t in_features() const;
10384
int64_t out_features() const;
@@ -108,18 +89,10 @@ class LoRARowParallelLinear : public nn::CloneableModule<LoRARowParallelLinear>
10889
void FreezeBaseWeights();
10990

11091
LoRAConfig config_;
111-
int64_t in_features_;
112-
int64_t out_features_;
113-
int64_t in_features_per_partition_;
114-
bool bias_;
115-
bool reduce_output_;
116-
bool input_is_parallel_;
117-
bool skip_bias_add_;
118-
bool sequence_parallel_;
92+
int64_t in_features_ = 0;
93+
int64_t out_features_ = 0;
94+
int64_t in_features_per_partition_ = 0;
11995
bool merged_ = false;
120-
121-
std::shared_ptr<Tensor> original_weight_;
122-
std::shared_ptr<nn::Module> base_module_; // Not registered in modules_ to avoid double-counting
12396
};
12497

12598
} // namespace infini_train::nn::lora

infini_train/include/nn/lora/lora_utils.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,19 @@ void MergeLoRAWeights(std::shared_ptr<Module> model);
7474
*/
7575
void UnmergeLoRAWeights(std::shared_ptr<Module> model);
7676

77+
/**
78+
* Merge LoRA weights and remove LoRA modules, returning a clean base model.
79+
* Similar to PEFT's merge_and_unload().
80+
*
81+
* For each LoRA module:
82+
* 1. Merge weights: W += (alpha/r) * B @ A
83+
* 2. Replace LoRA module with a base module sharing the merged weight/bias
84+
*
85+
* After this call, the model contains no LoRA parameters.
86+
* Root module may be replaced (same pattern as InjectLoRALayers).
87+
*/
88+
std::shared_ptr<Module> MergeAndUnload(std::shared_ptr<Module> model);
89+
7790
/**
7891
* Return a state dict containing only LoRA parameters.
7992
*/

infini_train/include/nn/modules/linear.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ class Linear : public CloneableModule<Linear> {
2222
Linear(int64_t in_features, int64_t out_features, bool bias = true, Device device = Device());
2323
std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
2424

25+
bool has_bias() const { return bias_; }
26+
2527
private:
2628
void ResetParameters();
2729
bool bias_ = true;

infini_train/include/nn/modules/module.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,19 +47,17 @@ class Module : public std::enable_shared_from_this<Module> {
4747

4848
const std::string &type() const;
4949

50+
// TODO: Change return type to filterable iterator (like PyTorch's named_parameters with prefix matching)
5051
virtual std::vector<std::shared_ptr<Tensor>> Parameters() const;
51-
// Get parameters with requires_grad == true (trainable parameters)
52-
std::vector<std::shared_ptr<Tensor>> TrainableParameters() const;
5352
bool has_parameter(const std::string &name) const;
5453
std::shared_ptr<Tensor> *mutable_parameter(const std::string &name);
5554
const std::shared_ptr<Tensor> &parameter(const std::string &name) const;
5655

5756
virtual std::vector<std::shared_ptr<Tensor>> Buffers() const;
5857

5958
std::vector<std::shared_ptr<Module>> modules();
60-
std::shared_ptr<Module> mutable_module(const std::string &name);
59+
std::shared_ptr<Module> &mutable_module(const std::string &name);
6160
const Module &module(const std::string &name) const;
62-
void replace_module(const std::string &name, std::shared_ptr<Module> new_module);
6361

6462
std::unordered_map<std::string, std::shared_ptr<Tensor>> StateDict() const;
6563

infini_train/include/nn/parallel/tensor_parallel.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class ColumnParallelLinear : public nn::CloneableModule<ColumnParallelLinear> {
3737
bool skip_bias_add() const;
3838
bool sequence_parallel() const;
3939

40-
private:
40+
protected:
4141
bool bias_ = true;
4242
bool gather_output_ = false; // whether to return full local output tensor after forward (need gather)
4343
bool input_is_parallel_ = false; // will perform an autograd-aware copy when false
@@ -66,7 +66,7 @@ class RowParallelLinear : public nn::CloneableModule<RowParallelLinear> {
6666
bool skip_bias_add() const;
6767
bool sequence_parallel() const;
6868

69-
private:
69+
protected:
7070
bool bias_ = true;
7171
bool reduce_output_ = false; // whether to return full local output tensor after forward (need reduce)
7272
bool input_is_parallel_ = false; // will perform an autograd-aware copy when false

0 commit comments

Comments
 (0)