This repository was archived by the owner on May 13, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTester.py
More file actions
122 lines (97 loc) · 3.42 KB
/
Copy pathTester.py
File metadata and controls
122 lines (97 loc) · 3.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import json
import pandas as pd
from typing import List
import torch
from torch import Tensor
from torch.nn import Module
from torch.utils.data import DataLoader
from ..metrics import MetricWrapper
from ..callbacks import CallbackWrapper
from ..utils import DotDict, get_services, Logger
from ..utils.runner_utils import ExportableState
from ..utils.BatchForwarder import BatchForwarder
from ..utils.runner_utils.tester import TesterState
__all__ = ["Tester"]
class Tester(object):
__config: DotDict
__logger: Logger
__metric: MetricWrapper
__test_dataloader: DataLoader
__callback: CallbackWrapper
__model: Module
def __init__(self,
config: DotDict,
model: Module,
metric: MetricWrapper,
dataloader: DataLoader
) -> None:
from ..callbacks import CallbackWrapper # tmp fix for cyclic import
self.__config = config
self.__model = model
self.__metric = metric
self.__test_dataloader = dataloader
self.__logger = Logger("test")
self.__callback = CallbackWrapper(self, get_services(config))
self.state = TesterState(
stateful_callbacks=[cb for cb in [*self.__callback.callback_lst]
if isinstance(cb, ExportableState)]
)
self.__callback("on_init_end")
@property
def config(self) -> DotDict:
return self.__config
@property
def metric(self) -> MetricWrapper:
return self.__metric
@property
def test_dataloader(self) -> DataLoader:
return self.__test_dataloader
@property
def callback(self) -> CallbackWrapper:
return self.__callback
@property
def model(self):
return self.__model
@property
def logger(self) -> Logger:
return self.__logger
def compute_metrics(self, delimiter: str = "; ", chunksize: int = 10 ** 6) -> None:
"""
Compute metrics from inferred result in log
"""
for chunk in pd.read_csv(
os.path.join(self.config.Global.log_path, "pred_result.csv"),
delimiter=delimiter,
header=None,
names=["pred", "label", "idx"],
chunksize=chunksize,
engine="python"
):
for _, row in chunk.iterrows():
pred: str
label: str
pred, label, i = row.pred, row.label, row.idx
pred: List[float] = json.loads(pred)
label: List[float] = json.loads(label)
pred: Tensor = torch.tensor(pred, dtype=torch.float16)
label: Tensor = torch.tensor(label, dtype=torch.uint8)
self.metric.update(pred, label)
self.metric.compute()
result = self.metric.get_result(True)
self.state.metric_result = result
def fit(self):
print(f"""Start running inference on test dataset ...""")
self.__callback("on_begin")
BatchForwarder(
self.__config.Data[self.state.phase].forward_strategy,
self,
**{
"overridden_args": self.__config.Data[self.state.phase].get("overridden_args", DotDict({})).get_dict()
}
)()
print(f"""Start computing metrics from inferred results ...""")
self.compute_metrics()
self.__callback("on_end")
print("Testing finished")
return None