diff --git a/tapeagents/agent.py b/tapeagents/agent.py index 1109e0de1..8ba07324b 100644 --- a/tapeagents/agent.py +++ b/tapeagents/agent.py @@ -29,7 +29,7 @@ Thought, TrainingText, ) -from tapeagents.llms import LLM, LLMCall, LLMEvent, LLMOutput, LLMStream, TrainableLLM +from tapeagents.llms import LLM, LLMCall, LLMEvent, LLMOutput, LLMStream from tapeagents.observe import observe_llm_call from tapeagents.tool_calling import ToolSpec from tapeagents.view import TapeViewStack @@ -718,39 +718,78 @@ def _run_implementation(): return AgentStream(_run_implementation()) - def run_batch(self: Agent[TapeType], tapes: list[TapeType]) -> list[Tape]: - """Run agent in parallel on tapes using batched LLM calls. - - This is faster than running agents in thread and having the LLM server batch the calls. - + def run_batch(self: Agent[TapeType], tapes: list[Tape], environment = None) -> list[Tape]: + """ + Run agent in parallel on tapes using batched LLM calls with optional environment integration. + + Args: + tapes: List of tapes to process in parallel + environment: Optional environment with a react(tape) method + + Returns: + List of processed tapes """ if len(self.llms) > 1: - raise NotImplementedError("For run_agent_batch the agent must have only one LLM for now") - if not isinstance(self.llm, TrainableLLM): - raise NotImplementedError("For run_agent_batch the LLM must be TrainableLLM") + raise NotImplementedError("For batch processing the agent must have only one LLM for now") + + llm = self.llms.get("default") or next(iter(self.llms.values())) + if not hasattr(llm, "batch_generate"): + raise NotImplementedError("The LLM must support batch_generate method") + + # Check environment has react method if provided + if environment is not None and (not hasattr(environment, "react") or not callable(environment.react)): + raise ValueError("Environment must have a callable react method") + original_tapes = list(tapes) n_iterations = 0 - active_indices = set(range(len(tapes))) - while n_iterations < self.max_iterations: - prompts = [] - current_subagents = [self.delegate(tapes[i]) for i in active_indices] - prompts = [subagent.make_prompt(tape) for subagent, tape in zip(current_subagents, tapes)] - llm_calls = self.llm.batch_generate(prompts) - for i in active_indices: - # Run the equivalent of agent.run_iteration + pending_tape_indices = set(range(len(tapes))) + + while n_iterations < self.max_iterations and pending_tape_indices: + # Create batch structure with tape indices, subagents, and prompts + batch = [] + for tape_idx in list(pending_tape_indices): + subagent = self.delegate(tapes[tape_idx]) + prompt = subagent.make_prompt(tapes[tape_idx]) + batch.append((tape_idx, subagent, prompt)) + + if not batch: + break + + # Batch LLM calls + prompts = [item[2] for item in batch] + llm_responses = llm.batch_generate(prompts) + + # Process results for each tape in the batch + for batch_pos, (tape_idx, subagent, _) in enumerate(batch): + llm_call = llm_responses[batch_pos] + + # Create LLM stream llm_stream = LLMStream( - (LLMEvent(output=output) for output in (llm_calls[i].output,)), llm_calls[i].prompt + (LLMEvent(output=llm_call.output) for _ in range(1)), + llm_call.prompt ) - for step in self.generate_steps(tapes[i], llm_stream): - step.metadata.agent = current_subagents[i].full_name + + # Process steps directly using generate_steps + for step in subagent.generate_steps(tapes[tape_idx], llm_stream): if isinstance(step, AgentStep): - step.metadata.prompt_id = llm_calls[i].prompt.id - tapes[i] = tapes[i].append(step) - if self.should_stop(tapes[i]): - active_indices.remove(i) - if self.store_llm_calls: - step.metadata.other["llm_call"] = llm_calls[i] + step.metadata.prompt_id = llm_call.prompt.id + tapes[tape_idx] = tapes[tape_idx].append(step) + + # Store LLM call in metadata if needed + if subagent.store_llm_calls: + step.metadata.other["llm_call"] = llm_call + + # Apply environment reactions if environment is provided + if environment is not None: + tapes[tape_idx] = environment.react(tapes[tape_idx]) + + # Check if agent should stop for this tape + if subagent.should_stop(tapes[tape_idx]): + pending_tape_indices.remove(tape_idx) + n_iterations += 1 + + # Update metadata for all tapes for i in range(len(tapes)): updated_metadata = original_tapes[i].metadata.model_validate( dict( @@ -760,6 +799,7 @@ def run_batch(self: Agent[TapeType], tapes: list[TapeType]) -> list[Tape]: ) ) tapes[i] = tapes[i].model_copy(update=dict(metadata=updated_metadata)) + return tapes def reuse(self, tape: TapeType) -> tuple[TapeType, list[LLMCall]]: