11115 . [ API 参考] ( #api-参考 )
12126 . [ 使用示例] ( #使用示例 )
13137 . [ 最佳实践] ( #最佳实践 )
14+ 8 . [ 常见问题] ( #常见问题 )
1415
1516## 快速开始
1617
2223#include " nn/lora/lora_utils.h"
2324// 如果使用张量并行
2425#include " nn/lora/lora_parallel_linear.h"
26+ // 如果使用 LoRAModel 包装器
27+ #include " nn/lora/lora_model.h"
2528```
2629
2730### 最简示例
@@ -34,23 +37,23 @@ LoRAConfig config;
3437config.rank = 8; // 低秩维度
3538config.alpha = 16.0f; // 缩放因子
3639
37- // 2. 获取 LoRA 模型
38- auto* lora_model = GetLoRAModel(model, config);
40+ // 2. 获取 LoRA 模型 (原地修改,自动冻结基础模型)
41+ auto lora_model = GetLoRAModel(model, config);
3942
40- // 3. 获取 LoRA 参数用于优化器
41- auto lora_params = lora_model->TrainableParameters();
42- auto optimizer = std::make_shared<Adam >(lora_params , lr);
43+ // 3. 获取可训练参数用于优化器
44+ auto trainable_params = lora_model->TrainableParameters();
45+ auto optimizer = std::make_shared<Adam >(trainable_params , lr);
4346
4447// 4. 训练循环
4548for (int step = 0; step < num_steps; ++step) {
46- auto loss = (* model )(inputs);
49+ auto loss = (* lora_model )(inputs);
4750 loss->Backward();
4851 optimizer->Step();
4952 optimizer->ZeroGrad();
5053}
5154
52- // 6 . 保存 LoRA 权重
53- SaveLoRAWeights(model , "lora_weights.bin");
55+ // 5 . 保存 LoRA 权重
56+ SaveLoRAWeights(lora_model , "lora_weights.bin");
5457```
5558
5659## 核心概念
@@ -134,20 +137,22 @@ std::shared_ptr<LoRAModel> CreateLoRAModel(
134137```cpp
135138struct LoRAConfig {
136139 int64_t rank = 8; // 低秩维度 r
137- float alpha = 16.0f; // 缩放因子 α
138- float dropout = 0.0f; // Dropout 概率(暂未实现)
140+ float alpha = 16.0f; // 缩放因子 α
141+ float dropout = 0.0f; // Dropout 概率(暂未实现)
139142
140143 // 目标模块名称(默认只对 attention 层应用)
141- std::unordered_set<std::string> target_modules = {"c_attn", "attn.c_proj"};
144+ // 注意:匹配模块名的后缀,而非完整路径
145+ std::unordered_set<std::string> target_modules = {"c_attn", "c_proj"};
142146
143147 // 初始化参数
144- bool use_kaiming_a = true; // A 矩阵使用 Kaiming 初始化
145- float kaiming_a_param = 1 .0f; // Kaiming 初始化参数
148+ bool use_kaiming_a = true; // A 矩阵使用 Kaiming 初始化
149+ float kaiming_a_param = sqrtf(5 .0f); // Kaiming 初始化参数 (默认值与 PyTorch 一致)
146150
147151 // 计算缩放因子
148152 float Scaling() const; // 返回 alpha / rank
149153
150154 // 检查模块是否应该应用 LoRA
155+ // 匹配规则:模块名以 target_modules 中的任意一个结尾
151156 bool ShouldApplyLoRA(const std::string &module_name) const;
152157};
153158```
@@ -159,7 +164,7 @@ struct LoRAConfig {
159164PEFT-style 运行时包装器,使用 ` NamedModules() ` 自动遍历模型层次结构,创建 LoRA 模型。
160165
161166``` cpp
162- LoRAModel* GetLoRAModel (
167+ std::shared_ptr<Module> GetLoRAModel (
163168 std::shared_ptr<Module> model, // 目标模型
164169 const LoRAConfig &config // LoRA 配置
165170);
@@ -169,17 +174,18 @@ LoRAModel* GetLoRAModel(
169174- ` model ` : 要包装的模型
170175- ` config ` : LoRA 配置(包含 ` target_modules ` 指定目标层)
171176
172- ** 返回值:** ` LoRAModel* ` ,可用于调用 ` LoadLoRA() ` , ` SaveLoRA() ` , ` PrintSummary() ` 等方法
177+ ** 返回值:** ` std::shared_ptr<Module> ` ,返回原地修改后的模型(已注入 LoRA 层并冻结基础模型)
173178
174179** 使用示例:**
175180``` cpp
176181// 配置 LoRA
177182LoRAConfig config{8, 16.0f};
178- config.SetTargetModules("c_attn,attn.c_proj"); // 只对 attention
179- // config.SetTargetModules("c_attn,attn.c_proj,c_fc,c_fc2,mlp.c_proj"); // 包含 MLP
183+ // 使用 ParseLoRATargetModules 解析逗号分隔的字符串
184+ config.target_modules = ParseLoRATargetModules("c_attn,c_proj"); // 只对 attention
185+ // config.target_modules = ParseLoRATargetModules("c_attn,c_proj,c_fc,c_proj"); // 包含 MLP
180186
181- // 一行启用 LoRA
182- auto* lora_model = nn::lora::GetLoRAModel(model, config);
187+ // 一行启用 LoRA (原地修改,自动冻结基础模型)
188+ auto lora_model = nn::lora::GetLoRAModel(model, config);
183189```
184190
185191#### InjectLoRALayers
@@ -271,6 +277,46 @@ int64_t CountTrainableParameters(const std::shared_ptr<Module> &model);
271277int64_t CountTotalParameters(const std::shared_ptr<Module > &model);
272278```
273279
280+ ### 工具函数
281+
282+ ```cpp
283+ // 解析逗号分隔的目标模块字符串
284+ std::unordered_set<std::string> ParseLoRATargetModules(const std::string &targets);
285+
286+ // 示例: "c_attn,c_proj" -> {"c_attn", "c_proj"}
287+ ```
288+
289+ ### 张量并行 LoRA 类
290+
291+ 当使用张量并行 (TP) 时,` GetLoRAModel ` 会自动检测并使用对应的 LoRA 包装器:
292+
293+ ``` cpp
294+ // LoRA for ColumnParallelLinear (e.g., QKV projection)
295+ // LoRA A: [rank, in_features] - replicated across TP ranks
296+ // LoRA B: [out_features_per_partition, rank] - sharded like base weight
297+ class LoRAColumnParallelLinear ;
298+
299+ // LoRA for RowParallelLinear (e.g., output projection)
300+ // LoRA A: [rank, in_features_per_partition] - sharded like base weight
301+ // LoRA B: [out_features, rank] - replicated across TP ranks
302+ class LoRARowParallelLinear ;
303+ ```
304+
305+ ** 注意** : 使用张量并行时无需手动创建这些类,` GetLoRAModel ` 会自动处理。
306+
307+ ### TP=1 自动退化
308+
309+ ` ColumnParallelLinear ` 和 ` RowParallelLinear ` 在 TP=1 时会自动退化为普通 Linear,无需在模型代码中条件分支:
310+
311+ ``` cpp
312+ // 模型代码可以统一使用 ColumnParallelLinear/RowParallelLinear
313+ // TP=1 时自动走 fast-path,等价于普通 Linear
314+ modules_[" c_attn" ] = std::make_shared<nn::parallel::ColumnParallelLinear>(...);
315+ modules_[" c_proj" ] = std::make_shared<nn::parallel::RowParallelLinear>(...);
316+ ```
317+
318+ 这使得 LoRA 包装可以统一工作,无论是否使用张量并行。
319+
274320## 使用示例
275321
276322### 示例 1: GPT2 微调
@@ -290,10 +336,10 @@ int main() {
290336 LoRAConfig lora_config;
291337 lora_config.rank = 8;
292338 lora_config.alpha = 16.0f;
293- lora_config.SetTargetModules ("c_attn,attn. c_proj"); // 只对 attention 层
339+ lora_config.target_modules = ParseLoRATargetModules ("c_attn,c_proj"); // 只对 attention 层
294340
295- // 获取 LoRA 模型
296- auto* lora_model = GetLoRAModel(model, lora_config);
341+ // 获取 LoRA 模型 (原地修改,自动冻结基础模型)
342+ auto lora_model = GetLoRAModel(model, lora_config);
297343
298344 // 打印参数统计
299345 PrintLoRASummary(lora_model);
@@ -305,8 +351,8 @@ int main() {
305351 // =========================================
306352
307353 // 创建优化器(只优化 LoRA 参数)
308- auto lora_params = lora_model->TrainableParameters();
309- auto optimizer = std::make_shared<Adam>(lora_params , /*lr=*/1e-4);
354+ auto trainable_params = lora_model->TrainableParameters();
355+ auto optimizer = std::make_shared<Adam>(trainable_params , /*lr=*/1e-4);
310356
311357 // 训练循环
312358 for (int step = 0; step < num_steps; ++step) {
@@ -321,7 +367,7 @@ int main() {
321367 }
322368
323369 // 保存 LoRA 权重(仅几 MB)
324- lora_model->SaveLoRA( "gpt2_lora.bin");
370+ SaveLoRAWeights (lora_model, "gpt2_lora.bin");
325371
326372 return 0;
327373}
@@ -350,18 +396,18 @@ int main(int argc, char **argv) {
350396
351397 // 配置 LoRA(包含 MLP 层以获得更好效果)
352398 LoRAConfig lora_config{16, 32.0f};
353- lora_config.SetTargetModules ("c_attn,attn. c_proj,c_fc,c_fc2,mlp. c_proj");
399+ lora_config.target_modules = ParseLoRATargetModules ("c_attn,c_proj,c_fc,c_fc2,c_proj");
354400
355- // 获取 LoRA 模型(通过 target_modules 配置包含 MLP 层 )
356- auto* lora_model = GetLoRAModel(model, lora_config);
401+ // 获取 LoRA 模型(原地修改,自动冻结基础模型 )
402+ auto lora_model = GetLoRAModel(model, lora_config);
357403
358404 PrintLoRASummary(lora_model);
359405
360406 // 训练...
361407
362408 // 保存
363409 if (GetRank() == 0) {
364- SaveLoRAWeights(model , "llama3_lora.bin");
410+ SaveLoRAWeights(lora_model , "llama3_lora.bin");
365411 }
366412
367413 return 0;
@@ -375,30 +421,37 @@ int main(int argc, char **argv) {
375421auto model = std::make_shared<GPT2>(config);
376422model->LoadWeights ("gpt2_weights.bin");
377423
378- // 获取 LoRA 模型
379- auto* lora_model = GetLoRAModel(model, lora_config);
424+ // 配置并获取 LoRA 模型
425+ LoRAConfig lora_config;
426+ lora_config.rank = 8;
427+ lora_config.alpha = 16.0f;
428+ lora_config.target_modules = ParseLoRATargetModules("c_attn,c_proj");
429+ auto lora_model = GetLoRAModel(model, lora_config);
380430
381431// 加载 LoRA 权重
382- lora_model->LoadLoRA( "gpt2_lora.bin");
432+ LoadLoRAWeights(lora_model, "gpt2_lora.bin");
383433
384434// 合并权重(推理时无额外开销)
385- lora_model->Merge( );
435+ MergeLoRAWeights(lora_model );
386436
387437// 现在可以像普通模型一样推理
388438auto output = (* lora_model)({input_ids});
389439
390440// 如果需要继续训练,先解除合并
391- lora_model->Unmerge( );
441+ UnmergeLoRAWeights(lora_model );
392442```
393443
394444### 示例 4: 自定义目标层
395445
396446```cpp
397- // 或者对所有线性层应用
398- config.SetTargetModules("c_attn,attn.c_proj,c_fc,c_fc2,mlp.c_proj,lm_head");
447+ // 对所有线性层应用
448+ LoRAConfig config;
449+ config.rank = 8;
450+ config.alpha = 16.0f;
451+ config.target_modules = ParseLoRATargetModules("c_attn,c_proj,c_fc,c_proj,lm_head");
399452
400453// 获取 LoRA 模型
401- auto* lora_model = GetLoRAModel(model, config);
454+ auto lora_model = GetLoRAModel(model, config);
402455```
403456
404457## 最佳实践
@@ -421,10 +474,10 @@ auto* lora_model = GetLoRAModel(model, config);
421474
422475``` cpp
423476// 推荐:只对 attention 层(参数效率最高)
424- config.SetTargetModules (" c_attn,attn. c_proj" );
477+ config.target_modules = ParseLoRATargetModules (" c_attn,c_proj" );
425478
426479// 可选:包含 MLP 层(效果可能更好,但参数更多)
427- config.SetTargetModules (" c_attn,attn. c_proj,c_fc,c_fc2,mlp. c_proj" );
480+ config.target_modules = ParseLoRATargetModules (" c_attn,c_proj,c_fc,c_fc2,c_proj" );
428481```
429482
430483### 4. 学习率
@@ -490,7 +543,7 @@ lm_head # Language Model Head
490543 --learning_rate 1e-5 \
491544 --lora_rank 8 \
492545 --lora_alpha 16.0 \
493- --lora_target_modules "c_attn,attn. c_proj" \
546+ --lora_target_modules "c_attn,c_proj" \
494547 --lora_save_path data/lora_weights
495548```
496549
@@ -500,7 +553,7 @@ lm_head # Language Model Head
500553| ------| --------| ------|
501554| ` --lora_rank ` | 0 | LoRA 秩 (0 = 禁用) |
502555| ` --lora_alpha ` | 16.0 | LoRA 缩放因子 |
503- | ` --lora_target_modules ` | "c_attn,attn. c_proj" | 目标模块 (逗号分隔: c_attn,attn. c_proj,c_fc,c_fc2,mlp. c_proj) |
556+ | ` --lora_target_modules ` | "c_attn,c_proj" | 目标模块 (逗号分隔: c_attn,c_proj,c_fc,c_proj) |
504557| ` --lora_load_path ` | "" | 加载已有 LoRA 权重 |
505558| ` --lora_save_path ` | "" | 保存 LoRA 权重路径 |
506559
@@ -540,7 +593,7 @@ int main() {
540593
541594 // 2. 创建 LoRA 配置
542595 LoRAConfig lora_config{8, 16.0f};
543- lora_config.SetTargetModules ("c_attn,attn. c_proj"); // 只对 attention 层
596+ lora_config.target_modules = ParseLoRATargetModules ("c_attn,c_proj"); // 只对 attention 层
544597
545598 // 3. 创建 LoRA 包装器 (一行代码)
546599 auto lora_model = std::make_shared<LoRAModel>(base_model, lora_config);
@@ -587,6 +640,20 @@ auto lora_model = CreateLoRAModel<GPT2, GPT2Config>(
587640);
588641```
589642
643+ ### 两种使用方式的区别
644+
645+ | 特性 | ` GetLoRAModel ` | ` LoRAModel ` 包装器 |
646+ | ------| ---------------| -------------------|
647+ | 返回类型 | ` std::shared_ptr<Module> ` | ` std::shared_ptr<LoRAModel> ` |
648+ | 修改方式 | 原地修改模型 | 创建新包装器 |
649+ | 自动冻结 | 是 | 是 |
650+ | 适用场景 | 简单场景,直接修改原模型 | 需要更精细控制 |
651+
652+ ### 推荐场景
653+
654+ - ** 使用 ` GetLoRAModel ` ** : 想要最小化代码改动,直接在原模型上启用 LoRA
655+ - ** 使用 ` LoRAModel ` ** : 需要更灵活的 API(如 ` Merge() ` /` Unmerge() ` 方法),或者需要保留原始模型的引用
656+
590657## 常见问题
591658
592659### Q: LoRA 权重文件有多大?
0 commit comments