@@ -43,7 +43,6 @@ def __init__(self, spec, num_labels=10,
4343 if isinstance (spec , tuple ):
4444 spec = ModelSpec (spec [0 ], spec [1 ])
4545
46- self .spec = spec
4746 self .cell_indices = set ()
4847
4948 self .layers = nn .ModuleList ([])
@@ -106,60 +105,72 @@ class Cell(nn.Module):
106105 def __init__ (self , spec , in_channels , out_channels ):
107106 super (Cell , self ).__init__ ()
108107
109- self .spec = spec
110- self .num_vertices = np .shape (self .spec . matrix )[0 ]
108+ self .matrix = spec . matrix
109+ self .num_vertices = np .shape (self .matrix )[0 ]
111110
112111 # vertex_channels[i] = number of output channels of vertex i
113- self .vertex_channels = ComputeVertexChannels (in_channels , out_channels , self .spec . matrix )
112+ self .vertex_channels = ComputeVertexChannels (in_channels , out_channels , self .matrix )
114113 #self.vertex_channels = [in_channels] + [out_channels] * (self.num_vertices - 1)
115114
116115 # operation for each node
117- self .vertex_op = nn .ModuleList ([None ])
116+ self .vertex_op = nn .ModuleList ([Placeholder () ])
118117 for t in range (1 , self .num_vertices - 1 ):
119118 op = OP_MAP [spec .ops [t ]](self .vertex_channels [t ], self .vertex_channels [t ])
120119 self .vertex_op .append (op )
121120
122121 # operation for input on each vertex
123- self .input_op = nn .ModuleList ([None ])
122+ self .input_op = nn .ModuleList ([Placeholder () ])
124123 for t in range (1 , self .num_vertices ):
125- if self .spec . matrix [0 , t ]:
124+ if self .matrix [0 , t ]:
126125 self .input_op .append (Projection (in_channels , self .vertex_channels [t ]))
127126 else :
128- self .input_op .append (None )
127+ self .input_op .append (Placeholder ())
128+
129+ self .last_inop : Projection = self .input_op [self .num_vertices - 1 ]
129130
130131 def forward (self , x ):
131132 tensors = [x ]
132133
133134 out_concat = []
134- for t in range (1 , self .num_vertices - 1 ):
135- fan_in = [Truncate (tensors [src ], self .vertex_channels [t ]) for src in range (1 , t ) if self .spec .matrix [src , t ]]
136-
137- if self .spec .matrix [0 , t ]:
138- fan_in .append (self .input_op [t ](x ))
139-
140- # perform operation on node
141- #vertex_input = torch.stack(fan_in, dim=0).sum(dim=0)
142- vertex_input = sum (fan_in )
143- #vertex_input = sum(fan_in) / len(fan_in)
144- vertex_output = self .vertex_op [t ](vertex_input )
145-
146- tensors .append (vertex_output )
147- if self .spec .matrix [t , self .num_vertices - 1 ]:
148- out_concat .append (tensors [t ])
135+ # range(1, self.num_vertices - 1),
136+ for t , (inmod , outmod ) in enumerate (zip (self .input_op , self .vertex_op )):
137+ if 0 < t < (self .num_vertices - 1 ):
138+
139+ fan_in = []
140+ for src in range (1 , t ):
141+ if self .matrix [src , t ]:
142+ fan_in .append (Truncate (tensors [src ], torch .tensor (self .vertex_channels [t ])))
143+
144+ if self .matrix [0 , t ]:
145+ l = inmod (x )
146+ fan_in .append (l )
147+
148+ # perform operation on node
149+ #vertex_input = torch.stack(fan_in, dim=0).sum(dim=0)
150+ vertex_input = torch .zeros (fan_in [0 ].shape )
151+ for val in fan_in :
152+ vertex_input += val
153+ #vertex_input = sum(fan_in)
154+ #vertex_input = sum(fan_in) / len(fan_in)
155+ vertex_output = outmod (vertex_input )
156+
157+ tensors .append (vertex_output )
158+ if self .matrix [t , self .num_vertices - 1 ]:
159+ out_concat .append (tensors [t ])
149160
150161 if not out_concat :
151- assert self .spec . matrix [0 , self .num_vertices - 1 ]
152- outputs = self .input_op [ self . num_vertices - 1 ] (tensors [0 ])
162+ assert self .matrix [0 , self .num_vertices - 1 ]
163+ outputs = self .last_inop (tensors [0 ])
153164 else :
154165 if len (out_concat ) == 1 :
155166 outputs = out_concat [0 ]
156167 else :
157168 outputs = torch .cat (out_concat , 1 )
158169
159- if self .spec . matrix [0 , self .num_vertices - 1 ]:
160- outputs += self .input_op [ self . num_vertices - 1 ] (tensors [0 ])
170+ if self .matrix [0 , self .num_vertices - 1 ]:
171+ outputs += self .last_inop (tensors [0 ])
161172
162- #if self.spec. matrix[0, self.num_vertices-1]:
173+ #if self.matrix[0, self.num_vertices-1]:
163174 # out_concat.append(self.input_op[self.num_vertices-1](tensors[0]))
164175 #outputs = sum(out_concat) / len(out_concat)
165176
@@ -196,6 +207,9 @@ def ComputeVertexChannels(in_channels, out_channels, matrix):
196207 Returns:
197208 list of channel counts, in order of the vertices.
198209 """
210+ if isinstance (matrix , torch .Tensor ):
211+ matrix = matrix .numpy ()
212+
199213 num_vertices = np .shape (matrix )[0 ]
200214
201215 vertex_channels = [0 ] * num_vertices
@@ -241,4 +255,13 @@ def ComputeVertexChannels(in_channels, out_channels, matrix):
241255 assert final_fan_in == out_channels or num_vertices == 2
242256 # num_vertices == 2 means only input/output nodes, so 0 fan-in
243257
244- return vertex_channels
258+ return [int (v ) for v in vertex_channels ]
259+
260+
261+ class Placeholder (torch .nn .Module ):
262+ def __init__ (self ):
263+ super ().__init__ ()
264+ self .a = torch .nn .Parameter (torch .randn (()))
265+
266+ def forward (self , x ):
267+ return x
0 commit comments