Skip to content

Commit c614ec6

Browse files
committed
fix: broadcast lora_A init from TP rank 0 to ensure consistent replicated weights
1 parent 98dc840 commit c614ec6

3 files changed

Lines changed: 53 additions & 23 deletions

File tree

infini_train/src/nn/lora/lora_parallel_linear.cc

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "infini_train/include/nn/init.h"
1212
#include "infini_train/include/nn/modules/linear.h"
1313
#include "infini_train/include/nn/parallel/global.h"
14+
#include "infini_train/include/nn/parallel/process_group.h"
1415
#include "infini_train/include/nn/parallel/tensor_parallel.h"
1516
#include "infini_train/include/nn/parallel/utils.h"
1617
#include "infini_train/include/tensor.h"
@@ -89,22 +90,38 @@ LoRAColumnParallelLinear::LoRAColumnParallelLinear(std::shared_ptr<parallel::Col
8990
}
9091

9192
void LoRAColumnParallelLinear::InitLoRAWeights() {
92-
// LoRA weights stored directly in parameters_
93-
// Following PEFT pattern conceptually:
94-
// lora_A: [rank, in_features] - replicated
93+
// lora_A: [rank, in_features] - replicated across TP ranks
9594
// lora_B: [out_features_per_partition, rank] - sharded like base weight
96-
97-
// lora_A: [rank, in_features]
9895
parameters_[kParamLoraAName]
9996
= std::make_shared<Tensor>(std::vector<int64_t>{config_.rank, in_features_}, DataType::kFLOAT32, device_)
10097
->RequiresGrad();
101-
if (config_.use_kaiming_a) {
102-
init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param);
98+
99+
if (parallel::global::GetTensorParallelSize() > 1) {
100+
const auto global_rank = device_.Rank().GlobalRank();
101+
auto *tp_group = parallel::ProcessGroupFactory::Instance(device_.type())
102+
->Get(parallel::GetTensorParallelProcessGroupName(global_rank));
103+
const int tp_rank = tp_group->GetGroupRank(global_rank);
104+
105+
// Only TP rank 0 generates random values; others zero-init.
106+
// AllReduce(sum) then broadcasts rank-0's values to all TP ranks.
107+
if (tp_rank == 0) {
108+
if (config_.use_kaiming_a) {
109+
init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param);
110+
} else {
111+
init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f);
112+
}
113+
} else {
114+
init::Zeros(parameters_[kParamLoraAName]);
115+
}
116+
tp_group->AllReduce(parameters_[kParamLoraAName]);
103117
} else {
104-
init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f);
118+
if (config_.use_kaiming_a) {
119+
init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param);
120+
} else {
121+
init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f);
122+
}
105123
}
106124

107-
// lora_B: [out_per_partition, rank] - sharded like base weight
108125
parameters_[kParamLoraBName]
109126
= std::make_shared<Tensor>(std::vector<int64_t>{out_features_per_partition_, config_.rank}, DataType::kFLOAT32,
110127
device_)

scripts/run_models_and_profile.bash

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,9 @@ run_and_log() {
154154
> "$log_path"
155155
fi
156156

157-
# Write the current run command to the log
158-
echo "[COMMAND] $cmd" >> "$log_path"
157+
# Write the current run command to the log (expand $LORA_WEIGHTS_DIR)
158+
local expanded_cmd="${cmd//\$LORA_WEIGHTS_DIR/$LORA_WEIGHTS_DIR}"
159+
echo "[COMMAND] $expanded_cmd" >> "$log_path"
159160

160161
# Run the command and append both stdout and stderr to the log file
161162
if ! eval "$cmd" >> "$log_path" 2>&1; then
@@ -267,10 +268,12 @@ for ((id=0; id<num_builds; ++id)); do
267268
arg_str="$(args_string_for_test "$gi" "$ti")"
268269

269270
# gpt2
271+
LORA_WEIGHTS_DIR="$GPT2_LORA_WEIGHTS_DIR"
270272
gpt2_cmd="${prefix}./gpt2 --input_bin ${GPT2_INPUT_BIN} --llmc_filepath ${GPT2_LLMC_FILEPATH} --device cuda ${arg_str}"
271273
run_and_log "$gpt2_cmd" "gpt2_${test_id}${log_suffix}" "$profile_flag" "$group_tag"
272274

273275
# llama3
276+
LORA_WEIGHTS_DIR="$LLAMA3_LORA_WEIGHTS_DIR"
274277
llama3_cmd="${prefix}./llama3 --input_bin ${LLAMA3_INPUT_BIN} --llmc_filepath ${LLAMA3_LLMC_FILEPATH} --device cuda ${arg_str}"
275278
run_and_log "$llama3_cmd" "llama3_${test_id}${log_suffix}" "$profile_flag" "$group_tag"
276279
done

scripts/test_config.json

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
"GPT2_LLMC_FILEPATH": "/data/shared/InfiniTrain-dev/data/llmc/gpt2/gpt2_124M.bin",
66
"LLAMA3_INPUT_BIN": "/data/shared/InfiniTrain-dev/data/llmc/llama3/tinyshakespeare/tiny_shakespeare_train.bin",
77
"LLAMA3_LLMC_FILEPATH": "/data/shared/InfiniTrain-dev/data/llmc/llama3/llama3.2_1B_fp32.bin",
8+
"LLAMA3_LORA_WEIGHTS_DIR": "/data/shared/InfiniTrain-dev/data/llmc/llama3/llama3.2_1B_lora_weights_rank8_alpha16.bin",
9+
"GPT2_LORA_WEIGHTS_DIR": "/data/shared/InfiniTrain-dev/data/llmc/gpt2/gpt2_124M_lora_weights_rank8_alpha16.bin",
810
"PROFILE_LOG_DIR": "./profile_logs",
911
"LOG_DIR": "./logs",
1012
"COMPARE_LOG_DIR": ""
@@ -313,7 +315,8 @@
313315
"dtype": "float32",
314316
"lora_rank": 8,
315317
"lora_alpha": 16.0,
316-
"lora_target_modules": "c_attn,attn.c_proj"
318+
"lora_target_modules": "c_attn,attn.c_proj",
319+
"lora_load_path": "$LORA_WEIGHTS_DIR"
317320
}
318321
},
319322
{
@@ -322,7 +325,8 @@
322325
"dtype": "bfloat16",
323326
"lora_rank": 8,
324327
"lora_alpha": 16.0,
325-
"lora_target_modules": "c_attn,attn.c_proj"
328+
"lora_target_modules": "c_attn,attn.c_proj",
329+
"lora_load_path": "$LORA_WEIGHTS_DIR"
326330
}
327331
},
328332
{
@@ -334,7 +338,8 @@
334338
"total_batch_size": 5120,
335339
"lora_rank": 8,
336340
"lora_alpha": 16.0,
337-
"lora_target_modules": "c_attn,attn.c_proj"
341+
"lora_target_modules": "c_attn,attn.c_proj",
342+
"lora_load_path": "$LORA_WEIGHTS_DIR"
338343
}
339344
},
340345
{
@@ -346,7 +351,8 @@
346351
"total_batch_size": 5120,
347352
"lora_rank": 8,
348353
"lora_alpha": 16.0,
349-
"lora_target_modules": "c_attn,attn.c_proj"
354+
"lora_target_modules": "c_attn,attn.c_proj",
355+
"lora_load_path": "$LORA_WEIGHTS_DIR"
350356
}
351357
},
352358
{
@@ -359,7 +365,8 @@
359365
"total_batch_size": 5120,
360366
"lora_rank": 8,
361367
"lora_alpha": 16.0,
362-
"lora_target_modules": "c_attn,attn.c_proj"
368+
"lora_target_modules": "c_attn,attn.c_proj",
369+
"lora_load_path": "$LORA_WEIGHTS_DIR"
363370
}
364371
},
365372
{
@@ -372,7 +379,8 @@
372379
"total_batch_size": 5120,
373380
"lora_rank": 8,
374381
"lora_alpha": 16.0,
375-
"lora_target_modules": "c_attn,attn.c_proj"
382+
"lora_target_modules": "c_attn,attn.c_proj",
383+
"lora_load_path": "$LORA_WEIGHTS_DIR"
376384
}
377385
},
378386
{
@@ -384,9 +392,10 @@
384392
"batch_size": 40,
385393
"total_batch_size": 5120,
386394
"tensor_parallel": 4,
387-
"lora_rank": 4,
388-
"lora_alpha": 8.0,
389-
"lora_target_modules": "c_attn,c_fc,c_proj"
395+
"lora_rank": 8,
396+
"lora_alpha": 16.0,
397+
"lora_target_modules": "c_attn,attn.c_proj",
398+
"lora_load_path": "$LORA_WEIGHTS_DIR"
390399
}
391400
},
392401
{
@@ -398,9 +407,10 @@
398407
"batch_size": 40,
399408
"total_batch_size": 5120,
400409
"tensor_parallel": 4,
401-
"lora_rank": 16,
402-
"lora_alpha": 32.0,
403-
"lora_target_modules": "c_attn,c_fc,c_proj"
410+
"lora_rank": 8,
411+
"lora_alpha": 16.0,
412+
"lora_target_modules": "c_attn,attn.c_proj",
413+
"lora_load_path": "$LORA_WEIGHTS_DIR"
404414
}
405415
},
406416
{

0 commit comments

Comments
 (0)