diff --git a/src/graph/error.rs b/src/graph/error.rs index b3016b0..b7bf81d 100644 --- a/src/graph/error.rs +++ b/src/graph/error.rs @@ -12,6 +12,8 @@ pub enum GraphError { node_id: usize, }, MultipleErrors(Vec), + /// Contains the original error message when runtime creation failed + RuntimeCreationFailed(String), } impl std::fmt::Display for GraphError { diff --git a/src/graph/graph.rs b/src/graph/graph.rs index 498e42f..32c8a9a 100644 --- a/src/graph/graph.rs +++ b/src/graph/graph.rs @@ -166,8 +166,48 @@ impl Graph { /// This function is used for the execution of a single dag. pub fn start(&mut self) -> Result<(), GraphError> { + let runtime = tokio::runtime::Runtime::new() + .map_err(|e| GraphError::RuntimeCreationFailed(e.to_string()))?; + runtime.block_on(async { self.async_start().await }) + } + /// Executes a single DAG within an existing async runtime. + /// + /// Use this method when you are already running inside an async context + /// (for example, inside a `tokio::main` function or a task spawned on a + /// Tokio runtime) and you do **not** want `Graph` to create and manage + /// its own Tokio runtime. + /// + /// Unlike [`start`], this method: + /// - Does not create a new Tokio runtime. + /// - Assumes it is called on a thread where a Tokio runtime is already + /// active. + /// - Can be `await`-ed like any other async function. + /// + /// # Requirements + /// + /// - A Tokio runtime must be active on the current thread when this + /// method is called. + /// - The graph must have been properly configured (nodes and edges + /// added) before calling this method. + /// + /// If those conditions are not met, execution may fail at runtime. + /// + /// # Examples + /// + /// ```ignore + /// #[tokio::main] + /// async fn main() -> Result<(), Box> { + /// let mut graph = build_graph_somehow(); + /// + /// // Use `async_start` because we are already inside a Tokio runtime. + /// graph.async_start().await?; + /// + /// Ok(()) + /// } + /// ``` + pub async fn async_start(&mut self) -> Result<(), GraphError> { self.init(); - let is_loop = self.check_loop_and_partition(); + let is_loop = self.check_loop_and_partition().await; if is_loop { return Err(GraphError::GraphLoopDetected); } @@ -175,10 +215,7 @@ impl Graph { if !self.is_active.load(Ordering::Relaxed) { return Err(GraphError::GraphNotActive); } - - tokio::runtime::Runtime::new() - .unwrap() - .block_on(async { self.run().await }) + self.run().await } /// Executes the graph's nodes in a concurrent manner, respecting the block structure. @@ -263,9 +300,10 @@ impl Graph { } } Err(_) => { - // Close all the channels - node_ref.blocking_lock().input_channels().close_all(); - node_ref.blocking_lock().output_channels().close_all(); + // Close all the channels using the async lock (do not use blocking_lock inside runtime) + let mut node_guard = node_ref.lock().await; + node_guard.input_channels().close_all(); + node_guard.output_channels().close_all(); error!("Execution failed [name: {}, id: {}]", node_name, node_id,); let mut errors_lock = errors.lock().await; @@ -309,7 +347,7 @@ impl Graph { /// - Groups nodes into blocks, creating a new block whenever a conditional node / loop is encountered /// /// Returns true if the graph contains a cycle, false otherwise. - pub fn check_loop_and_partition(&mut self) -> bool { + pub async fn check_loop_and_partition(&mut self) -> bool { // Check for cycles using abstract graph let has_cycle = self.abstract_graph.check_loop(); @@ -334,7 +372,9 @@ impl Graph { // Create new block if conditional node / loop encountered let node = self.nodes.get(node_id).unwrap(); - if node.blocking_lock().is_condition() { + // Use an async lock here to avoid blocking the runtime + let node_guard = node.lock().await; + if node_guard.is_condition() { self.blocks.push(current_block); current_block = HashSet::new(); }