-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathtest_pykeen.py
More file actions
68 lines (63 loc) · 2.78 KB
/
test_pykeen.py
File metadata and controls
68 lines (63 loc) · 2.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from dicee.executer import Execute
import sys
import pytest
from dicee.config import Namespace
def template(model_name):
args = Namespace()
args.dataset_dir = "KGs/UMLS"
args.trainer = "PL"
args.model = model_name
args.num_epochs = 20
args.batch_size = 256
args.lr = 0.1
args.num_workers = 1
args.num_core = 1
args.scoring_technique = "KvsAll"
args.sample_triples_ratio = None
args.read_only_few = None
args.num_folds_for_cv = None
return args
@pytest.mark.parametrize("model_name", ["Pykeen_DistMult", "Pykeen_ComplEx", "Pykeen_HolE", "Pykeen_CP",
"Pykeen_ProjE", "Pykeen_TuckER", "Pykeen_TransR", "Pykeen_TransH",
"Pykeen_TransD", "Pykeen_TransE", "Pykeen_QuatE", "Pykeen_MuRE",
"Pykeen_BoxE", "Pykeen_RotatE","Pykeen_TransF"])
class TestClass:
def test_defaultParameters_case(self, model_name):
args = template(model_name)
result = Execute(args).start()
"""
if args.model == "Pykeen_DistMult":
assert result["Train"]["MRR"] >= 0.78
elif args.model == "Pykeen_ComplEx":
assert result["Train"]["MRR"] >= 0.76
elif args.model == "Pykeen_QuatE":
assert result["Train"]["MRR"] >= 0.83
elif args.model == "Pykeen_MuRE":
assert result["Train"]["MRR"] >= 0.84
elif args.model == "Pykeen_BoxE":
assert result["Train"]["MRR"] >= 0.77
elif args.model == "Pykeen_RotatE":
assert result["Train"]["MRR"] >= 0.59
elif args.model == "Pykeen_CP": # 1.5M params
assert result["Train"]["MRR"] >= 0.97
elif args.model == "Pykeen_HolE": # 14.k params
assert result["Train"]["MRR"] >= 0.87
elif args.model == "Pykeen_ProjE": # 14.k params
assert result["Train"]["MRR"] >= 0.77
elif args.model == "Pykeen_TuckER": # 276.k params
assert result["Train"]["MRR"] >= 0.30
elif args.model == "Pykeen_TransR": # 188.k params
assert result["Train"]["MRR"] >= 0.45
elif args.model == "Pykeen_TransF": # 14.5 k params
assert result["Train"]["MRR"] >= 0.13
elif args.model == "Pykeen_TransH": # 20.4 k params
assert result["Train"]["MRR"] >= 0.37
elif args.model == "Pykeen_TransD": # 29.1 k params
assert result["Train"]["MRR"] >= 0.31
elif args.model == "Pykeen_TransE": # 29.1 k params
assert result["Train"]["MRR"] >= 0.06
"""
def test_perturb_callback_case(self, model_name):
args = template(model_name)
args.callbacks = {"Perturb": {"level": "out", "ratio": 0.2, "method": "Soft", "scaler": 0.3}}
Execute(args).start()