forked from AlexanderSlivinskiy/FUNIT
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdebugUtils.py
More file actions
80 lines (71 loc) · 2.76 KB
/
debugUtils.py
File metadata and controls
80 lines (71 loc) · 2.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
from skimage.io import imsave
import os
import numpy as np
def printCheckpoint(index, funcName, className="", prefix=""):
print("==="+prefix+" : "+className+"."+funcName+", checkpoint:",index,"===")
class Debugger():
#Expected call: debug = Debugger(self.function, self, prefix)
def __init__(self, func=None, classSelf=None, prefix=""):
if (func != None):
self.funcName = func.__name__
else:
self.funcName = ""
if (classSelf != None):
self.className = classSelf.__class__.__name__
else:
self.className = ""
self.prefix = prefix
self.index = 0
def printCheckpoint(self, function=None, content=""):
if function == None:
f = self.funcName
else:
f = function.__name__
print("==="+self.prefix+" : "+self.className+"."+f+", checkpoint:",self.index,"===")
if (content != ""):
print(content)
self.index+=1
def checkForNaNandInf(self, tensor, msg=""):
nan = torch.isnan(tensor)
inf = torch.isinf(tensor)
if (torch.sum(nan) != 0):
print("INPUT IS NAN IN FORWARD PASS!!", torch.sum(nan))
self.printCheckpoint(content=msg)
if (torch.sum(inf) != 0):
print("INPUT IS INF IN FORWARD PASS!!", torch.sum(inf))
self.printCheckpoint(content=msg)
#print(tensor)
def printgradnorm(self, cls, grad_input, grad_output):
print('Inside ' + cls.__class__.__name__ + ' backward')
#print('')
#print('grad_input: ', type(grad_input))
#print('grad_input[0]: ', type(grad_input[0]))
#print('grad_output: ', type(grad_output))
#print('grad_output[0]: ', type(grad_output[0]))
#print('')
#print('grad_input size:', grad_input[0].size())
#print('grad_output size:', grad_output[0].size())
#print('grad_input norm:', grad_input[0].norm())
print('grad_output_max:', grad_output[0].max())
#print(grad_output)
class DebugNet():
name = ""
safeImgSwitch = False
def setName(n):
DebugNet.name = n
def safeImage(pic):
if (DebugNet.safeImgSwitch):
picName = "pics/pic_"+DebugNet.name+"_"
i = 0
while (os.path.exists(picName + ((str)(i)) + "_0" + ".png")):
i+=1
pic = pic.detach().cpu().numpy()
print(pic.shape)
if len(pic.shape) == 4 :
pic = pic[0]
#pic = pic.astype(np.uint8) #may be lossy
#for j in range(pic.shape[0]):
# imsave(picName + ((str)(i)) + "_" + ((str)(j)) + ".png", pic[j])
pic = np.max(pic, axis = 0)
imsave(picName + ((str)(i))+ "_0.png", pic)