Skip to content

Commit 8446048

Browse files
authored
Stop fitting pipeline after last fit block (#132)
* initial early stop * change to stop after fitting the last block with attribute * test early-stop calls * remove comment * change to fit pending
1 parent 1af7b1b commit 8446048

2 files changed

Lines changed: 73 additions & 10 deletions

File tree

mlblocks/mlpipeline.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def _get_tunable_hyperparameters(self):
9696

9797
def _build_blocks(self):
9898
blocks = OrderedDict()
99+
last_fit_block = None
99100

100101
block_names_count = Counter()
101102
for primitive in self.primitives:
@@ -118,11 +119,14 @@ def _build_blocks(self):
118119
block = MLBlock(primitive, **block_params)
119120
blocks[block_name] = block
120121

122+
if bool(block._fit):
123+
last_fit_block = block_name
124+
121125
except Exception:
122126
LOGGER.exception('Exception caught building MLBlock %s', primitive)
123127
raise
124128

125-
return blocks
129+
return blocks, last_fit_block
126130

127131
@staticmethod
128132
def _get_pipeline_dict(pipeline, primitives):
@@ -207,7 +211,7 @@ def __init__(self, pipeline=None, primitives=None, init_params=None,
207211

208212
self.primitives = primitives or pipeline['primitives']
209213
self.init_params = init_params or pipeline.get('init_params', dict())
210-
self.blocks = self._build_blocks()
214+
self.blocks, self._last_fit_block = self._build_blocks()
211215
self._last_block_name = self._get_block_name(-1)
212216

213217
self.input_names = input_names or pipeline.get('input_names', dict())
@@ -767,7 +771,11 @@ def fit(self, X=None, y=None, output_=None, start_=None, debug=False, **kwargs):
767771
debug_info = defaultdict(dict)
768772
debug_info['debug'] = debug.lower() if isinstance(debug, str) else 'tmio'
769773

774+
fit_pending = True
770775
for block_name, block in self.blocks.items():
776+
if block_name == self._last_fit_block:
777+
fit_pending = False
778+
771779
if start_:
772780
if block_name == start_:
773781
start_ = False
@@ -777,7 +785,7 @@ def fit(self, X=None, y=None, output_=None, start_=None, debug=False, **kwargs):
777785

778786
self._fit_block(block, block_name, context, debug_info)
779787

780-
if (block_name != self._last_block_name) or (block_name in output_blocks):
788+
if fit_pending or output_blocks:
781789
self._produce_block(
782790
block, block_name, context, output_variables, outputs, debug_info)
783791

@@ -787,16 +795,23 @@ def fit(self, X=None, y=None, output_=None, start_=None, debug=False, **kwargs):
787795

788796
# If there was an output_ but there are no pending
789797
# outputs we are done.
790-
if output_variables is not None and not output_blocks:
791-
if len(outputs) > 1:
792-
result = tuple(outputs)
793-
else:
794-
result = outputs[0]
798+
if output_variables:
799+
if not output_blocks:
800+
if len(outputs) > 1:
801+
result = tuple(outputs)
802+
else:
803+
result = outputs[0]
804+
805+
if debug:
806+
return result, debug_info
807+
808+
return result
795809

810+
elif not fit_pending:
796811
if debug:
797-
return result, debug_info
812+
return debug_info
798813

799-
return result
814+
return
800815

801816
if start_:
802817
# We skipped all the blocks up to the end

tests/test_mlpipeline.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,54 @@ def test_get_inputs_no_fit(self):
681681

682682
assert inputs == expected
683683

684+
@patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock)
685+
def test_fit_pending_all_primitives(self):
686+
block_1 = get_mlblock_mock()
687+
block_2 = get_mlblock_mock()
688+
blocks = OrderedDict((
689+
('a.primitive.Name#1', block_1),
690+
('a.primitive.Name#2', block_2),
691+
))
692+
693+
self_ = MagicMock(autospec=MLPipeline)
694+
self_.blocks = blocks
695+
self_._last_fit_block = 'a.primitive.Name#2'
696+
697+
MLPipeline.fit(self_)
698+
699+
expected = [
700+
call('a.primitive.Name#1'),
701+
call('a.primitive.Name#2')
702+
]
703+
self_._fit_block.call_args_list = expected
704+
705+
expected = [
706+
call('a.primitive.Name#1'),
707+
]
708+
self_._produce_block.call_args_list = expected
709+
710+
@patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock)
711+
def test_fit_pending_one_primitive(self):
712+
block_1 = get_mlblock_mock()
713+
block_2 = get_mlblock_mock()
714+
blocks = OrderedDict((
715+
('a.primitive.Name#1', block_1),
716+
('a.primitive.Name#2', block_2),
717+
))
718+
719+
self_ = MagicMock(autospec=MLPipeline)
720+
self_.blocks = blocks
721+
self_._last_fit_block = 'a.primitive.Name#1'
722+
723+
MLPipeline.fit(self_)
724+
725+
expected = [
726+
call('a.primitive.Name#1'),
727+
]
728+
self_._fit_block.call_args_list = expected
729+
730+
assert not self_._produce_block.called
731+
684732
@patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock)
685733
def test_fit_no_debug(self):
686734
mlpipeline = MLPipeline(['a_primitive'])

0 commit comments

Comments
 (0)