forked from AlexanderSlivinskiy/FUNIT
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcustomTransforms.py
More file actions
120 lines (96 loc) · 3.56 KB
/
customTransforms.py
File metadata and controls
120 lines (96 loc) · 3.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from skimage.exposure import rescale_intensity
import torch
import numpy as np
from imgaug import augmenters as iaa
from globalConstants import GlobalConstants
def transformTo3Tuple(image):
if len(image.shape) > 3:
image = np.max(image, axis=0)
if len(image.shape) < 3:
image = np.expand_dims(image, axis=0)
#if image.shape[0] < image.shape[-1]:
# image = np.transpose(image, (1, 2, 0))
#if image.shape[-1] == 1:
# image = np.repeat(image, 3, axis=-1)
class RescaleToZeroOne(object):
#Converts a PyTorch Tensor with RGB Range [0, 255] to PyTorch Tensor [0,1]
def __call__(self, pic):
"""
Args:
pic (tensor or numpy.ndarray): Image to be rescaled.
Returns:
Tensor: Converted image.
"""
#return pic/255
return pic/pic.max()
def __repr__(self):
return self.__class__.__name__ + '()'
class RescaleToOneOne(object):
#Converts a PyTorch Tensor with RGB Range [0, 255] to PyTorch Tensor [0,1]
def __call__(self, pic):
"""
Args:
pic (tensor or numpy.ndarray): Image to be rescaled.
Returns:
Tensor: Converted image.
"""
#half does not support the max operation on the CPU but only on CUDA
if (pic.dtype == torch.float16):
maximum = pic.numpy().max()
pic = pic.numpy()
maximum = pic.max()
res = ((pic / maximum) *2)-1
res = torch.from_numpy(res)
res = res.half()
else:
maximum = pic.max()
res = ((pic/ maximum) *2)-1
return res
def __repr__(self):
return self.__class__.__name__ + '()'
@staticmethod
def reverse(tensor):
(tensor.cpu().detach()[0]+1)/2*255
class ToTensor(object):
def __call__(self, pic):
if (pic.dtype == 'uint16'):
if (pic.max()<32768):
pic = pic.astype('int16')
else:
pic = pic.astype('int32')
if (GlobalConstants.getInputChannels()==3 and pic.shape[2]==3):
pic=pic.transpose((2,0,1))
elif (GlobalConstants.getInputChannels()==1 and len(pic.shape)==2):
pic=pic.reshape((1, pic.shape[0], pic.shape[1]))
tensor = torch.from_numpy(pic.copy())
if (not GlobalConstants.usingApex):
tensor = GlobalConstants.setTensorToPrecision(tensor)
return tensor
def __repr__(self):
return self.__class__.__name__ + '()'
class DynamicResize(iaa.Resize):
def __init__(self, desired_size):
self.desired_size = desired_size
# This is willingly NOT exponentially growing, since this could result and cropping half the picture
def get_closest_factor(self, current_size):
x = 1
while (self.desired_size <= current_size//(x+1)):
x+=1
return x
def __call__(self, pic):
#Assume that we have (3,y,x)-shape
print(pic.shape)
new_x = self.get_closest_factor(pic.shape[2])
new_y = self.get_closest_factor(pic.shape[1])
scalar = min(new_x, new_y) #We want to keep proportions
resizer = iaa.Resize({"height":pic.shape[1]//scalar, "width":pic.shape[2]//scalar})
return resizer(pic)
def __repr__(self):
return self.__class__.__name__ + '()'
class PrintInputShape(object):
#Converts a PyTorch Tensor with RGB Range [0, 255] to PyTorch Tensor [0,1]
def __call__(self, pic):
print(pic.shape)
return pic
def __repr__(self):
return self.__class__.__name__ + '()'