-
Notifications
You must be signed in to change notification settings - Fork 48
/
utils.py
118 lines (102 loc) · 4.02 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import matplotlib.pyplot as plt
import torch
def tensor2img(tensor, ax=plt):
tensor = tensor.squeeze()
if len(tensor.shape) > 2: tensor = tensor.permute(1, 2, 0)
img = tensor.detach().cpu().numpy()
return img
def subplot(images, parse=lambda x: x, rows_titles=None, cols_titles=None, title='', *args, **kwargs):
fig, ax = plt.subplots(*args, **kwargs)
fig.suptitle(title)
i = 0
try:
for row in ax:
if rows_titles is not None: row.set_title(rows_titles[i])
try:
for j, col in enumerate(row):
if cols_titles is not None: col.set_title(cols_titles[j])
col.imshow(parse(images[i]))
col.axis('off')
col.set_aspect('equal')
i += 1
except TypeError:
row.imshow(parse(images[i]))
row.axis('off')
row.set_aspect('equal')
i += 1
except IndexError:
break
except:
ax.imshow(parse(images[i]))
ax.axis('off')
ax.set_aspect('equal')
fig.tight_layout()
fig.subplots_adjust(top=0.88)
plt.subplots_adjust(wspace=0.0, hspace=0.0)
plt.show()
def module2traced(module, inputs):
handles, modules = [], []
def trace(module, inputs, outputs):
modules.append(module)
def traverse(module):
for m in module.children():
traverse(m) # recursion is love
is_leaf = len(list(module.children())) == 0
if is_leaf: handles.append(module.register_forward_hook(trace))
traverse(module)
_ = module(inputs)
[h.remove() for h in handles]
return modules
def run_vis_plot(vis, x, layer, ncols=1, nrows=1):
images, info = vis(x, layer)
images = images[: nrows*ncols]
print(images[0].shape)
subplot(images, tensor2img, title=str(layer), ncols=ncols, nrows=nrows)
def run_vis_plot_across_models(modules, input, layer_id, Vis, title,
device,
inputs=None,
nrows=3,
ncols=2,
row_wise=True,
parse=tensor2img,
annotations=None,
idx2label=None,
rows_name=None,*args, **kwargs):
pad = 0 # in points
fig, ax = plt.subplots(nrows=nrows, ncols=ncols)
fig.suptitle(title)
for i, row in enumerate(ax):
try:
module = next(modules)
module.eval()
module = module.to(device)
layer = None
if layer_id is not None: layer = module2traced(module, input)[layer_id]
vis = Vis(module, device)
info = {}
if inputs is None: images, info = vis(input.clone(), layer, *args, **kwargs)
row_title = module.__class__.__name__
del module
torch.cuda.empty_cache()
if rows_name is not None: row_title = rows_name[i]
row[0].set_title(row_title)
if annotations is not None:
row[0].annotate(annotations[i], xy=(0, 0.5), xytext=(-row[0].yaxis.labelpad - pad, 0),
xycoords=row[0].yaxis.label, textcoords='offset points',
size='medium', ha='right', va='center', rotation=90)
for j, col in enumerate(row):
if inputs is None: image = images[j]
else: image, info = vis(inputs[j], layer, *args, **kwargs)
if 'prediction' in info: col.set_title(idx2label[int(info['prediction'])])
col.imshow(parse(image))
col.axis('off')
col.set_aspect('equal')
except StopIteration:
break
except:
row.set_title(row_title)
row.imshow(parse(images[0]))
row.axis('off')
row.set_aspect('equal')
fig.tight_layout()
plt.subplots_adjust(wspace=0.1, hspace=0.2)