-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathsystem.py
More file actions
53 lines (43 loc) · 1.93 KB
/
system.py
File metadata and controls
53 lines (43 loc) · 1.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from typing import Any, Dict, Tuple, Union
import torch
from lighter import LighterModule
class LighterSystem(LighterModule):
def training_step(self, batch: Any, batch_idx: int) -> Dict[str, torch.Tensor]:
input, target = batch
# Handle cases where SwaV needs epoch/step
if hasattr(self.network, "forward") and "epoch" in self.network.forward.__code__.co_varnames:
pred = self.network(input, epoch=self.current_epoch, step=self.global_step)
else:
pred = self.network(input)
try:
loss = self.criterion(pred, target)
except TypeError:
loss = self.criterion(pred)
return {"loss": loss, "pred": pred, "target": target}
def validation_step(self, batch: Any, batch_idx: int) -> Dict[str, torch.Tensor]:
input, target = batch
if hasattr(self.network, "forward") and "epoch" in self.network.forward.__code__.co_varnames:
pred = self.network(input, epoch=self.current_epoch, step=self.global_step)
else:
pred = self.network(input)
try:
loss = self.criterion(pred, target)
except TypeError:
loss = self.criterion(pred)
return {"loss": loss, "pred": pred, "target": target}
def test_step(self, batch: Any, batch_idx: int) -> Dict[str, torch.Tensor]:
input, target = batch
pred = self.network(input)
loss = None
if self.criterion:
try:
loss = self.criterion(pred, target)
except TypeError:
loss = self.criterion(pred)
return {"loss": loss, "pred": pred, "target": target}
def predict_step(self, batch: Any, batch_idx: int) -> Any:
if isinstance(batch, (list, tuple)) and len(batch) == 2:
input, _ = batch
else:
input = batch
return self.network(input)