44#include < vector>
55
66#include " infini_train/include/nn/lora/lora_config.h"
7- #include " infini_train/include/nn/modules/module .h"
7+ #include " infini_train/include/nn/parallel/tensor_parallel .h"
88
99namespace infini_train {
1010class Tensor ;
@@ -18,21 +18,19 @@ namespace infini_train::nn::lora {
1818// LoRA A: [rank, in_features] - replicated across TP ranks (implemented as Linear)
1919// LoRA B: [out_features_per_partition, rank] - sharded like base weight (implemented as ColumnParallelLinear with
2020// gather_output)
21- class LoRAColumnParallelLinear : public nn ::CloneableModule<LoRAColumnParallelLinear> {
21+ class LoRAColumnParallelLinear : public nn ::parallel::ColumnParallelLinear {
2222public:
2323 static constexpr char kType [] = " LoRAColumnParallelLinear" ;
2424
25- static constexpr char kParamWeightName [] = " weight" ;
26- static constexpr char kParamBiasName [] = " bias" ;
2725 static constexpr char kParamLoraAName [] = " lora_A" ;
2826 static constexpr char kParamLoraBName [] = " lora_B" ;
2927
3028 // Constructor wrapping existing ColumnParallelLinear
31- LoRAColumnParallelLinear (std::shared_ptr<nn::Module > base_module, const LoRAConfig &config, int64_t in_features ,
32- int64_t out_features);
29+ LoRAColumnParallelLinear (std::shared_ptr<parallel::ColumnParallelLinear > base_module, const LoRAConfig &config,
30+ int64_t in_features, int64_t out_features);
3331
3432 // Constructor wrapping existing ColumnParallelLinear (auto-infer dimensions from weight)
35- LoRAColumnParallelLinear (std::shared_ptr<nn::Module > base_module, const LoRAConfig &config);
33+ LoRAColumnParallelLinear (std::shared_ptr<parallel::ColumnParallelLinear > base_module, const LoRAConfig &config);
3634
3735 std::vector<std::shared_ptr<Tensor>> Forward (const std::vector<std::shared_ptr<Tensor>> &input_tensors) override ;
3836
@@ -41,9 +39,6 @@ class LoRAColumnParallelLinear : public nn::CloneableModule<LoRAColumnParallelLi
4139 bool IsMerged () const ;
4240
4341 std::vector<std::shared_ptr<Tensor>> LoRAParameters () const ;
44- std::vector<std::shared_ptr<Tensor>> Parameters () const override ;
45- std::vector<std::shared_ptr<Tensor>> TrainableParameters () const ;
46- std::vector<std::shared_ptr<Tensor>> AllParameters () const ;
4742
4843 int64_t in_features () const ;
4944 int64_t out_features () const ;
@@ -54,39 +49,28 @@ class LoRAColumnParallelLinear : public nn::CloneableModule<LoRAColumnParallelLi
5449 void FreezeBaseWeights ();
5550
5651 LoRAConfig config_;
57- int64_t in_features_;
58- int64_t out_features_;
59- int64_t out_features_per_partition_;
60- bool bias_;
61- bool gather_output_;
62- bool input_is_parallel_;
63- bool skip_bias_add_;
64- bool sequence_parallel_;
52+ int64_t in_features_ = 0 ;
53+ int64_t out_features_ = 0 ;
54+ int64_t out_features_per_partition_ = 0 ;
6555 bool merged_ = false ;
66-
67- std::shared_ptr<Tensor> original_weight_;
68- std::shared_ptr<nn::Module> base_module_; // Not registered in modules_ to avoid double-counting
6956};
7057
7158// LoRA wrapper for RowParallelLinear
7259// Weight shape: [out_features, in_features_per_partition]
7360// LoRA A: [rank, in_features_per_partition] - sharded like base weight (implemented as RowParallelLinear with
7461// input_is_parallel) LoRA B: [out_features, rank] - replicated (implemented as Linear)
75- class LoRARowParallelLinear : public nn ::CloneableModule<LoRARowParallelLinear> {
62+ class LoRARowParallelLinear : public nn ::parallel::RowParallelLinear {
7663public:
7764 static constexpr char kType [] = " LoRARowParallelLinear" ;
78-
79- static constexpr char kParamWeightName [] = " weight" ;
80- static constexpr char kParamBiasName [] = " bias" ;
8165 static constexpr char kParamLoraAName [] = " lora_A" ;
8266 static constexpr char kParamLoraBName [] = " lora_B" ;
8367
8468 // Constructor wrapping existing RowParallelLinear
85- LoRARowParallelLinear (std::shared_ptr<nn::Module > base_module, const LoRAConfig &config, int64_t in_features ,
86- int64_t out_features);
69+ LoRARowParallelLinear (std::shared_ptr<parallel::RowParallelLinear > base_module, const LoRAConfig &config,
70+ int64_t in_features, int64_t out_features);
8771
8872 // Constructor wrapping existing RowParallelLinear (auto-infer dimensions from weight)
89- LoRARowParallelLinear (std::shared_ptr<nn::Module > base_module, const LoRAConfig &config);
73+ LoRARowParallelLinear (std::shared_ptr<parallel::RowParallelLinear > base_module, const LoRAConfig &config);
9074
9175 std::vector<std::shared_ptr<Tensor>> Forward (const std::vector<std::shared_ptr<Tensor>> &input_tensors) override ;
9276
@@ -95,9 +79,6 @@ class LoRARowParallelLinear : public nn::CloneableModule<LoRARowParallelLinear>
9579 bool IsMerged () const ;
9680
9781 std::vector<std::shared_ptr<Tensor>> LoRAParameters () const ;
98- std::vector<std::shared_ptr<Tensor>> Parameters () const override ;
99- std::vector<std::shared_ptr<Tensor>> TrainableParameters () const ;
100- std::vector<std::shared_ptr<Tensor>> AllParameters () const ;
10182
10283 int64_t in_features () const ;
10384 int64_t out_features () const ;
@@ -108,18 +89,10 @@ class LoRARowParallelLinear : public nn::CloneableModule<LoRARowParallelLinear>
10889 void FreezeBaseWeights ();
10990
11091 LoRAConfig config_;
111- int64_t in_features_;
112- int64_t out_features_;
113- int64_t in_features_per_partition_;
114- bool bias_;
115- bool reduce_output_;
116- bool input_is_parallel_;
117- bool skip_bias_add_;
118- bool sequence_parallel_;
92+ int64_t in_features_ = 0 ;
93+ int64_t out_features_ = 0 ;
94+ int64_t in_features_per_partition_ = 0 ;
11995 bool merged_ = false ;
120-
121- std::shared_ptr<Tensor> original_weight_;
122- std::shared_ptr<nn::Module> base_module_; // Not registered in modules_ to avoid double-counting
12396};
12497
12598} // namespace infini_train::nn::lora
0 commit comments