diff --git a/haystack/core/pipeline/async_pipeline.py b/haystack/core/pipeline/async_pipeline.py index f3d48b34dc..05a7a514fc 100644 --- a/haystack/core/pipeline/async_pipeline.py +++ b/haystack/core/pipeline/async_pipeline.py @@ -182,6 +182,167 @@ async def _run_component_async( return outputs + @staticmethod + async def _wait_for_tasks( + running_tasks: dict[asyncio.Task, str], scheduled_components: set[str], *, return_when: str + ) -> AsyncIterator[dict[str, Any]]: + """ + Waits for running tasks to finish and yields their partial outputs. + + :param running_tasks: Mapping of in-flight tasks to the name of the component they run. Finished tasks are + removed in place. + :param scheduled_components: Set of component names that are scheduled but not yet finished. Finished + components are discarded in place. + :param return_when: Either `asyncio.FIRST_COMPLETED` to wait for a single task or `asyncio.ALL_COMPLETED` to + wait for every running task. + :returns: An async iterator of partial outputs, one per finished component that produced an output. + """ + if not running_tasks: + return + + done, _pending = await asyncio.wait(running_tasks.keys(), return_when=return_when) + for finished in done: + finished_component_name = running_tasks.pop(finished) + partial_result = finished.result() + scheduled_components.discard(finished_component_name) + if partial_result: + yield {finished_component_name: _deepcopy_with_exceptions(partial_result)} + + async def _run_component_in_isolation( + self, + *, + component_name: str, + inputs: dict[str, dict[str, list[dict[str, Any]]]], + pipeline_outputs: dict[str, Any], + component_visits: dict[str, int], + running_tasks: dict[asyncio.Task, str], + scheduled_components: set[str], + cached_receivers: dict[str, Any], + include_outputs_from: set[str], + parent_span: tracing.Span | None, + ) -> AsyncIterator[dict[str, Any]]: + """ + Runs a component with HIGHEST priority in isolation. + + We need to run components with HIGHEST priority (i.e. components with a GreedyVariadic input socket) by + themselves, without any other components running concurrently. Otherwise, downstream components could produce + additional inputs for the GreedyVariadic socket. + + :param component_name: The name of the component to run. + :param inputs: The global input state shared by all components. Mutated in place. + :param pipeline_outputs: The accumulated pipeline outputs. Mutated in place. + :param component_visits: Current state of component visits. Mutated in place. + :param running_tasks: Mapping of in-flight tasks to component names. Drained in place before running. + :param scheduled_components: Set of scheduled-but-unfinished component names. Mutated in place. + :param cached_receivers: Precomputed mapping of component name to its downstream receivers. + :param include_outputs_from: Set of component names whose outputs should always be included in the output. + :param parent_span: The parent tracing span for the pipeline run. + :returns: An async iterator of partial outputs. + """ + # 1) Wait for all in-flight tasks to finish so the HIGHEST component runs alone. + async for partial_outputs in self._wait_for_tasks( + running_tasks, scheduled_components, return_when=asyncio.ALL_COMPLETED + ): + yield partial_outputs + + if component_name in scheduled_components: + # If it's already scheduled for some reason, skip. + return + + # 2) Run the HIGHEST component by itself. + scheduled_components.add(component_name) + component = self._get_component_with_graph_metadata_and_visits(component_name, component_visits[component_name]) + component_inputs = self._consume_component_inputs(component_name, component, inputs) + component_inputs = self._add_missing_input_defaults(component_inputs, component["input_sockets"]) + + component_outputs = await self._run_component_async( + component_name=component_name, + component=component, + component_inputs=component_inputs, + component_visits=component_visits, + parent_span=parent_span, + ) + + pruned = self._write_component_outputs( + component_name=component_name, + component_outputs=component_outputs, + inputs=inputs, + receivers=cached_receivers[component_name], + include_outputs_from=include_outputs_from, + ) + if pruned or component_name in include_outputs_from: + pipeline_outputs[component_name] = pruned + + scheduled_components.remove(component_name) + if pruned or component_name in include_outputs_from: + yield {component_name: _deepcopy_with_exceptions(pruned)} + + def _schedule_component( + self, + *, + component_name: str, + inputs: dict[str, dict[str, list[dict[str, Any]]]], + pipeline_outputs: dict[str, Any], + component_visits: dict[str, int], + running_tasks: dict[asyncio.Task, str], + scheduled_components: set[str], + ready_sem: asyncio.Semaphore, + cached_receivers: dict[str, Any], + include_outputs_from: set[str], + parent_span: tracing.Span | None, + ) -> None: + """ + Schedules a component to run as a background task without waiting for it to finish. + + Inputs are consumed synchronously here (before the task is created) so that other components scheduled in the + same iteration of the scheduling loop observe the updated input state and don't race for the same inputs. + + :param component_name: The name of the component to schedule. + :param inputs: The global input state shared by all components. Mutated in place. + :param pipeline_outputs: The accumulated pipeline outputs. Mutated in place by the task once it finishes. + :param component_visits: Current state of component visits. Mutated in place by the task once it finishes. + :param running_tasks: Mapping of in-flight tasks to component names. The new task is registered here. + :param scheduled_components: Set of scheduled-but-unfinished component names. Mutated in place. + :param ready_sem: Semaphore bounding how many components run concurrently. + :param cached_receivers: Precomputed mapping of component name to its downstream receivers. + :param include_outputs_from: Set of component names whose outputs should always be included in the output. + :param parent_span: The parent tracing span for the pipeline run. + """ + if component_name in scheduled_components: + return # already scheduled, do nothing + + scheduled_components.add(component_name) + + component = self._get_component_with_graph_metadata_and_visits(component_name, component_visits[component_name]) + component_inputs = self._consume_component_inputs(component_name, component, inputs) + component_inputs = self._add_missing_input_defaults(component_inputs, component["input_sockets"]) + + async def _runner() -> Mapping[str, Any]: + async with ready_sem: + component_outputs = await self._run_component_async( + component_name=component_name, + component=component, + component_inputs=component_inputs, + component_visits=component_visits, + parent_span=parent_span, + ) + + pruned = self._write_component_outputs( + component_name=component_name, + component_outputs=component_outputs, + inputs=inputs, + receivers=cached_receivers[component_name], + include_outputs_from=include_outputs_from, + ) + if pruned or component_name in include_outputs_from: + pipeline_outputs[component_name] = pruned + + scheduled_components.remove(component_name) + return pruned + + task = asyncio.create_task(_runner()) + running_tasks[task] = component_name + async def run_async_generator( # noqa: PLR0915,C901 self, data: dict[str, Any], include_outputs_from: set[str] | None = None, concurrency_limit: int = 4 ) -> AsyncIterator[dict[str, Any]]: @@ -276,203 +437,75 @@ async def process_results(): it to get stuck and fail running. Or if a Component fails or returns output in an unsupported type. """ + pipeline_running(self) # telemetry + + # warm up the pipeline by running each component's warm_up method + self.warm_up() + if include_outputs_from is None: include_outputs_from = set() - # 0) Basic pipeline init - pipeline_running(self) # telemetry - self.warm_up() # optional warm-up (if needed) - - # 1) Prepare ephemeral state - ready_sem = asyncio.Semaphore(max(1, concurrency_limit)) - inputs_state: dict[str, dict[str, list[dict[str, Any]]]] = {} pipeline_outputs: dict[str, Any] = {} - running_tasks: dict[asyncio.Task, str] = {} - # A set of component names that have been scheduled but not finished: - scheduled_components: set[str] = set() + # Normalize `data` and raise ValueError if the input is malformed in some way. + data = self._prepare_component_input_data(data) - # 2) Convert input data - prepared_data = self._prepare_component_input_data(data) + # Raise ValueError if input is malformed in some way + self.validate_input(data) - # raises ValueError if input is malformed in some way - self.validate_input(prepared_data) - inputs_state = self._convert_to_internal_format(prepared_data) + # We create a list of components in the pipeline sorted by name, so that the algorithm runs + # deterministically and independent of insertion order into the pipeline. + ordered_component_names = sorted(self.graph.nodes.keys()) + + # We track component visits to decide if a component can run. + component_visits = dict.fromkeys(ordered_component_names, 0) - # For quick lookup of downstream receivers - ordered_names = sorted(self.graph.nodes.keys()) - cached_receivers = {n: self._find_receivers_from(n) for n in ordered_names} - component_visits = dict.fromkeys(ordered_names, 0) cached_topological_sort = None + # We need to access a component's receivers multiple times during a pipeline run. + # We store them here for easy access. + cached_receivers = {name: self._find_receivers_from(name) for name in ordered_component_names} - # We fill the queue once and raise if all components are BLOCKED - self.validate_pipeline(self._fill_queue(ordered_names, inputs_state, component_visits)) + # Ephemeral concurrency state shared (and mutated in place) by the scheduling helpers below. + ready_sem = asyncio.Semaphore(max(1, concurrency_limit)) + running_tasks: dict[asyncio.Task, str] = {} + # A set of component names that have been scheduled but not finished. + scheduled_components: set[str] = set() - # Single parent span for entire pipeline execution with tracing.tracer.trace( "haystack.async_pipeline.run", tags={ - "haystack.pipeline.input_data": prepared_data, + "haystack.pipeline.input_data": data, "haystack.pipeline.output_data": pipeline_outputs, "haystack.pipeline.metadata": self.metadata, "haystack.pipeline.max_runs_per_component": self._max_runs_per_component, }, ) as parent_span: - # ------------------------------------------------- - # We define some functions here so that they have access to local runtime state - # (inputs, tasks, scheduled components) via closures. - # ------------------------------------------------- - async def _run_highest_in_isolation(component_name: str) -> AsyncIterator[dict[str, Any]]: - """ - Runs a component with HIGHEST priority in isolation. - - We need to run components with HIGHEST priority (i.e. components with GreedyVariadic input socket) - by themselves, without any other components running concurrently. Otherwise, downstream components - could produce additional inputs for the GreedyVariadic socket. - - :param component_name: The name of the component. - :return: An async iterator of partial outputs. - """ - # 1) Wait for all in-flight tasks to finish - while running_tasks: - done, _pending = await asyncio.wait(running_tasks.keys(), return_when=asyncio.ALL_COMPLETED) - for finished in done: - finished_component_name = running_tasks.pop(finished) - partial_result = finished.result() - scheduled_components.discard(finished_component_name) - if partial_result: - yield_dict = {finished_component_name: _deepcopy_with_exceptions(partial_result)} - yield yield_dict # partial outputs - - if component_name in scheduled_components: - # If it's already scheduled for some reason, skip - return - - # 2) Run the HIGHEST component by itself - scheduled_components.add(component_name) - comp_dict = self._get_component_with_graph_metadata_and_visits( - component_name, component_visits[component_name] - ) - component_inputs = self._consume_component_inputs(component_name, comp_dict, inputs_state) - component_inputs = self._add_missing_input_defaults(component_inputs, comp_dict["input_sockets"]) - - component_pipeline_outputs = await self._run_component_async( - component_name=component_name, - component=comp_dict, - component_inputs=component_inputs, - component_visits=component_visits, - parent_span=parent_span, - ) - - # Distribute outputs to downstream inputs; also prune outputs based on `include_outputs_from` - pruned = self._write_component_outputs( - component_name=component_name, - component_outputs=component_pipeline_outputs, - inputs=inputs_state, - receivers=cached_receivers[component_name], - include_outputs_from=include_outputs_from, - ) - if pruned or component_name in include_outputs_from: - pipeline_outputs[component_name] = pruned - - scheduled_components.remove(component_name) - if pruned or component_name in include_outputs_from: - yield {component_name: _deepcopy_with_exceptions(pruned)} - - async def _schedule_task(component_name: str) -> None: - """ - Schedule a component to run. - - We do NOT wait for it to finish here. This allows us to run other components concurrently. - - :param component_name: The name of the component. - """ - - if component_name in scheduled_components: - return # already scheduled, do nothing - - scheduled_components.add(component_name) - - comp_dict = self._get_component_with_graph_metadata_and_visits( - component_name, component_visits[component_name] - ) - component_inputs = self._consume_component_inputs(component_name, comp_dict, inputs_state) - component_inputs = self._add_missing_input_defaults(component_inputs, comp_dict["input_sockets"]) + inputs = self._convert_to_internal_format(pipeline_inputs=data) - async def _runner() -> Mapping[str, Any]: - async with ready_sem: - component_pipeline_outputs = await self._run_component_async( - component_name=component_name, - component=comp_dict, - component_inputs=component_inputs, - component_visits=component_visits, - parent_span=parent_span, - ) + # check if pipeline is blocked before execution + self.validate_pipeline(self._fill_queue(ordered_component_names, inputs, component_visits)) - # Distribute outputs to downstream inputs; also prune outputs based on `include_outputs_from` - pruned = self._write_component_outputs( - component_name=component_name, - component_outputs=component_pipeline_outputs, - inputs=inputs_state, - receivers=cached_receivers[component_name], - include_outputs_from=include_outputs_from, - ) - if pruned or component_name in include_outputs_from: - pipeline_outputs[component_name] = pruned - - scheduled_components.remove(component_name) - return pruned - - task = asyncio.create_task(_runner()) - running_tasks[task] = component_name - - async def _wait_for_one_task_to_complete() -> AsyncIterator[dict[str, Any]]: - """ - Wait for exactly one running task to finish, yield partial outputs. - - If no tasks are running, does nothing. - """ - if running_tasks: - done, _ = await asyncio.wait(running_tasks.keys(), return_when=asyncio.FIRST_COMPLETED) - for finished in done: - finished_component_name = running_tasks.pop(finished) - partial_result = finished.result() - scheduled_components.discard(finished_component_name) - if partial_result: - yield {finished_component_name: _deepcopy_with_exceptions(partial_result)} - - async def _wait_for_all_tasks_to_complete() -> AsyncIterator[dict[str, Any]]: - """ - Wait for all running tasks to finish, yield partial outputs. - """ - if running_tasks: - done, _ = await asyncio.wait(running_tasks.keys(), return_when=asyncio.ALL_COMPLETED) - for finished in done: - finished_component_name = running_tasks.pop(finished) - partial_result = finished.result() - scheduled_components.discard(finished_component_name) - if partial_result: - yield {finished_component_name: _deepcopy_with_exceptions(partial_result)} - - # ------------------------------------------------- - # MAIN SCHEDULING LOOP - # ------------------------------------------------- while True: - # 2) Build the priority queue of candidates - priority_queue = self._fill_queue(ordered_names, inputs_state, component_visits) + # We rebuild the priority queue every iteration: each iteration waits for one or more concurrent tasks + # to finish, which mutates `inputs` and can change many components' priorities at once, so we rebuild + # to give every scheduling decision an up-to-date view. + priority_queue = self._fill_queue(ordered_component_names, inputs, component_visits) candidate = self._get_next_runnable_component(priority_queue, component_visits) + # If we can't make progress with the queue but tasks are running, we wait for one to finish and retry + # to potentially unblock the priority queue. if (candidate is None or candidate[0] == ComponentPriority.BLOCKED) and running_tasks: - # We need to wait for one task to finish to make progress and potentially unblock the priority_queue - async for partial_res in _wait_for_one_task_to_complete(): - yield partial_res + async for partial_outputs in self._wait_for_tasks( + running_tasks, scheduled_components, return_when=asyncio.FIRST_COMPLETED + ): + yield partial_outputs continue + # If there are no runnable components left and nothing is running, we can exit the loop. if candidate is None and not running_tasks: - # done break - priority, comp_name, comp = candidate # type: ignore + priority, component_name, component = candidate # type: ignore # If the next component is blocked, we do a check to see if the pipeline is possibly blocked and raise # a warning if it is. @@ -480,68 +513,109 @@ async def _wait_for_all_tasks_to_complete() -> AsyncIterator[dict[str, Any]]: if self._is_pipeline_possibly_blocked(current_pipeline_outputs=pipeline_outputs): # Pipeline is most likely blocked (most likely a configuration issue) so we raise a warning. self._find_components_blocking_pipeline( - priority_queue=priority_queue, component_visits=component_visits, inputs=inputs_state + priority_queue=priority_queue, component_visits=component_visits, inputs=inputs ) # We always exit the loop since we cannot run the next component. break - if comp_name in scheduled_components: - # We need to wait for one task to finish to make progress - async for partial_res in _wait_for_one_task_to_complete(): - yield partial_res + # If the next component is already scheduled, we wait for a task to finish to make progress. + if component_name in scheduled_components: + async for partial_outputs in self._wait_for_tasks( + running_tasks, scheduled_components, return_when=asyncio.FIRST_COMPLETED + ): + yield partial_outputs continue if priority == ComponentPriority.HIGHEST: - # 1) run alone - async for partial_res in _run_highest_in_isolation(comp_name): - yield partial_res - # then continue the loop + # A HIGHEST priority component must run alone, so we hand off to the isolation helper. + async for partial_outputs in self._run_component_in_isolation( + component_name=component_name, + inputs=inputs, + pipeline_outputs=pipeline_outputs, + component_visits=component_visits, + running_tasks=running_tasks, + scheduled_components=scheduled_components, + cached_receivers=cached_receivers, + include_outputs_from=include_outputs_from, + parent_span=parent_span, + ): + yield partial_outputs continue if priority == ComponentPriority.READY: - # 1) schedule this one - await _schedule_task(comp_name) + # Schedule this component, then schedule as many additional READY components as concurrency allows. + self._schedule_component( + component_name=component_name, + inputs=inputs, + pipeline_outputs=pipeline_outputs, + component_visits=component_visits, + running_tasks=running_tasks, + scheduled_components=scheduled_components, + ready_sem=ready_sem, + cached_receivers=cached_receivers, + include_outputs_from=include_outputs_from, + parent_span=parent_span, + ) - # 2) Possibly schedule more READY tasks if concurrency not fully used + # Possibly schedule more READY tasks if concurrency not fully used while len(priority_queue) > 0 and not ready_sem.locked(): - peek_prio, peek_name = priority_queue.peek() - if peek_prio in (ComponentPriority.BLOCKED, ComponentPriority.HIGHEST): - # can't run or must run alone => skip + peek_priority, peek_name = priority_queue.peek() + if peek_priority != ComponentPriority.READY: + # We stop scheduling: the next component is BLOCKED (can't run), HIGHEST (must run alone), + # or DEFER (waiting for more inputs - we only schedule it once it becomes READY). break - if peek_prio == ComponentPriority.READY: - priority_queue.pop() - await _schedule_task(peek_name) - # keep adding while concurrency is not locked - continue - - # The next is DEFER => we only schedule it if it "becomes READY" - # We'll handle it in the next iteration or with incremental waiting - break + priority_queue.pop() + self._schedule_component( + component_name=peek_name, + inputs=inputs, + pipeline_outputs=pipeline_outputs, + component_visits=component_visits, + running_tasks=running_tasks, + scheduled_components=scheduled_components, + ready_sem=ready_sem, + cached_receivers=cached_receivers, + include_outputs_from=include_outputs_from, + parent_span=parent_span, + ) - # We only schedule components with priority DEFER when no other tasks are running + # We only schedule components with priority DEFER when no other tasks are running. elif priority == ComponentPriority.DEFER and not running_tasks: if len(priority_queue) > 0: - comp_name, topological_sort = self._tiebreak_waiting_components( - component_name=comp_name, + component_name, cached_topological_sort = self._tiebreak_waiting_components( + component_name=component_name, priority=priority, priority_queue=priority_queue, topological_sort=cached_topological_sort, ) - cached_topological_sort = topological_sort - await _schedule_task(comp_name) - - # To make progress, we wait for one task to complete before re-starting the loop - async for partial_res in _wait_for_one_task_to_complete(): - yield partial_res - - # End main loop + self._schedule_component( + component_name=component_name, + inputs=inputs, + pipeline_outputs=pipeline_outputs, + component_visits=component_visits, + running_tasks=running_tasks, + scheduled_components=scheduled_components, + ready_sem=ready_sem, + cached_receivers=cached_receivers, + include_outputs_from=include_outputs_from, + parent_span=parent_span, + ) - # 3) Drain leftover tasks - async for partial_res in _wait_for_all_tasks_to_complete(): - yield partial_res + # To make progress, we wait for one task to complete before restarting the loop. + async for partial_outputs in self._wait_for_tasks( + running_tasks, scheduled_components, return_when=asyncio.FIRST_COMPLETED + ): + yield partial_outputs + + # Safety net: drain any leftover tasks once the scheduling loop has finished. With the current loop both + # `break` paths require `running_tasks` to be empty, so this is a no-op. We keep it so that a future change + # adding a `break` that leaves tasks in flight doesn't lose outputs. + async for partial_outputs in self._wait_for_tasks( + running_tasks, scheduled_components, return_when=asyncio.ALL_COMPLETED + ): + yield partial_outputs - # 4) Yield final pipeline outputs + # Yield the final pipeline outputs. yield pipeline_outputs async def run_async( diff --git a/haystack/core/pipeline/pipeline.py b/haystack/core/pipeline/pipeline.py index e957c94fce..305be593fd 100644 --- a/haystack/core/pipeline/pipeline.py +++ b/haystack/core/pipeline/pipeline.py @@ -231,7 +231,7 @@ def run( # noqa: PLR0915, PLR0912, C901 :raises PipelineBreakpointException: When a pipeline_breakpoint is triggered. Contains the component name, state, and partial results. """ - pipeline_running(self) + pipeline_running(self) # telemetry if break_point and pipeline_snapshot: msg = ( @@ -244,8 +244,7 @@ def run( # noqa: PLR0915, PLR0912, C901 if break_point: _validate_break_point_against_pipeline(break_point, self.graph) - # TODO: Remove this warmup once we can check reliably whether a component has been warmed up or not - # As of now it's here to make sure we don't have failing tests that assume warm_up() is called in run() + # warm up the pipeline by running each component's warm_up method self.warm_up() if include_outputs_from is None: diff --git a/test/core/pipeline/test_async_pipeline.py b/test/core/pipeline/test_async_pipeline.py index 3235e0741c..9042bac134 100644 --- a/test/core/pipeline/test_async_pipeline.py +++ b/test/core/pipeline/test_async_pipeline.py @@ -9,6 +9,7 @@ import pytest from haystack import AsyncPipeline, Document, component +from haystack.components.joiners import BranchJoiner def test_async_pipeline_reentrance(waiting_component, spying_tracer): @@ -280,3 +281,130 @@ def run(self, doc: Document) -> dict: assert result["producer"]["doc"].content == "original" assert result["mutator"]["doc"].content == "mutated" + + +@component +class _Doubler: + """Minimal component used to exercise the isolation helper.""" + + @component.output_types(value=int) + def run(self, value: int) -> dict[str, int]: + return {"value": value * 2} + + +def _build_isolation_state(pipeline: AsyncPipeline, data: dict) -> dict: + """ + Build the ephemeral run state that `_run_component_in_isolation` expects. + + Mirrors the setup `run_async_generator` performs before the scheduling loop. + """ + inputs = pipeline._convert_to_internal_format(pipeline._prepare_component_input_data(data)) + names = sorted(pipeline.graph.nodes.keys()) + return { + "inputs": inputs, + "pipeline_outputs": {}, + "component_visits": dict.fromkeys(names, 0), + "running_tasks": {}, + "scheduled_components": set(), + "cached_receivers": {name: pipeline._find_receivers_from(name) for name in names}, + "include_outputs_from": set(), + "parent_span": None, + } + + +class TestRunComponentInIsolation: + @pytest.mark.asyncio + async def test_runs_component_and_yields_output(self): + pp = AsyncPipeline() + pp.add_component("doubler", _Doubler()) + state = _build_isolation_state(pp, {"doubler": {"value": 3}}) + + results = [out async for out in pp._run_component_in_isolation(component_name="doubler", **state)] + + assert results == [{"doubler": {"value": 6}}] + assert state["pipeline_outputs"] == {"doubler": {"value": 6}} + assert state["component_visits"]["doubler"] == 1 + # The component is added to and removed from scheduled_components over the course of the run. + assert state["scheduled_components"] == set() + + @pytest.mark.asyncio + async def test_runs_greedy_component_consuming_single_input(self): + pp = AsyncPipeline() + pp.add_component("joiner", BranchJoiner(type_=int)) + state = _build_isolation_state(pp, {}) + # Two values are queued on the greedy variadic socket; greedy consumption keeps only the first. + state["inputs"]["joiner"] = {"value": [{"sender": None, "value": 1}, {"sender": None, "value": 2}]} + + results = [out async for out in pp._run_component_in_isolation(component_name="joiner", **state)] + + assert results == [{"joiner": {"value": 1}}] + assert state["component_visits"]["joiner"] == 1 + + @pytest.mark.asyncio + async def test_drains_in_flight_tasks_before_running(self): + pp = AsyncPipeline() + pp.add_component("doubler", _Doubler()) + state = _build_isolation_state(pp, {"doubler": {"value": 3}}) + + async def _in_flight() -> dict: + return {"value": 99} + + task = asyncio.create_task(_in_flight()) + state["running_tasks"][task] = "other" + state["scheduled_components"].add("other") + + results = [out async for out in pp._run_component_in_isolation(component_name="doubler", **state)] + + # The in-flight task is drained (and its output yielded) before the isolated component runs. + assert {"other": {"value": 99}} in results + assert {"doubler": {"value": 6}} in results + assert results.index({"other": {"value": 99}}) < results.index({"doubler": {"value": 6}}) + assert state["running_tasks"] == {} + assert "other" not in state["scheduled_components"] + + @pytest.mark.asyncio + async def test_skips_when_component_already_scheduled(self): + pp = AsyncPipeline() + pp.add_component("doubler", _Doubler()) + state = _build_isolation_state(pp, {"doubler": {"value": 3}}) + state["scheduled_components"].add("doubler") + + results = [out async for out in pp._run_component_in_isolation(component_name="doubler", **state)] + + # Already scheduled: the component is not run. + assert results == [] + assert state["component_visits"]["doubler"] == 0 + assert state["pipeline_outputs"] == {} + assert "doubler" in state["scheduled_components"] + + @pytest.mark.asyncio + async def test_distributes_outputs_downstream_and_prunes_consumed(self): + pp = AsyncPipeline() + pp.add_component("first", _Doubler()) + pp.add_component("second", _Doubler()) + pp.connect("first.value", "second.value") + state = _build_isolation_state(pp, {"first": {"value": 3}}) + + results = [out async for out in pp._run_component_in_isolation(component_name="first", **state)] + + # `first`'s output is consumed by `second`, so it is pruned: nothing is yielded or stored as a pipeline output. + assert results == [] + assert state["pipeline_outputs"] == {} + # `second` can now consume the distributed value. + second = pp._get_component_with_graph_metadata_and_visits("second", 0) + assert pp._consume_component_inputs("second", second, state["inputs"]) == {"value": 6} + + @pytest.mark.asyncio + async def test_include_outputs_from_yields_even_when_consumed(self): + pp = AsyncPipeline() + pp.add_component("first", _Doubler()) + pp.add_component("second", _Doubler()) + pp.connect("first.value", "second.value") + state = _build_isolation_state(pp, {"first": {"value": 3}}) + state["include_outputs_from"] = {"first"} + + results = [out async for out in pp._run_component_in_isolation(component_name="first", **state)] + + # Even though `first`'s output is consumed by `second`, include_outputs_from forces it to be surfaced. + assert results == [{"first": {"value": 6}}] + assert state["pipeline_outputs"] == {"first": {"value": 6}}