-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate.py
More file actions
88 lines (74 loc) · 2.42 KB
/
evaluate.py
File metadata and controls
88 lines (74 loc) · 2.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""
This script was made by soeque1 at 24/07/20.
To implement code for training your model.
"""
import logging
from argparse import ArgumentParser, Namespace
from logging import getLogger
import pytorch_lightning as pl
from src.core.build_data import Config
from src.data import UbuntuDataLoader, UbuntuDataSet, collate
from src.metric import bleuS_2, bleuS_4
from src.utils.prepare import build
from train import RecoSAPL
logger = getLogger(__name__)
def main(
config_data_file: str,
config_model_file: str,
config_trainer_file: str,
config_api_file: str,
version: str,
) -> None:
# TODO: to be removed
_ = build({"data_config": config_data_file, "version": version})
cfg = Config()
cfg.add_dataset(config_data_file)
cfg.add_model(config_model_file)
cfg.add_api(config_api_file)
cfg.add_trainer(config_trainer_file)
val_data = UbuntuDataSet(
cfg.dataset.root + cfg.dataset.target,
cfg.dataset.raw.val,
cfg.model.max_seq,
cfg.dataset.target,
cfg.model.max_turns,
)
val_dataloader = UbuntuDataLoader(
val_data,
batch_size=cfg.model.batch_size,
shuffle=False,
num_workers=8,
collate_fn=collate,
)
model = RecoSAPL.load_from_checkpoint(
checkpoint_path=cfg.api.model_path, config=cfg
)
cfg.trainer.pl.max_epochs = 1
trainer = pl.Trainer(**cfg.trainer.pl, logger=False, checkpoint_callback=False)
test_result = trainer.test(model, test_dataloaders=val_dataloader)
logger.info(test_result)
bleu_score_4 = bleuS_4(model.pred, model.target)
bleu_score_2 = bleuS_2(model.pred, model.target)
logger.info(bleu_score_4)
logger.info(bleu_score_2)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument(
"--config_data_file", default="./conf/dataset/ubuntu.yml", type=str
)
parser.add_argument(
"--config_model_file", default="./conf/model/ReCoSa.yml", type=str
)
parser.add_argument(
"--config_trainer_file", default="./conf/trainer/ReCoSa.yml", type=str
)
parser.add_argument("--config_api_file", default="./conf/api/ReCoSa.yml", type=str)
parser.add_argument("--version", default="v0.0.8.1", type=str)
args = parser.parse_args()
main(
args.config_data_file,
args.config_model_file,
args.config_trainer_file,
args.config_api_file,
args.version,
)