Skip to content

Commit 08c0e1d

Browse files
chen2021673claude
authored andcommitted
fix(lora): fix dimension mismatch and refactor TP helper functions
- Fix RowParallel/ColumnParallel LoRA input handling to match base module behavior - Add shape-based defensive checks for TP/SP consistency - Move TP/SP communication helper function declarations to utils.h - Move getter implementations from header to .cc file - Add unit test for SaveLoRAWeights/LoadLoRAWeights functionality Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 26f6d7d commit 08c0e1d

21 files changed

Lines changed: 606 additions & 317 deletions

docs/lora_usage.md

Lines changed: 109 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
5. [API 参考](#api-参考)
1212
6. [使用示例](#使用示例)
1313
7. [最佳实践](#最佳实践)
14+
8. [常见问题](#常见问题)
1415

1516
## 快速开始
1617

@@ -22,6 +23,8 @@
2223
#include "nn/lora/lora_utils.h"
2324
// 如果使用张量并行
2425
#include "nn/lora/lora_parallel_linear.h"
26+
// 如果使用 LoRAModel 包装器
27+
#include "nn/lora/lora_model.h"
2528
```
2629

2730
### 最简示例
@@ -34,23 +37,23 @@ LoRAConfig config;
3437
config.rank = 8; // 低秩维度
3538
config.alpha = 16.0f; // 缩放因子
3639

37-
// 2. 获取 LoRA 模型
38-
auto* lora_model = GetLoRAModel(model, config);
40+
// 2. 获取 LoRA 模型 (原地修改,自动冻结基础模型)
41+
auto lora_model = GetLoRAModel(model, config);
3942

40-
// 3. 获取 LoRA 参数用于优化器
41-
auto lora_params = lora_model->TrainableParameters();
42-
auto optimizer = std::make_shared<Adam>(lora_params, lr);
43+
// 3. 获取可训练参数用于优化器
44+
auto trainable_params = lora_model->TrainableParameters();
45+
auto optimizer = std::make_shared<Adam>(trainable_params, lr);
4346

4447
// 4. 训练循环
4548
for (int step = 0; step < num_steps; ++step) {
46-
auto loss = (*model)(inputs);
49+
auto loss = (*lora_model)(inputs);
4750
loss->Backward();
4851
optimizer->Step();
4952
optimizer->ZeroGrad();
5053
}
5154

52-
// 6. 保存 LoRA 权重
53-
SaveLoRAWeights(model, "lora_weights.bin");
55+
// 5. 保存 LoRA 权重
56+
SaveLoRAWeights(lora_model, "lora_weights.bin");
5457
```
5558
5659
## 核心概念
@@ -134,20 +137,22 @@ std::shared_ptr<LoRAModel> CreateLoRAModel(
134137
```cpp
135138
struct LoRAConfig {
136139
int64_t rank = 8; // 低秩维度 r
137-
float alpha = 16.0f; // 缩放因子 α
138-
float dropout = 0.0f; // Dropout 概率(暂未实现)
140+
float alpha = 16.0f; // 缩放因子 α
141+
float dropout = 0.0f; // Dropout 概率(暂未实现)
139142
140143
// 目标模块名称(默认只对 attention 层应用)
141-
std::unordered_set<std::string> target_modules = {"c_attn", "attn.c_proj"};
144+
// 注意:匹配模块名的后缀,而非完整路径
145+
std::unordered_set<std::string> target_modules = {"c_attn", "c_proj"};
142146
143147
// 初始化参数
144-
bool use_kaiming_a = true; // A 矩阵使用 Kaiming 初始化
145-
float kaiming_a_param = 1.0f; // Kaiming 初始化参数
148+
bool use_kaiming_a = true; // A 矩阵使用 Kaiming 初始化
149+
float kaiming_a_param = sqrtf(5.0f); // Kaiming 初始化参数 (默认值与 PyTorch 一致)
146150
147151
// 计算缩放因子
148152
float Scaling() const; // 返回 alpha / rank
149153
150154
// 检查模块是否应该应用 LoRA
155+
// 匹配规则:模块名以 target_modules 中的任意一个结尾
151156
bool ShouldApplyLoRA(const std::string &module_name) const;
152157
};
153158
```
@@ -159,7 +164,7 @@ struct LoRAConfig {
159164
PEFT-style 运行时包装器,使用 `NamedModules()` 自动遍历模型层次结构,创建 LoRA 模型。
160165

161166
```cpp
162-
LoRAModel* GetLoRAModel(
167+
std::shared_ptr<Module> GetLoRAModel(
163168
std::shared_ptr<Module> model, // 目标模型
164169
const LoRAConfig &config // LoRA 配置
165170
);
@@ -169,17 +174,18 @@ LoRAModel* GetLoRAModel(
169174
- `model`: 要包装的模型
170175
- `config`: LoRA 配置(包含 `target_modules` 指定目标层)
171176

172-
**返回值:** `LoRAModel*`,可用于调用 `LoadLoRA()`, `SaveLoRA()`, `PrintSummary()` 等方法
177+
**返回值:** `std::shared_ptr<Module>`,返回原地修改后的模型(已注入 LoRA 层并冻结基础模型)
173178

174179
**使用示例:**
175180
```cpp
176181
// 配置 LoRA
177182
LoRAConfig config{8, 16.0f};
178-
config.SetTargetModules("c_attn,attn.c_proj"); // 只对 attention
179-
// config.SetTargetModules("c_attn,attn.c_proj,c_fc,c_fc2,mlp.c_proj"); // 包含 MLP
183+
// 使用 ParseLoRATargetModules 解析逗号分隔的字符串
184+
config.target_modules = ParseLoRATargetModules("c_attn,c_proj"); // 只对 attention
185+
// config.target_modules = ParseLoRATargetModules("c_attn,c_proj,c_fc,c_proj"); // 包含 MLP
180186

181-
// 一行启用 LoRA
182-
auto* lora_model = nn::lora::GetLoRAModel(model, config);
187+
// 一行启用 LoRA (原地修改,自动冻结基础模型)
188+
auto lora_model = nn::lora::GetLoRAModel(model, config);
183189
```
184190
185191
#### InjectLoRALayers
@@ -271,6 +277,46 @@ int64_t CountTrainableParameters(const std::shared_ptr<Module> &model);
271277
int64_t CountTotalParameters(const std::shared_ptr<Module> &model);
272278
```
273279
280+
### 工具函数
281+
282+
```cpp
283+
// 解析逗号分隔的目标模块字符串
284+
std::unordered_set<std::string> ParseLoRATargetModules(const std::string &targets);
285+
286+
// 示例: "c_attn,c_proj" -> {"c_attn", "c_proj"}
287+
```
288+
289+
### 张量并行 LoRA 类
290+
291+
当使用张量并行 (TP) 时,`GetLoRAModel` 会自动检测并使用对应的 LoRA 包装器:
292+
293+
```cpp
294+
// LoRA for ColumnParallelLinear (e.g., QKV projection)
295+
// LoRA A: [rank, in_features] - replicated across TP ranks
296+
// LoRA B: [out_features_per_partition, rank] - sharded like base weight
297+
class LoRAColumnParallelLinear;
298+
299+
// LoRA for RowParallelLinear (e.g., output projection)
300+
// LoRA A: [rank, in_features_per_partition] - sharded like base weight
301+
// LoRA B: [out_features, rank] - replicated across TP ranks
302+
class LoRARowParallelLinear;
303+
```
304+
305+
**注意**: 使用张量并行时无需手动创建这些类,`GetLoRAModel` 会自动处理。
306+
307+
### TP=1 自动退化
308+
309+
`ColumnParallelLinear``RowParallelLinear` 在 TP=1 时会自动退化为普通 Linear,无需在模型代码中条件分支:
310+
311+
```cpp
312+
// 模型代码可以统一使用 ColumnParallelLinear/RowParallelLinear
313+
// TP=1 时自动走 fast-path,等价于普通 Linear
314+
modules_["c_attn"] = std::make_shared<nn::parallel::ColumnParallelLinear>(...);
315+
modules_["c_proj"] = std::make_shared<nn::parallel::RowParallelLinear>(...);
316+
```
317+
318+
这使得 LoRA 包装可以统一工作,无论是否使用张量并行。
319+
274320
## 使用示例
275321

276322
### 示例 1: GPT2 微调
@@ -290,10 +336,10 @@ int main() {
290336
LoRAConfig lora_config;
291337
lora_config.rank = 8;
292338
lora_config.alpha = 16.0f;
293-
lora_config.SetTargetModules("c_attn,attn.c_proj"); // 只对 attention 层
339+
lora_config.target_modules = ParseLoRATargetModules("c_attn,c_proj"); // 只对 attention 层
294340

295-
// 获取 LoRA 模型
296-
auto* lora_model = GetLoRAModel(model, lora_config);
341+
// 获取 LoRA 模型 (原地修改,自动冻结基础模型)
342+
auto lora_model = GetLoRAModel(model, lora_config);
297343

298344
// 打印参数统计
299345
PrintLoRASummary(lora_model);
@@ -305,8 +351,8 @@ int main() {
305351
// =========================================
306352

307353
// 创建优化器(只优化 LoRA 参数)
308-
auto lora_params = lora_model->TrainableParameters();
309-
auto optimizer = std::make_shared<Adam>(lora_params, /*lr=*/1e-4);
354+
auto trainable_params = lora_model->TrainableParameters();
355+
auto optimizer = std::make_shared<Adam>(trainable_params, /*lr=*/1e-4);
310356

311357
// 训练循环
312358
for (int step = 0; step < num_steps; ++step) {
@@ -321,7 +367,7 @@ int main() {
321367
}
322368

323369
// 保存 LoRA 权重(仅几 MB)
324-
lora_model->SaveLoRA("gpt2_lora.bin");
370+
SaveLoRAWeights(lora_model, "gpt2_lora.bin");
325371

326372
return 0;
327373
}
@@ -350,18 +396,18 @@ int main(int argc, char **argv) {
350396
351397
// 配置 LoRA(包含 MLP 层以获得更好效果)
352398
LoRAConfig lora_config{16, 32.0f};
353-
lora_config.SetTargetModules("c_attn,attn.c_proj,c_fc,c_fc2,mlp.c_proj");
399+
lora_config.target_modules = ParseLoRATargetModules("c_attn,c_proj,c_fc,c_fc2,c_proj");
354400
355-
// 获取 LoRA 模型(通过 target_modules 配置包含 MLP 层
356-
auto* lora_model = GetLoRAModel(model, lora_config);
401+
// 获取 LoRA 模型(原地修改,自动冻结基础模型
402+
auto lora_model = GetLoRAModel(model, lora_config);
357403
358404
PrintLoRASummary(lora_model);
359405
360406
// 训练...
361407
362408
// 保存
363409
if (GetRank() == 0) {
364-
SaveLoRAWeights(model, "llama3_lora.bin");
410+
SaveLoRAWeights(lora_model, "llama3_lora.bin");
365411
}
366412
367413
return 0;
@@ -375,30 +421,37 @@ int main(int argc, char **argv) {
375421
auto model = std::make_shared<GPT2>(config);
376422
model->LoadWeights("gpt2_weights.bin");
377423

378-
// 获取 LoRA 模型
379-
auto* lora_model = GetLoRAModel(model, lora_config);
424+
// 配置并获取 LoRA 模型
425+
LoRAConfig lora_config;
426+
lora_config.rank = 8;
427+
lora_config.alpha = 16.0f;
428+
lora_config.target_modules = ParseLoRATargetModules("c_attn,c_proj");
429+
auto lora_model = GetLoRAModel(model, lora_config);
380430

381431
// 加载 LoRA 权重
382-
lora_model->LoadLoRA("gpt2_lora.bin");
432+
LoadLoRAWeights(lora_model, "gpt2_lora.bin");
383433

384434
// 合并权重(推理时无额外开销)
385-
lora_model->Merge();
435+
MergeLoRAWeights(lora_model);
386436

387437
// 现在可以像普通模型一样推理
388438
auto output = (*lora_model)({input_ids});
389439

390440
// 如果需要继续训练,先解除合并
391-
lora_model->Unmerge();
441+
UnmergeLoRAWeights(lora_model);
392442
```
393443
394444
### 示例 4: 自定义目标层
395445
396446
```cpp
397-
// 或者对所有线性层应用
398-
config.SetTargetModules("c_attn,attn.c_proj,c_fc,c_fc2,mlp.c_proj,lm_head");
447+
// 对所有线性层应用
448+
LoRAConfig config;
449+
config.rank = 8;
450+
config.alpha = 16.0f;
451+
config.target_modules = ParseLoRATargetModules("c_attn,c_proj,c_fc,c_proj,lm_head");
399452
400453
// 获取 LoRA 模型
401-
auto* lora_model = GetLoRAModel(model, config);
454+
auto lora_model = GetLoRAModel(model, config);
402455
```
403456

404457
## 最佳实践
@@ -421,10 +474,10 @@ auto* lora_model = GetLoRAModel(model, config);
421474

422475
```cpp
423476
// 推荐:只对 attention 层(参数效率最高)
424-
config.SetTargetModules("c_attn,attn.c_proj");
477+
config.target_modules = ParseLoRATargetModules("c_attn,c_proj");
425478

426479
// 可选:包含 MLP 层(效果可能更好,但参数更多)
427-
config.SetTargetModules("c_attn,attn.c_proj,c_fc,c_fc2,mlp.c_proj");
480+
config.target_modules = ParseLoRATargetModules("c_attn,c_proj,c_fc,c_fc2,c_proj");
428481
```
429482

430483
### 4. 学习率
@@ -490,7 +543,7 @@ lm_head # Language Model Head
490543
--learning_rate 1e-5 \
491544
--lora_rank 8 \
492545
--lora_alpha 16.0 \
493-
--lora_target_modules "c_attn,attn.c_proj" \
546+
--lora_target_modules "c_attn,c_proj" \
494547
--lora_save_path data/lora_weights
495548
```
496549

@@ -500,7 +553,7 @@ lm_head # Language Model Head
500553
|------|--------|------|
501554
| `--lora_rank` | 0 | LoRA 秩 (0 = 禁用) |
502555
| `--lora_alpha` | 16.0 | LoRA 缩放因子 |
503-
| `--lora_target_modules` | "c_attn,attn.c_proj" | 目标模块 (逗号分隔: c_attn,attn.c_proj,c_fc,c_fc2,mlp.c_proj) |
556+
| `--lora_target_modules` | "c_attn,c_proj" | 目标模块 (逗号分隔: c_attn,c_proj,c_fc,c_proj) |
504557
| `--lora_load_path` | "" | 加载已有 LoRA 权重 |
505558
| `--lora_save_path` | "" | 保存 LoRA 权重路径 |
506559

@@ -540,7 +593,7 @@ int main() {
540593

541594
// 2. 创建 LoRA 配置
542595
LoRAConfig lora_config{8, 16.0f};
543-
lora_config.SetTargetModules("c_attn,attn.c_proj"); // 只对 attention 层
596+
lora_config.target_modules = ParseLoRATargetModules("c_attn,c_proj"); // 只对 attention 层
544597

545598
// 3. 创建 LoRA 包装器 (一行代码)
546599
auto lora_model = std::make_shared<LoRAModel>(base_model, lora_config);
@@ -587,6 +640,20 @@ auto lora_model = CreateLoRAModel<GPT2, GPT2Config>(
587640
);
588641
```
589642

643+
### 两种使用方式的区别
644+
645+
| 特性 | `GetLoRAModel` | `LoRAModel` 包装器 |
646+
|------|---------------|-------------------|
647+
| 返回类型 | `std::shared_ptr<Module>` | `std::shared_ptr<LoRAModel>` |
648+
| 修改方式 | 原地修改模型 | 创建新包装器 |
649+
| 自动冻结 |||
650+
| 适用场景 | 简单场景,直接修改原模型 | 需要更精细控制 |
651+
652+
### 推荐场景
653+
654+
- **使用 `GetLoRAModel`**: 想要最小化代码改动,直接在原模型上启用 LoRA
655+
- **使用 `LoRAModel`**: 需要更灵活的 API(如 `Merge()`/`Unmerge()` 方法),或者需要保留原始模型的引用
656+
590657
## 常见问题
591658

592659
### Q: LoRA 权重文件有多大?

example/gpt2/net.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,11 @@ GPT2::GPT2(const GPT2Config &config)
307307
modules_[kTransformerLayerName] = std::make_shared<nn::ModuleDict>(std::move(transformer));
308308

309309
// FIXME(jym): Assigning the parameter values of wte to LMHead, which is not real tying operation
310+
// TODO: Implement real GPT-2 weight tying: make lm_head.weight share the exact same Parameter/Tensor (same
311+
// shared_ptr/storage) as transformer.wte.weight (pointer aliasing, not value copy), and ensure the tie is applied
312+
// after loading weights so it won't be overwritten. Also fix GPT2::FromLLMC() loading logic to respect weight tying
313+
// (do not create/load a separate lm_head.weight tensor; load once into the tied weight) so parameter counting
314+
// matches PyTorch/PEFT.
310315
if (nn::parallel::global::GetPipelineParallelSize() == 1) {
311316
// https://paperswithcode.com/method/weight-tying
312317
*mutable_module(kTransformerLayerName)

example/llama3/main.cc

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -263,15 +263,8 @@ void Train(const nn::parallel::Rank &rank) {
263263
auto optimizer_creator = optimizers::Adam::Create(FLAGS_learning_rate);
264264
std::shared_ptr<Optimizer> optimizer = nullptr;
265265

266-
// Create optimizer - use GetLoRAParameters() if LoRA is enabled
267-
std::vector<std::shared_ptr<Tensor>> params_to_optimize;
268-
if (lora_enabled) {
269-
params_to_optimize = nn::lora::GetLoRAParameters(model);
270-
LOG(INFO) << "Optimizing " << params_to_optimize.size() << " LoRA parameters";
271-
} else {
272-
params_to_optimize = model->Parameters();
273-
LOG(INFO) << "Optimizing " << params_to_optimize.size() << " model parameters";
274-
}
266+
// Create optimizer - use TrainableParameters() as single source of truth
267+
std::vector<std::shared_ptr<Tensor>> params_to_optimize = model->TrainableParameters();
275268

276269
if (FLAGS_use_distributed_optimizer) {
277270
auto model_chunks = (pp_world_size > 1)

infini_train/include/nn/lora/lora_config.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include <cmath>
34
#include <cstdint>
45
#include <string>
56
#include <unordered_set>
@@ -17,8 +18,8 @@ struct LoRAConfig {
1718
std::unordered_set<std::string> target_modules = {"c_attn", "c_proj"};
1819

1920
// Initialization parameters
20-
bool use_kaiming_a = true; // Use Kaiming init for A matrix
21-
float kaiming_a_param = 1.0f; // Parameter 'a' for Kaiming init
21+
bool use_kaiming_a = true; // Use Kaiming init for A matrix
22+
float kaiming_a_param = sqrtf(5.0f); // Parameter 'a' for Kaiming init
2223

2324
// Default constructor (uses default target_modules = {"c_attn", "c_proj"})
2425
LoRAConfig() = default;

infini_train/include/nn/lora/lora_linear.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,12 @@ class LoRALinear : public nn::CloneableModule<LoRALinear> {
4343
// Get only LoRA parameters (for optimizer)
4444
std::vector<std::shared_ptr<Tensor>> LoRAParameters() const;
4545

46-
// Override Parameters() to return only trainable (LoRA) parameters
46+
// Override Parameters() to return all parameters (frozen base + trainable LoRA)
4747
std::vector<std::shared_ptr<Tensor>> Parameters() const override;
4848

49+
// Get trainable parameters (requires_grad == true)
50+
std::vector<std::shared_ptr<Tensor>> TrainableParameters() const;
51+
4952
// Get all parameters including frozen base weights (for state dict)
5053
std::vector<std::shared_ptr<Tensor>> AllParameters() const;
5154

0 commit comments

Comments
 (0)