@@ -225,6 +225,7 @@ def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, initial_epoc
225225 # Train
226226 print ("Train on {0} samples, validate on {1} samples, {2} steps per epoch" .format (
227227 len (train_tensor_data ), len (val_y ), steps_per_epoch ))
228+ num_tasks = getattr (self , "num_tasks" , 1 )
228229 for epoch in range (initial_epoch , epochs ):
229230 callbacks .on_epoch_begin (epoch )
230231 epoch_logs = {}
@@ -239,15 +240,15 @@ def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, initial_epoc
239240 y = y_train .to (self .device ).float ()
240241
241242 y_pred = model (x )
242- if self . num_tasks == 1 and y_pred .ndim > 1 and y_pred .shape [- 1 ] == 1 :
243+ if num_tasks == 1 and y_pred .ndim > 1 and y_pred .shape [- 1 ] == 1 :
243244 y_pred = y_pred .squeeze (- 1 )
244245
245246 optim .zero_grad ()
246247 if isinstance (loss_func , list ):
247- assert len (loss_func ) == self . num_tasks ,\
248- "the length of `loss_func` should be equal with `self. num_tasks`"
248+ assert len (loss_func ) == num_tasks ,\
249+ "the length of `loss_func` should be equal with `num_tasks`"
249250 loss = sum (
250- [loss_func [i ](y_pred [:, i ], y [:, i ], reduction = 'sum' ) for i in range (self . num_tasks )])
251+ [loss_func [i ](y_pred [:, i ], y [:, i ], reduction = 'sum' ) for i in range (num_tasks )])
251252 else :
252253 y_for_loss = y
253254 if y_for_loss .ndim > 1 and y_for_loss .shape [- 1 ] == 1 :
0 commit comments