-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcensusincome_main.py
More file actions
86 lines (72 loc) · 2.51 KB
/
censusincome_main.py
File metadata and controls
86 lines (72 loc) · 2.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import argparse
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from config import CensusIncome_Vocabulary_Size
from multitaskrec.dataset import CensusIncomeDataset
from multitaskrec.model import MPTRec
from multitaskrec.train import MPTRecTrainManager
def main(args):
# set seed
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
# load dataset
train_dataset = CensusIncomeDataset("dataset/Census-income/train.gz")
test_dataset = CensusIncomeDataset("dataset/Census-income/test.gz")
val_dataset, test_dataset = train_test_split(
test_dataset, test_size=0.5, random_state=args.seed
)
env_ids = torch.randint(0, 2, size=(len(train_dataset),))
train_loader = DataLoader(train_dataset, batch_size=256)
val_loader = DataLoader(val_dataset, batch_size=256)
test_loader = DataLoader(test_dataset, batch_size=256)
# build model
model = MPTRec(
num_tasks=2,
feature_vocabulary=CensusIncome_Vocabulary_Size,
embedding_size=4,
input_size=127,
expert_dnn_hidden_units=[256, 128],
tower_dnn_hidden_units=[64, 32],
reg_embedding=0.006,
reg_dnn=3e-5,
)
device = torch.device(f"cuda:{args.gpu}")
model.to(device)
# build train manager
train_manager = MPTRecTrainManager(
model=model,
train_loader=train_loader,
val_loader=val_loader,
env_ids=env_ids,
task_name=["Income", "Marital"],
lr=1e-3,
batch_size=256,
epochs=10,
patience=5,
gen_coe=0.9,
env_coe=0.1,
clustering_interval=2,
)
# counting parameters and floating-point operands
train_manager.compute_cost()
# training
train_manager.train(way=args.way)
# testing
model.load_state_dict(train_manager.best_weight)
auc_test = train_manager.evaluation(test_loader, way=args.way)
print(
"AUC-Test-Income:{:.4f}, AUC-Test-Marital:{:.4f}".format(
auc_test[0], auc_test[1]
)
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=1000)
parser.add_argument("--gpu", type=int, default=0)
parser.add_argument("--way", type=str, default="all")
args = parser.parse_args()
main(args)