-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathrun_fusion.py
More file actions
104 lines (88 loc) · 4.15 KB
/
Copy pathrun_fusion.py
File metadata and controls
104 lines (88 loc) · 4.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import argparse
import json
import os
from config_utils import DEFAULT_TASK_MAP, get_dataset_config, restore_path
from itrain import ModelArguments, RunArguments, Setup
FUSION_OUTPUT_DIR = "fusion_output"
FUSION_FINETUNE_OUTPUT_DIR = "fusion_finetune_output"
def run_fusion(args, config_name=None):
# init setup
dataset_manager, config = get_dataset_config(args["target_task"], train_size=args["train_size"])
if args["train_size"] > 0:
target_task_name = args["target_task"] + "_n" + str(args["train_size"])
else:
target_task_name = args["target_task"]
if args["finetune_adapters"]:
output_base = os.path.join(FUSION_FINETUNE_OUTPUT_DIR, "to_" + target_task_name)
else:
output_base = os.path.join(FUSION_OUTPUT_DIR, "to_" + target_task_name)
# patch training args
config["training"]["learning_rate"] = args["learning_rate"]
config["training"]["num_train_epochs"] = args["num_train_epochs"]
# load results if existing
final_results_file = os.path.join(output_base, "eval_results.json")
if os.path.exists(final_results_file):
with open(final_results_file, "r") as f:
results = json.load(f)
else:
results = {}
with open(os.path.expanduser(args["task_map"]), "r") as f:
task_map = json.load(f)
fusion_name = ",".join(args["source_tasks"])
print(f"*** Running fusion from {fusion_name} to {target_task_name} ***")
output_dir = os.path.join(output_base, fusion_name)
# skip this iteration if no overwrites requested & existing
if args["overwrite_mode"] == 0 and os.path.exists(output_dir):
print(f"Skipping task {fusion_name} as it already exists.")
return
# setup the dataset and training params
setup = Setup(id=args["id"])
setup.dataset(dataset_manager)
config["training"]["output_dir"] = output_dir
setup.training(RunArguments(**config["training"]))
if isinstance(config["evaluation"], str):
setup.evaluation(split=config["evaluation"])
else:
setup.evaluation()
setup.notify(config["notify"])
setup._config_name = config_name or "fusion_" + fusion_name + "_to_" + target_task_name
# setup model
load_adapters_map = {}
# iterate over all adapters for fusion
for source_task in args["source_tasks"]:
source_dataset_manager, _ = get_dataset_config(source_task)
load_adapters_map[source_task] = restore_path(task_map, source_task, source_dataset_manager)
# patch model args
config["model"]["load_adapters"] = load_adapters_map
config["model"]["train_adapter_fusion"] = fusion_name
config["model"]["drop_last_fusion_layer"] = True
config["model"]["train_adapter"] = args["finetune_adapters"]
setup.model(ModelArguments(**config["model"]))
# start!
if fusion_name in results and args["overwrite_mode"] == 1:
# append to existing
run_results = setup.run(restarts=args["restarts"], first_run_index=len(results[fusion_name]["seeds"]))
for k, v in run_results.items():
results[fusion_name][k] += v
else:
run_results = setup.run(restarts=args["restarts"])
results[fusion_name] = run_results
# save results
with open(final_results_file, "w") as f:
json.dump(results, f)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("target_task", type=str, help="Name of the target task training setup.")
parser.add_argument("--id", type=int, default=0, help="ID of this run.")
parser.add_argument("--task_map", type=str, default=DEFAULT_TASK_MAP)
parser.add_argument(
"--overwrite_mode", type=int, choices=[0, 1, 2], default=0, help="0: no overwrite; 1: append; 2: overwrite"
)
parser.add_argument("--source_tasks", type=lambda s: s.split(","), required=True)
parser.add_argument("--learning_rate", type=float, default=5e-5)
parser.add_argument("--num_train_epochs", type=int, default=15)
parser.add_argument("--train_size", type=int, default=-1)
parser.add_argument("--restarts", type=int, default=None)
parser.add_argument("--finetune_adapters", action="store_true")
args = vars(parser.parse_args())
run_fusion(args)