Skip to content

Commit 433cf59

Browse files
committed
New version, code modification so that the network can be processed by torchscript.
1 parent 65bb60a commit 433cf59

3 files changed

Lines changed: 57 additions & 30 deletions

File tree

nasbench_pytorch/model/model.py

Lines changed: 52 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def __init__(self, spec, num_labels=10,
4343
if isinstance(spec, tuple):
4444
spec = ModelSpec(spec[0], spec[1])
4545

46-
self.spec = spec
4746
self.cell_indices = set()
4847

4948
self.layers = nn.ModuleList([])
@@ -106,60 +105,72 @@ class Cell(nn.Module):
106105
def __init__(self, spec, in_channels, out_channels):
107106
super(Cell, self).__init__()
108107

109-
self.spec = spec
110-
self.num_vertices = np.shape(self.spec.matrix)[0]
108+
self.matrix = spec.matrix
109+
self.num_vertices = np.shape(self.matrix)[0]
111110

112111
# vertex_channels[i] = number of output channels of vertex i
113-
self.vertex_channels = ComputeVertexChannels(in_channels, out_channels, self.spec.matrix)
112+
self.vertex_channels = ComputeVertexChannels(in_channels, out_channels, self.matrix)
114113
#self.vertex_channels = [in_channels] + [out_channels] * (self.num_vertices - 1)
115114

116115
# operation for each node
117-
self.vertex_op = nn.ModuleList([None])
116+
self.vertex_op = nn.ModuleList([Placeholder()])
118117
for t in range(1, self.num_vertices-1):
119118
op = OP_MAP[spec.ops[t]](self.vertex_channels[t], self.vertex_channels[t])
120119
self.vertex_op.append(op)
121120

122121
# operation for input on each vertex
123-
self.input_op = nn.ModuleList([None])
122+
self.input_op = nn.ModuleList([Placeholder()])
124123
for t in range(1, self.num_vertices):
125-
if self.spec.matrix[0, t]:
124+
if self.matrix[0, t]:
126125
self.input_op.append(Projection(in_channels, self.vertex_channels[t]))
127126
else:
128-
self.input_op.append(None)
127+
self.input_op.append(Placeholder())
128+
129+
self.last_inop : Projection = self.input_op[self.num_vertices-1]
129130

130131
def forward(self, x):
131132
tensors = [x]
132133

133134
out_concat = []
134-
for t in range(1, self.num_vertices-1):
135-
fan_in = [Truncate(tensors[src], self.vertex_channels[t]) for src in range(1, t) if self.spec.matrix[src, t]]
136-
137-
if self.spec.matrix[0, t]:
138-
fan_in.append(self.input_op[t](x))
139-
140-
# perform operation on node
141-
#vertex_input = torch.stack(fan_in, dim=0).sum(dim=0)
142-
vertex_input = sum(fan_in)
143-
#vertex_input = sum(fan_in) / len(fan_in)
144-
vertex_output = self.vertex_op[t](vertex_input)
145-
146-
tensors.append(vertex_output)
147-
if self.spec.matrix[t, self.num_vertices-1]:
148-
out_concat.append(tensors[t])
135+
# range(1, self.num_vertices - 1),
136+
for t, (inmod, outmod) in enumerate(zip(self.input_op, self.vertex_op)):
137+
if 0 < t < (self.num_vertices - 1):
138+
139+
fan_in = []
140+
for src in range(1, t):
141+
if self.matrix[src, t]:
142+
fan_in.append(Truncate(tensors[src], torch.tensor(self.vertex_channels[t])))
143+
144+
if self.matrix[0, t]:
145+
l = inmod(x)
146+
fan_in.append(l)
147+
148+
# perform operation on node
149+
#vertex_input = torch.stack(fan_in, dim=0).sum(dim=0)
150+
vertex_input = torch.zeros(fan_in[0].shape)
151+
for val in fan_in:
152+
vertex_input += val
153+
#vertex_input = sum(fan_in)
154+
#vertex_input = sum(fan_in) / len(fan_in)
155+
vertex_output = outmod(vertex_input)
156+
157+
tensors.append(vertex_output)
158+
if self.matrix[t, self.num_vertices-1]:
159+
out_concat.append(tensors[t])
149160

150161
if not out_concat:
151-
assert self.spec.matrix[0, self.num_vertices-1]
152-
outputs = self.input_op[self.num_vertices-1](tensors[0])
162+
assert self.matrix[0, self.num_vertices-1]
163+
outputs = self.last_inop(tensors[0])
153164
else:
154165
if len(out_concat) == 1:
155166
outputs = out_concat[0]
156167
else:
157168
outputs = torch.cat(out_concat, 1)
158169

159-
if self.spec.matrix[0, self.num_vertices-1]:
160-
outputs += self.input_op[self.num_vertices-1](tensors[0])
170+
if self.matrix[0, self.num_vertices-1]:
171+
outputs += self.last_inop(tensors[0])
161172

162-
#if self.spec.matrix[0, self.num_vertices-1]:
173+
#if self.matrix[0, self.num_vertices-1]:
163174
# out_concat.append(self.input_op[self.num_vertices-1](tensors[0]))
164175
#outputs = sum(out_concat) / len(out_concat)
165176

@@ -196,6 +207,9 @@ def ComputeVertexChannels(in_channels, out_channels, matrix):
196207
Returns:
197208
list of channel counts, in order of the vertices.
198209
"""
210+
if isinstance(matrix, torch.Tensor):
211+
matrix = matrix.numpy()
212+
199213
num_vertices = np.shape(matrix)[0]
200214

201215
vertex_channels = [0] * num_vertices
@@ -241,4 +255,13 @@ def ComputeVertexChannels(in_channels, out_channels, matrix):
241255
assert final_fan_in == out_channels or num_vertices == 2
242256
# num_vertices == 2 means only input/output nodes, so 0 fan-in
243257

244-
return vertex_channels
258+
return [int(v) for v in vertex_channels]
259+
260+
261+
class Placeholder(torch.nn.Module):
262+
def __init__(self):
263+
super().__init__()
264+
self.a = torch.nn.Parameter(torch.randn(()))
265+
266+
def forward(self, x):
267+
return x

nasbench_pytorch/model/model_spec.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import copy
2727
import numpy as np
28+
import torch
2829

2930
from nasbench_pytorch.model import graph_util
3031

@@ -52,6 +53,7 @@ def __init__(self, matrix, ops, data_format='channels_last'):
5253
Raises:
5354
ValueError: invalid matrix or ops
5455
"""
56+
5557
if not isinstance(matrix, np.ndarray):
5658
matrix = np.array(matrix)
5759
shape = np.shape(matrix)
@@ -73,6 +75,8 @@ def __init__(self, matrix, ops, data_format='channels_last'):
7375
self.valid_spec = True
7476
self._prune()
7577

78+
self.matrix = torch.tensor(self.matrix)
79+
7680
self.data_format = data_format
7781

7882
def _prune(self):

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setuptools.setup(
44
name='nasbench_pytorch',
5-
version='1.1',
5+
version='1.2',
66
license='Apache License 2.0',
77
author='Romulus Hong, Gabriela Suchopárová',
88
packages=setuptools.find_packages()

0 commit comments

Comments
 (0)