forked from ziwliu8/LLM-EDT
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
122 lines (102 loc) · 5.05 KB
/
main.py
File metadata and controls
122 lines (102 loc) · 5.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# here put the import lib
import os
import argparse
import torch
from generators.generator import CDSRRegSeq2SeqGeneratorUser
from generators.aug_generator import Aug_CDSRRegSeq2SeqGeneratorUser
from trainers.cdsr_trainer import CDSRTrainer
#from trainers.cold_cdsr_trainer import CDSRTrainer
from trainers.domain_adapter_trainer import DomainAdapterTrainer
from utils.utils import set_seed
from utils.logger import Logger
from utils.argument import *
from trainers.lora_trainer import LoRATrainer
import setproctitle
setproctitle.setproctitle("LLM4CDSR")
parser = argparse.ArgumentParser()
parser = get_main_arguments(parser)
parser = get_model_arguments(parser)
parser = get_train_arguments(parser)
parser.add_argument('--augmented', action='store_true', help='Whether to use augmented generator')
parser.add_argument('--domain_beta', type=float, default=2.0, help='domain_beta')
parser.add_argument('--do_peft', action='store_true', help='Whether to use PEFT')
parser.add_argument('--peft_type', type=str, default='adapter',
choices=['prompt', 'adapter', 'lora', 'interest_transfer'],
help='Type of PEFT to use')
parser.add_argument('--adapter_size', type=int, default=64,
help='Size of adapter hidden layer')
parser.add_argument('--peft_epochs', type=int, default=100,
help='Number of epochs for PEFT training')
parser.add_argument('--peft_lr', type=float, default=1e-4,
help='Learning rate for PEFT training')
parser.add_argument('--lora_rank', type=int, default=64,
help='Rank for LoRA decomposition')
parser.add_argument('--lora_alpha', type=int, default=16,
help='Alpha scaling factor for LoRA')
parser.add_argument('--lora_dropout', type=float, default=0.1,
help='Dropout rate for LoRA layers')
parser.add_argument('--lora_lr', type=float, default=1e-4,
help='Learning rate for LoRA training')
parser.add_argument('--finetune_domain', type=str, default='0', choices=['0', '1'],
help='Target domain for PEFT')
parser.add_argument('--pretrain_path', type=str, default='One4All',
help='Path to pretrained model')
torch.autograd.set_detect_anomaly(True)
parser.add_argument('--hard_negative_weight', type=float, default=0.1,
help='硬负样本的权重')
parser.add_argument('--adaptive_weight', type=bool, default=True,
help='是否使用自适应权重')
parser.add_argument('--l2_weight', type=float, default=0.01,
help='L2正则化权重')
parser.add_argument('--align_weight', type=float, default=0.1,
help='对齐权重')
parser.add_argument('--kd_weight', type=float, default=0.1,
help='知识蒸馏权重')
args = parser.parse_args()
set_seed(args.seed) # fix the random seed
args.output_dir = os.path.join(args.output_dir, args.dataset)
args.pretrain_path = os.path.join(args.output_dir, args.pretrain_path)
args.output_dir = os.path.join(args.output_dir, args.model_name)
args.output_dir = os.path.join(args.output_dir, args.check_path) # if check_path is none, then without check_path
args.llm_emb_path = os.path.join("data/"+args.dataset+"/handled/", "{}.pkl".format(args.llm_emb_file))
def main():
torch.cuda.empty_cache()
log_manager = Logger(args) # initialize the log manager
logger, writer = log_manager.get_logger() # get the logger
args.now_str = log_manager.get_now_str()
device = torch.device("cuda:"+str(args.gpu_id) if torch.cuda.is_available()
and not args.no_cuda else "cpu")
os.makedirs(args.output_dir, exist_ok=True)
# generator is used to manage dataset
if args.model_name in ["llm4cdsr", "One4All", "C2DSR", "One4AllAttentionOnly", "One4AllEmbeddingOnly", "DomainSpecificAdapter"]:
if args.augmented:
generator = Aug_CDSRRegSeq2SeqGeneratorUser(args, logger, device)
else:
generator = CDSRRegSeq2SeqGeneratorUser(args, logger, device)
else:
raise ValueError
if args.model_name in ["llm4cdsr", "One4All", "C2DSR", "One4AllAttentionOnly", "One4AllEmbeddingOnly"]:
if args.do_peft:
if args.peft_type == 'lora':
trainer = LoRATrainer(args, logger, writer, device, generator)
else:
trainer = CDSRTrainer(args, logger, writer, device, generator)
else:
trainer = CDSRTrainer(args, logger, writer, device, generator)
elif args.model_name == "DomainSpecificAdapter":
from trainers.domain_adapter_trainer import DomainAdapterTrainer
trainer = DomainAdapterTrainer(args, logger, writer, device, generator)
#trainer = CDSRTrainer(args, logger, writer, device, generator)
else:
raise ValueError
if args.do_test:
trainer.test()
elif args.do_emb:
trainer.save_item_emb()
elif args.do_group:
trainer.test_group()
else:
trainer.train()
log_manager.end_log() # delete the logger threads
if __name__ == "__main__":
main()