diff --git a/Cargo.lock b/Cargo.lock index 6b21b786e..dbb6d5733 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1803,6 +1803,7 @@ dependencies = [ "prost", "prost-types", "rand", + "scopeguard", "tokio", "tokio-stream", "tonic 0.11.0", diff --git a/nativelink-config/src/schedulers.rs b/nativelink-config/src/schedulers.rs index 09f79d6e0..170b5f273 100644 --- a/nativelink-config/src/schedulers.rs +++ b/nativelink-config/src/schedulers.rs @@ -119,6 +119,12 @@ pub struct SimpleScheduler { /// The strategy used to assign workers jobs. #[serde(default)] pub allocation_strategy: WorkerAllocationStrategy, + + /// Remove action from queue after this much time has elapsed without a listener + /// amount of time in seconds. + /// Default: 60 (seconds) + #[serde(default, deserialize_with = "convert_numeric_with_shellexpand")] + pub disconnect_timeout_s: u64, } /// A scheduler that simply forwards requests to an upstream scheduler. This diff --git a/nativelink-scheduler/src/action_scheduler.rs b/nativelink-scheduler/src/action_scheduler.rs index 5460209d7..5c4f1aa13 100644 --- a/nativelink-scheduler/src/action_scheduler.rs +++ b/nativelink-scheduler/src/action_scheduler.rs @@ -47,6 +47,9 @@ pub trait ActionScheduler: Sync + Send + Unpin { /// Cleans up the cache of recently completed actions. async fn clean_recently_completed_actions(&self); + /// Inform the scheduler a client has disconnected + fn notify_client_disconnected(&self, unique_qualifier: &ActionInfoHashKey); + /// Register the metrics for the action scheduler. fn register_metrics(self: Arc, _registry: &mut Registry) {} } diff --git a/nativelink-scheduler/src/cache_lookup_scheduler.rs b/nativelink-scheduler/src/cache_lookup_scheduler.rs index b1e60cab0..c98bc50d4 100644 --- a/nativelink-scheduler/src/cache_lookup_scheduler.rs +++ b/nativelink-scheduler/src/cache_lookup_scheduler.rs @@ -151,6 +151,8 @@ impl CacheLookupScheduler { #[async_trait] impl ActionScheduler for CacheLookupScheduler { + fn notify_client_disconnected(&self, _unique_qualifier: &ActionInfoHashKey) {} + async fn get_platform_property_manager( &self, instance_name: &str, diff --git a/nativelink-scheduler/src/grpc_scheduler.rs b/nativelink-scheduler/src/grpc_scheduler.rs index 155a67f4f..e8d7ee398 100644 --- a/nativelink-scheduler/src/grpc_scheduler.rs +++ b/nativelink-scheduler/src/grpc_scheduler.rs @@ -152,6 +152,8 @@ impl GrpcScheduler { #[async_trait] impl ActionScheduler for GrpcScheduler { + fn notify_client_disconnected(&self, _unique_qualifier: &ActionInfoHashKey) {} + async fn get_platform_property_manager( &self, instance_name: &str, diff --git a/nativelink-scheduler/src/property_modifier_scheduler.rs b/nativelink-scheduler/src/property_modifier_scheduler.rs index 12d4741ac..c265d8803 100644 --- a/nativelink-scheduler/src/property_modifier_scheduler.rs +++ b/nativelink-scheduler/src/property_modifier_scheduler.rs @@ -47,6 +47,8 @@ impl PropertyModifierScheduler { #[async_trait] impl ActionScheduler for PropertyModifierScheduler { + fn notify_client_disconnected(&self, _unique_qualifier: &ActionInfoHashKey) {} + async fn get_platform_property_manager( &self, instance_name: &str, diff --git a/nativelink-scheduler/src/simple_scheduler.rs b/nativelink-scheduler/src/simple_scheduler.rs index 9e22397be..7a4043a2b 100644 --- a/nativelink-scheduler/src/simple_scheduler.rs +++ b/nativelink-scheduler/src/simple_scheduler.rs @@ -18,10 +18,10 @@ use std::collections::BTreeMap; use std::hash::{Hash, Hasher}; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; -use std::time::{Instant, SystemTime}; +use std::time::Instant; use async_trait::async_trait; -use futures::Future; +use futures::future::{BoxFuture, Future}; use hashbrown::{HashMap, HashSet}; use lru::LruCache; use nativelink_config::schedulers::WorkerAllocationStrategy; @@ -38,7 +38,7 @@ use parking_lot::{Mutex, MutexGuard}; use tokio::sync::{watch, Notify}; use tokio::task::JoinHandle; use tokio::time::Duration; -use tracing::{error, warn}; +use tracing::{error, info, warn}; use crate::action_scheduler::ActionScheduler; use crate::platform_property_manager::PlatformPropertyManager; @@ -57,6 +57,10 @@ const DEFAULT_RETAIN_COMPLETED_FOR_S: u64 = 60; /// If this changes, remember to change the documentation in the config. const DEFAULT_MAX_JOB_RETRIES: usize = 3; +/// Default timeout for actions without any listeners. +/// If this changes, remember to change the documentation in the config. +const DEFAULT_DISCONNECT_TIMEOUT_S: u64 = 60; + /// An action that is being awaited on and last known state. struct AwaitedAction { action_info: Arc, @@ -173,7 +177,7 @@ impl Workers { } struct CompletedAction { - completed_time: SystemTime, + completed_instant: Instant, state: Arc, } @@ -198,6 +202,41 @@ impl Borrow for CompletedAction { } } +type NowFn = fn() -> Instant; +type SleepFn = fn(Duration) -> BoxFuture<'static, ()>; + +/// Functions that may be injected for testing purposes, during standard control +/// flows these are specified by the new function. +pub struct Callbacks { + /// A function that gets the current time. + now_fn: NowFn, + + /// A function that sleeps for a given Duration. + sleep_fn: SleepFn, +} + +impl Callbacks { + pub fn new(now_fn: NowFn, sleep_fn: SleepFn) -> Self { + Self { now_fn, sleep_fn } + } + + fn now(&self) -> Instant { + (self.now_fn)() + } + + fn sleep(&self, duration: Duration) -> impl Future { + (self.sleep_fn)(duration) + } +} + +impl Default for Callbacks { + fn default() -> Self { + Callbacks { + now_fn: Instant::now, + sleep_fn: |duration| Box::pin(tokio::time::sleep(duration)), + } + } +} struct SimpleSchedulerImpl { // BTreeMap uses `cmp` to do it's comparisons, this is a problem because we want to sort our // actions based on priority and insert timestamp but also want to find and join new actions @@ -227,6 +266,9 @@ struct SimpleSchedulerImpl { /// Notify task<->worker matching engine that work needs to be done. tasks_or_workers_change_notify: Arc, metrics: Arc, + /// How long the server will wait for a client to reconnect before removing the action from the queue. + disconnect_timeout_s: u64, + callbacks: Arc, } impl SimpleSchedulerImpl { @@ -313,11 +355,13 @@ impl SimpleSchedulerImpl { } fn clean_recently_completed_actions(&mut self) { - let expiry_time = SystemTime::now() + let expiry_time = self + .callbacks + .now() .checked_sub(self.retain_completed_for) .unwrap(); self.recently_completed_actions - .retain(|action| action.completed_time > expiry_time); + .retain(|action| action.completed_instant > expiry_time); } fn find_recently_completed_action( @@ -422,10 +466,21 @@ impl SimpleSchedulerImpl { Ok(()) } + fn get_queued_action(&self, unique_qualifier: &ActionInfoHashKey) -> Option<&AwaitedAction> { + self.queued_actions_set + .get(unique_qualifier) + .and_then(|action_info| self.queued_actions.get(action_info)) + } + + fn get_active_action(&self, unique_qualifier: &ActionInfoHashKey) -> Option<&AwaitedAction> { + self.active_actions.get(unique_qualifier) + } + // TODO(blaise.bruer) This is an O(n*m) (aka n^2) algorithm. In theory we can create a map // of capabilities of each worker and then try and match the actions to the worker using // the map lookup (ie. map reduce). fn do_try_match(&mut self) { + println!("do_try_match did run"); // TODO(blaise.bruer) This is a bit difficult because of how rust's borrow checker gets in // the way. We need to conditionally remove items from the `queued_action`. Rust is working // to add `drain_filter`, which would in theory solve this problem, but because we need @@ -619,7 +674,7 @@ impl SimpleSchedulerImpl { // Keep in case this is asked for soon. self.recently_completed_actions.insert(CompletedAction { - completed_time: SystemTime::now(), + completed_instant: self.callbacks.now(), state: running_action.current_state, }); @@ -647,7 +702,7 @@ impl SimpleScheduler { #[inline] #[must_use] pub fn new(scheduler_cfg: &nativelink_config::schedulers::SimpleScheduler) -> Self { - Self::new_with_callback(scheduler_cfg, || { + Self::new_with_callback(scheduler_cfg, Callbacks::default(), || { // The cost of running `do_try_match()` is very high, but constant // in relation to the number of changes that have happened. This means // that grabbing this lock to process `do_try_match()` should always @@ -664,6 +719,7 @@ impl SimpleScheduler { F: Fn() -> Fut + Send + Sync + 'static, >( scheduler_cfg: &nativelink_config::schedulers::SimpleScheduler, + callbacks: Callbacks, on_matching_engine_run: F, ) -> Self { let platform_property_manager = Arc::new(PlatformPropertyManager::new( @@ -688,6 +744,11 @@ impl SimpleScheduler { max_job_retries = DEFAULT_MAX_JOB_RETRIES; } + let mut disconnect_timeout_s = scheduler_cfg.disconnect_timeout_s; + if disconnect_timeout_s == 0 { + disconnect_timeout_s = DEFAULT_DISCONNECT_TIMEOUT_S; + } + let tasks_or_workers_change_notify = Arc::new(Notify::new()); let metrics = Arc::new(Metrics::default()); @@ -703,6 +764,8 @@ impl SimpleScheduler { max_job_retries, tasks_or_workers_change_notify: tasks_or_workers_change_notify.clone(), metrics: metrics.clone(), + disconnect_timeout_s, + callbacks: Arc::new(callbacks), })); let weak_inner = Arc::downgrade(&inner); Self { @@ -777,6 +840,97 @@ impl SimpleScheduler { #[async_trait] impl ActionScheduler for SimpleScheduler { + fn notify_client_disconnected(&self, unique_qualifier: &ActionInfoHashKey) { + // TODO: Make this prettier. + // It's a bit tricky to comply with borrow checker + // but it should be possible to make this nicer. + let inner = self.get_inner_lock(); + let Some(action) = inner + .get_queued_action(&unique_qualifier) + .or_else(|| inner.get_active_action(&unique_qualifier)) + else { + warn!( + "Scheduler notified that client disconnected, but failed to find action {}", + unique_qualifier.digest.hash_str() + ); + return; + }; + + if action.notify_channel.receiver_count() != 0 { + return; + } + let sleep_time = Duration::from_secs(inner.disconnect_timeout_s); + let callbacks = inner.callbacks.clone(); + // Drop the mutex guard so we don't hold up access. + drop(inner); + + let weak_inner = Arc::downgrade(&self.inner); + + let unique_qualifier = unique_qualifier.clone(); + + // We create a spawn here which sleeps for disconnect_timeout_s + // and then checks to see if the listener count is still 0. + // If so, it moves to the removal stage, otherwise returns. + tokio::spawn(async move { + callbacks.sleep(sleep_time).await; + let Some(inner_mux) = weak_inner.upgrade() else { + return; + }; + let mut inner = inner_mux.lock(); + let Some(action) = inner + .queued_actions_set + .get(&unique_qualifier) + .and_then(|action_info| inner.queued_actions.get(action_info)) + else { + if inner.active_actions.contains_key(&unique_qualifier) { + warn!("Action was active and could not be killed"); + } else { + info!( + "Action {} completed while client disconnected", + unique_qualifier.digest.hash_str() + ); + } + return; + }; + + // If listener count is 0, remove the action from queued actions + // or active actions, whichever is applicable, and kill the action + // on the worker if it was active. + // If the listener count is not still 0, a client has reconencted + // and there is no more work to be done. + if action.notify_channel.receiver_count() != 0 { + info!( + "Client reconnected before disconnect_timeout elsapsed for Action {}", + unique_qualifier.digest.hash_str() + ); + return; + } + + warn!( + "Client disconnect timeout elapsed - Removing action with digest hash {}", + action.action_info.unique_qualifier.digest.hash_str() + ); + + println!("About to remove"); + match inner.get_queued_action(&unique_qualifier) { + Some(_) => { + // We can't use the action info from the above call due to borrow checker. + let action_info = inner + .queued_actions_set + .get(&unique_qualifier) + .unwrap() + .clone(); + inner.queued_actions_set.remove(&action_info); + inner.queued_actions.remove(&action_info); + } + None => { + inner.active_actions.remove(&unique_qualifier); + // TODO: Send kill on worker signal - PR: #842. + } + } + }); + } + async fn get_platform_property_manager( &self, _instance_name: &str, @@ -1024,9 +1178,9 @@ impl MetricsComponent for SimpleScheduler { impl MetricsComponent for CompletedAction { fn gather_metrics(&self, c: &mut CollectorState) { c.publish( - "completed_timestamp", - &self.completed_time, - "The timestamp this action was completed", + "completed_instant_elapsed", + &self.completed_instant.elapsed(), + "The elapsed time since the action was completed", ); c.publish( "current_state", diff --git a/nativelink-scheduler/tests/simple_scheduler_test.rs b/nativelink-scheduler/tests/simple_scheduler_test.rs index 0653d7b73..5d290d5ad 100644 --- a/nativelink-scheduler/tests/simple_scheduler_test.rs +++ b/nativelink-scheduler/tests/simple_scheduler_test.rs @@ -31,7 +31,7 @@ use nativelink_proto::build::bazel::remote::execution::v2::{digest_function, Exe use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{ update_for_worker, ConnectionResult, StartExecute, UpdateForWorker, }; -use nativelink_scheduler::simple_scheduler::SimpleScheduler; +use nativelink_scheduler::simple_scheduler::{Callbacks, SimpleScheduler}; use nativelink_scheduler::worker::{Worker, WorkerId}; use nativelink_scheduler::worker_scheduler::WorkerScheduler; use nativelink_util::common::DigestInfo; @@ -95,6 +95,8 @@ async fn setup_action( #[cfg(test)] mod scheduler_tests { + use std::time::Instant; + use pretty_assertions::assert_eq; use super::*; // Must be declared in every module. @@ -107,6 +109,7 @@ mod scheduler_tests { let scheduler = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), + Callbacks::default(), || async move {}, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -160,6 +163,7 @@ mod scheduler_tests { let scheduler = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), + Callbacks::default(), || async move {}, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -224,6 +228,7 @@ mod scheduler_tests { worker_timeout_s: WORKER_TIMEOUT_S, ..Default::default() }, + Callbacks::default(), || async move {}, ); let action_digest1 = DigestInfo::new([99u8; 32], 512); @@ -367,6 +372,7 @@ mod scheduler_tests { let scheduler = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), + Callbacks::default(), || async move {}, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -442,6 +448,7 @@ mod scheduler_tests { let scheduler = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), + Callbacks::default(), || async move {}, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -524,6 +531,7 @@ mod scheduler_tests { let scheduler = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), + Callbacks::default(), || async move {}, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -625,6 +633,7 @@ mod scheduler_tests { const WORKER_ID: WorkerId = WorkerId(0x0010_0010); let scheduler = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), + Callbacks::default(), || async move {}, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -666,6 +675,7 @@ mod scheduler_tests { worker_timeout_s: WORKER_TIMEOUT_S, ..Default::default() }, + Callbacks::default(), || async move {}, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -765,6 +775,7 @@ mod scheduler_tests { let scheduler = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), + Callbacks::default(), || async move {}, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -868,6 +879,7 @@ mod scheduler_tests { let scheduler = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), + Callbacks::default(), || async move {}, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -971,6 +983,7 @@ mod scheduler_tests { let scheduler = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), + Callbacks::default(), || async move {}, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -1066,6 +1079,7 @@ mod scheduler_tests { let scheduler = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), + Callbacks::default(), || async move {}, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -1195,6 +1209,7 @@ mod scheduler_tests { let scheduler = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), + Callbacks::default(), || async move {}, ); let action_digest1 = DigestInfo::new([11u8; 32], 512); @@ -1339,6 +1354,7 @@ mod scheduler_tests { let scheduler = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), + Callbacks::default(), || async move {}, ); let action_digest1 = DigestInfo::new([11u8; 32], 512); @@ -1394,6 +1410,7 @@ mod scheduler_tests { max_job_retries: 2, ..Default::default() }, + Callbacks::default(), || async move {}, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -1523,6 +1540,7 @@ mod scheduler_tests { // DropChecker. let scheduler = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), + Callbacks::default(), move || { // This will ensure dropping happens if this function is ever dropped. let _drop_checker = drop_checker.clone(); @@ -1548,6 +1566,7 @@ mod scheduler_tests { let scheduler = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), + Callbacks::default(), || async move {}, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -1602,4 +1621,202 @@ mod scheduler_tests { Ok(()) } + + #[tokio::test] + async fn ensure_actions_with_disconnected_clients_are_dropped() -> Result<(), Error> { + const DISCONNECT_TIMEOUT_S: u64 = 1; + + let scheduler = SimpleScheduler::new_with_callback( + &nativelink_config::schedulers::SimpleScheduler { + disconnect_timeout_s: DISCONNECT_TIMEOUT_S, + ..Default::default() + }, + Callbacks::new(Instant::now, |_| Box::pin(futures::future::ready(()))), + || async move {}, + ); + + let client_rx = setup_action( + &scheduler, + DigestInfo::new([98u8; 32], 512), + PlatformProperties::default(), + make_system_time(1), + ) + .await?; + + // Drop our receiver. + let unique_qualifier = client_rx.borrow().unique_qualifier.clone(); + drop(client_rx); + // Inform the scheduler the client has been dropped. + // Normally this would happen automatically in the service api. + scheduler.notify_client_disconnected(&unique_qualifier); + + // Allow the action removal spawn to run. + tokio::task::yield_now().await; + + // Check to make sure that the action was removed. + assert!( + scheduler + .find_existing_action(&unique_qualifier) + .await + .is_none(), + "Expected action to be removed" + ); + + Ok(()) + } + + #[tokio::test] + async fn ensure_notify_disconnected_does_not_block() -> Result<(), Error> { + const WORKER_ID: WorkerId = WorkerId(0x1234_5678_9111); + const DISCONNECT_TIMEOUT_S: u64 = 1; + + let scheduler = SimpleScheduler::new_with_callback( + &nativelink_config::schedulers::SimpleScheduler { + disconnect_timeout_s: DISCONNECT_TIMEOUT_S, + ..Default::default() + }, + Callbacks::new(Instant::now, |_| Box::pin(futures::future::pending())), + || async move {}, + ); + let action1_digest = DigestInfo::new([98u8; 32], 512); + let action2_digest = DigestInfo::new([99u8; 32], 512); + + let mut rx_from_worker = + setup_new_worker(&scheduler, WORKER_ID, PlatformProperties::default()).await?; + let insert_timestamp = make_system_time(1); + + let client_rx = setup_action( + &scheduler, + action1_digest, + PlatformProperties::default(), + insert_timestamp, + ) + .await?; + + // Drop our receiver. + let unique_qualifier_1 = client_rx.borrow().unique_qualifier.clone(); + drop(client_rx); + + scheduler.notify_client_disconnected(&unique_qualifier_1); + + { + // Other tests check full data. We only care if we got StartAction. + match rx_from_worker.recv().await.unwrap().update { + Some(update_for_worker::Update::StartAction(_)) => { /* Success */ } + v => panic!("Expected StartAction, got : {v:?}"), + } + } + + // Setup a second action so matching engine is scheduled to rerun. + let client_rx = setup_action( + &scheduler, + action2_digest, + PlatformProperties::default(), + insert_timestamp, + ) + .await?; + + { + // Make sure the action sent after dropping the first reciever is registered. + match rx_from_worker.recv().await.unwrap().update { + Some(update_for_worker::Update::StartAction(_)) => { /* Success */ } + v => panic!("Expected StartAction, got : {v:?}"), + } + } + + // Check to make sure that the first action was not removed. + assert!( + scheduler + .find_existing_action(&unique_qualifier_1) + .await + .is_some(), + "Expected action to be removed" + ); + + let unique_qualifier_2 = client_rx.borrow().unique_qualifier.clone(); + // Check to make sure that the first action was not removed. + assert!( + scheduler + .find_existing_action(&unique_qualifier_2) + .await + .is_some(), + "Expected action to be removed" + ); + Ok(()) + } + + // #[tokio::test] + // async fn ensure_active_actions_with_client_disconnect_are_dropped() -> Result<(), Error> { + // const WORKER_ID: WorkerId = WorkerId(0x1234_5678_9111); + // const DISCONNECT_TIMEOUT_S: u64 = 1; + + // let scheduler = SimpleScheduler::new_with_callback( + // &nativelink_config::schedulers::SimpleScheduler { + // disconnect_timeout_s: DISCONNECT_TIMEOUT_S, + // ..Default::default() + // }, + // Callbacks::new(Instant::now, |_| Box::pin(futures::future::ready(()))), + // || async move {}, + // ); + // let action_digest = DigestInfo::new([98u8; 32], 512); + // let action_digest = DigestInfo::new([98u8; 32], 512); + + // let mut rx_from_worker = + // setup_new_worker(&scheduler, WORKER_ID, PlatformProperties::default()).await?; + // let insert_timestamp = make_system_time(1); + + // let client_rx = setup_action( + // &scheduler, + // action1_digest, + // PlatformProperties::default(), + // insert_timestamp, + // ) + // .await?; + + // // Drop our receiver. + // let unique_qualifier = client_rx.borrow().unique_qualifier.clone(); + // drop(client_rx); + + // // Allow task<->worker matcher to run. + // tokio::task::yield_now().await; + + // // Inform the scheduler the client has been dropped. + // // Normally this would happen automatically due to Tonic. + // scheduler.notify_client_disconnected(&unique_qualifier); + + // { + // // Other tests check full data. We only care if we got StartAction. + // match rx_from_worker.recv().await.unwrap().update { + // Some(update_for_worker::Update::StartAction(_)) => { /* Success */ } + // v => panic!("Expected StartAction, got : {v:?}"), + // } + // } + + // // Setup a second action so matching engine is scheduled to rerun. + // let client_rx = setup_action( + // &scheduler, + // action2_digest, + // PlatformProperties::default(), + // insert_timestamp, + // ) + // .await?; + // println!("About to drop rx"); + // drop(client_rx); + // // scheduler.notify_client_disconnected_for_test(&unique_qualifier).await; + + // // Allow task<->worker matcher to run. + // tokio::task::yield_now().await; + // println!("yield did run"); + + // // Check to make sure that the action was removed. + // assert!( + // scheduler + // .find_existing_action(&unique_qualifier) + // .await + // .is_none(), + // "Expected action to be removed" + // ); + + // Ok(()) + // } } diff --git a/nativelink-scheduler/tests/utils/mock_scheduler.rs b/nativelink-scheduler/tests/utils/mock_scheduler.rs index 9afd1dd6b..706acba56 100644 --- a/nativelink-scheduler/tests/utils/mock_scheduler.rs +++ b/nativelink-scheduler/tests/utils/mock_scheduler.rs @@ -120,6 +120,8 @@ impl MockActionScheduler { #[async_trait] impl ActionScheduler for MockActionScheduler { + fn notify_client_disconnected(&self, _unique_qualifier: &ActionInfoHashKey) {} + async fn get_platform_property_manager( &self, instance_name: &str, diff --git a/nativelink-service/BUILD.bazel b/nativelink-service/BUILD.bazel index 06e1d4ef2..63eedc52c 100644 --- a/nativelink-service/BUILD.bazel +++ b/nativelink-service/BUILD.bazel @@ -31,6 +31,7 @@ rust_library( "@crates//:parking_lot", "@crates//:prost", "@crates//:rand", + "@crates//:scopeguard", "@crates//:tokio", "@crates//:tokio-stream", "@crates//:tonic", diff --git a/nativelink-service/Cargo.toml b/nativelink-service/Cargo.toml index 233f6c3d5..2c2bae302 100644 --- a/nativelink-service/Cargo.toml +++ b/nativelink-service/Cargo.toml @@ -17,6 +17,7 @@ log = "0.4.21" parking_lot = "0.12.1" prost = "0.12.3" rand = "0.8.5" +scopeguard = "1.2.0" tokio = { version = "1.37.0", features = ["sync", "rt"] } tokio-stream = { version = "0.1.15", features = ["sync"] } tonic = { version = "0.11.0", features = ["gzip", "tls"] } diff --git a/nativelink-service/src/execution_server.rs b/nativelink-service/src/execution_server.rs index cd26d4f93..3b3269612 100644 --- a/nativelink-service/src/execution_server.rs +++ b/nativelink-service/src/execution_server.rs @@ -38,11 +38,11 @@ use nativelink_util::digest_hasher::DigestHasherFunc; use nativelink_util::platform_properties::PlatformProperties; use nativelink_util::store_trait::Store; use rand::{thread_rng, Rng}; +use scopeguard::guard; use tokio::sync::watch; use tokio_stream::wrappers::WatchStream; use tonic::{Request, Response, Status}; use tracing::{error, info}; - struct InstanceInfo { scheduler: Arc, cas_store: Arc, @@ -182,11 +182,26 @@ impl ExecutionServer { Server::new(self) } - fn to_execute_stream(receiver: watch::Receiver>) -> Response { - let receiver_stream = Box::pin(WatchStream::new(receiver).map(|action_update| { + fn to_execute_stream( + &self, + receiver: watch::Receiver>, + unique_qualifier: ActionInfoHashKey, + ) -> Response { + let scheduler = self + .instance_infos + .get(&unique_qualifier.instance_name) + .unwrap() + .scheduler + .clone(); + let scope_guard = guard(scheduler, move |scheduler| { + scheduler.notify_client_disconnected(&unique_qualifier) + }); + let receiver_stream = Box::pin(WatchStream::new(receiver).map(move |action_update| { + let _scope_guard = &scope_guard; info!("\x1b[0;31mexecute Resp Stream\x1b[0m: {:?}", action_update); Ok(Into::::into(action_update.as_ref().clone())) })); + tonic::Response::new(receiver_stream) } @@ -230,11 +245,11 @@ impl ExecutionServer { let rx = instance_info .scheduler - .add_action(action_info) + .add_action(action_info.clone()) .await .err_tip(|| "Failed to schedule task")?; - Ok(Self::to_execute_stream(rx)) + Ok(self.to_execute_stream(rx, action_info.unique_qualifier)) } async fn inner_wait_execution( @@ -256,7 +271,7 @@ impl ExecutionServer { else { return Err(Status::not_found("Failed to find existing task")); }; - Ok(Self::to_execute_stream(rx)) + Ok(self.to_execute_stream(rx, unique_qualifier)) } }