@@ -96,6 +96,7 @@ def _get_tunable_hyperparameters(self):
9696
9797 def _build_blocks (self ):
9898 blocks = OrderedDict ()
99+ last_fit_block = None
99100
100101 block_names_count = Counter ()
101102 for primitive in self .primitives :
@@ -118,11 +119,14 @@ def _build_blocks(self):
118119 block = MLBlock (primitive , ** block_params )
119120 blocks [block_name ] = block
120121
122+ if bool (block ._fit ):
123+ last_fit_block = block_name
124+
121125 except Exception :
122126 LOGGER .exception ('Exception caught building MLBlock %s' , primitive )
123127 raise
124128
125- return blocks
129+ return blocks , last_fit_block
126130
127131 @staticmethod
128132 def _get_pipeline_dict (pipeline , primitives ):
@@ -207,7 +211,7 @@ def __init__(self, pipeline=None, primitives=None, init_params=None,
207211
208212 self .primitives = primitives or pipeline ['primitives' ]
209213 self .init_params = init_params or pipeline .get ('init_params' , dict ())
210- self .blocks = self ._build_blocks ()
214+ self .blocks , self . _last_fit_block = self ._build_blocks ()
211215 self ._last_block_name = self ._get_block_name (- 1 )
212216
213217 self .input_names = input_names or pipeline .get ('input_names' , dict ())
@@ -767,7 +771,11 @@ def fit(self, X=None, y=None, output_=None, start_=None, debug=False, **kwargs):
767771 debug_info = defaultdict (dict )
768772 debug_info ['debug' ] = debug .lower () if isinstance (debug , str ) else 'tmio'
769773
774+ fit_pending = True
770775 for block_name , block in self .blocks .items ():
776+ if block_name == self ._last_fit_block :
777+ fit_pending = False
778+
771779 if start_ :
772780 if block_name == start_ :
773781 start_ = False
@@ -777,7 +785,7 @@ def fit(self, X=None, y=None, output_=None, start_=None, debug=False, **kwargs):
777785
778786 self ._fit_block (block , block_name , context , debug_info )
779787
780- if ( block_name != self . _last_block_name ) or ( block_name in output_blocks ) :
788+ if fit_pending or output_blocks :
781789 self ._produce_block (
782790 block , block_name , context , output_variables , outputs , debug_info )
783791
@@ -787,16 +795,23 @@ def fit(self, X=None, y=None, output_=None, start_=None, debug=False, **kwargs):
787795
788796 # If there was an output_ but there are no pending
789797 # outputs we are done.
790- if output_variables is not None and not output_blocks :
791- if len (outputs ) > 1 :
792- result = tuple (outputs )
793- else :
794- result = outputs [0 ]
798+ if output_variables :
799+ if not output_blocks :
800+ if len (outputs ) > 1 :
801+ result = tuple (outputs )
802+ else :
803+ result = outputs [0 ]
804+
805+ if debug :
806+ return result , debug_info
807+
808+ return result
795809
810+ elif not fit_pending :
796811 if debug :
797- return result , debug_info
812+ return debug_info
798813
799- return result
814+ return
800815
801816 if start_ :
802817 # We skipped all the blocks up to the end
0 commit comments