-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathtest_trainers.py
More file actions
32 lines (30 loc) · 987 Bytes
/
test_trainers.py
File metadata and controls
32 lines (30 loc) · 987 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from dicee.executer import Execute
from dicee.config import Namespace
import pytest
class TestCallback:
@pytest.mark.filterwarnings('ignore::UserWarning')
def test_conex_torch_cpu_trainer(self):
args = Namespace()
args.model = 'AConEx'
args.num_epochs = 1
args.scoring_technique = 'KvsAll'
args.dataset_dir = 'KGs/UMLS'
args.num_epochs = 10
args.batch_size = 1024
args.lr = 0.01
args.embedding_dim = 32
args.trainer = 'torchCPUTrainer'
Execute(args).start()
@pytest.mark.filterwarnings('ignore::UserWarning')
def test_aconex_pl_trainer(self):
args = Namespace()
args.model = 'AConEx'
args.num_epochs = 1
args.scoring_technique = 'KvsAll'
args.dataset_dir = 'KGs/UMLS'
args.num_epochs = 10
args.batch_size = 1024
args.lr = 0.01
args.embedding_dim = 32
args.trainer = 'PL'
Execute(args).start()