From d9713798374574377063cd9c73c2b8b6ed6f9ff4 Mon Sep 17 00:00:00 2001 From: Zach Birenbaum Date: Tue, 9 Apr 2024 18:07:54 -0700 Subject: [PATCH] Remove old actions with no listeners Implement scheduler side removal of actions with no listeners. Adds disconnect_timeout_s configuration field with default of 60s. If the client waiting on a given action is disconnected for longer than this duration without reconnecting the scheduler will stop tracking it. This does not remove it from the worker if the job has already been dispatched. fixes TraceMachina#338 --- Cargo.lock | 1 + nativelink-config/src/schedulers.rs | 6 + nativelink-scheduler/src/action_scheduler.rs | 3 + .../src/cache_lookup_scheduler.rs | 2 + nativelink-scheduler/src/grpc_scheduler.rs | 2 + .../src/property_modifier_scheduler.rs | 2 + nativelink-scheduler/src/simple_scheduler.rs | 176 +++++++++++++- .../tests/simple_scheduler_test.rs | 219 +++++++++++++++++- .../tests/utils/mock_scheduler.rs | 2 + nativelink-service/BUILD.bazel | 1 + nativelink-service/Cargo.toml | 1 + nativelink-service/src/execution_server.rs | 27 ++- 12 files changed, 424 insertions(+), 18 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6b21b786e..dbb6d5733 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1803,6 +1803,7 @@ dependencies = [ "prost", "prost-types", "rand", + "scopeguard", "tokio", "tokio-stream", "tonic 0.11.0", diff --git a/nativelink-config/src/schedulers.rs b/nativelink-config/src/schedulers.rs index 09f79d6e0..170b5f273 100644 --- a/nativelink-config/src/schedulers.rs +++ b/nativelink-config/src/schedulers.rs @@ -119,6 +119,12 @@ pub struct SimpleScheduler { /// The strategy used to assign workers jobs. #[serde(default)] pub allocation_strategy: WorkerAllocationStrategy, + + /// Remove action from queue after this much time has elapsed without a listener + /// amount of time in seconds. + /// Default: 60 (seconds) + #[serde(default, deserialize_with = "convert_numeric_with_shellexpand")] + pub disconnect_timeout_s: u64, } /// A scheduler that simply forwards requests to an upstream scheduler. This diff --git a/nativelink-scheduler/src/action_scheduler.rs b/nativelink-scheduler/src/action_scheduler.rs index 5460209d7..5c4f1aa13 100644 --- a/nativelink-scheduler/src/action_scheduler.rs +++ b/nativelink-scheduler/src/action_scheduler.rs @@ -47,6 +47,9 @@ pub trait ActionScheduler: Sync + Send + Unpin { /// Cleans up the cache of recently completed actions. async fn clean_recently_completed_actions(&self); + /// Inform the scheduler a client has disconnected + fn notify_client_disconnected(&self, unique_qualifier: &ActionInfoHashKey); + /// Register the metrics for the action scheduler. fn register_metrics(self: Arc, _registry: &mut Registry) {} } diff --git a/nativelink-scheduler/src/cache_lookup_scheduler.rs b/nativelink-scheduler/src/cache_lookup_scheduler.rs index b1e60cab0..c98bc50d4 100644 --- a/nativelink-scheduler/src/cache_lookup_scheduler.rs +++ b/nativelink-scheduler/src/cache_lookup_scheduler.rs @@ -151,6 +151,8 @@ impl CacheLookupScheduler { #[async_trait] impl ActionScheduler for CacheLookupScheduler { + fn notify_client_disconnected(&self, _unique_qualifier: &ActionInfoHashKey) {} + async fn get_platform_property_manager( &self, instance_name: &str, diff --git a/nativelink-scheduler/src/grpc_scheduler.rs b/nativelink-scheduler/src/grpc_scheduler.rs index 155a67f4f..e8d7ee398 100644 --- a/nativelink-scheduler/src/grpc_scheduler.rs +++ b/nativelink-scheduler/src/grpc_scheduler.rs @@ -152,6 +152,8 @@ impl GrpcScheduler { #[async_trait] impl ActionScheduler for GrpcScheduler { + fn notify_client_disconnected(&self, _unique_qualifier: &ActionInfoHashKey) {} + async fn get_platform_property_manager( &self, instance_name: &str, diff --git a/nativelink-scheduler/src/property_modifier_scheduler.rs b/nativelink-scheduler/src/property_modifier_scheduler.rs index 12d4741ac..c265d8803 100644 --- a/nativelink-scheduler/src/property_modifier_scheduler.rs +++ b/nativelink-scheduler/src/property_modifier_scheduler.rs @@ -47,6 +47,8 @@ impl PropertyModifierScheduler { #[async_trait] impl ActionScheduler for PropertyModifierScheduler { + fn notify_client_disconnected(&self, _unique_qualifier: &ActionInfoHashKey) {} + async fn get_platform_property_manager( &self, instance_name: &str, diff --git a/nativelink-scheduler/src/simple_scheduler.rs b/nativelink-scheduler/src/simple_scheduler.rs index 9e22397be..7a4043a2b 100644 --- a/nativelink-scheduler/src/simple_scheduler.rs +++ b/nativelink-scheduler/src/simple_scheduler.rs @@ -18,10 +18,10 @@ use std::collections::BTreeMap; use std::hash::{Hash, Hasher}; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; -use std::time::{Instant, SystemTime}; +use std::time::Instant; use async_trait::async_trait; -use futures::Future; +use futures::future::{BoxFuture, Future}; use hashbrown::{HashMap, HashSet}; use lru::LruCache; use nativelink_config::schedulers::WorkerAllocationStrategy; @@ -38,7 +38,7 @@ use parking_lot::{Mutex, MutexGuard}; use tokio::sync::{watch, Notify}; use tokio::task::JoinHandle; use tokio::time::Duration; -use tracing::{error, warn}; +use tracing::{error, info, warn}; use crate::action_scheduler::ActionScheduler; use crate::platform_property_manager::PlatformPropertyManager; @@ -57,6 +57,10 @@ const DEFAULT_RETAIN_COMPLETED_FOR_S: u64 = 60; /// If this changes, remember to change the documentation in the config. const DEFAULT_MAX_JOB_RETRIES: usize = 3; +/// Default timeout for actions without any listeners. +/// If this changes, remember to change the documentation in the config. +const DEFAULT_DISCONNECT_TIMEOUT_S: u64 = 60; + /// An action that is being awaited on and last known state. struct AwaitedAction { action_info: Arc, @@ -173,7 +177,7 @@ impl Workers { } struct CompletedAction { - completed_time: SystemTime, + completed_instant: Instant, state: Arc, } @@ -198,6 +202,41 @@ impl Borrow for CompletedAction { } } +type NowFn = fn() -> Instant; +type SleepFn = fn(Duration) -> BoxFuture<'static, ()>; + +/// Functions that may be injected for testing purposes, during standard control +/// flows these are specified by the new function. +pub struct Callbacks { + /// A function that gets the current time. + now_fn: NowFn, + + /// A function that sleeps for a given Duration. + sleep_fn: SleepFn, +} + +impl Callbacks { + pub fn new(now_fn: NowFn, sleep_fn: SleepFn) -> Self { + Self { now_fn, sleep_fn } + } + + fn now(&self) -> Instant { + (self.now_fn)() + } + + fn sleep(&self, duration: Duration) -> impl Future { + (self.sleep_fn)(duration) + } +} + +impl Default for Callbacks { + fn default() -> Self { + Callbacks { + now_fn: Instant::now, + sleep_fn: |duration| Box::pin(tokio::time::sleep(duration)), + } + } +} struct SimpleSchedulerImpl { // BTreeMap uses `cmp` to do it's comparisons, this is a problem because we want to sort our // actions based on priority and insert timestamp but also want to find and join new actions @@ -227,6 +266,9 @@ struct SimpleSchedulerImpl { /// Notify task<->worker matching engine that work needs to be done. tasks_or_workers_change_notify: Arc, metrics: Arc, + /// How long the server will wait for a client to reconnect before removing the action from the queue. + disconnect_timeout_s: u64, + callbacks: Arc, } impl SimpleSchedulerImpl { @@ -313,11 +355,13 @@ impl SimpleSchedulerImpl { } fn clean_recently_completed_actions(&mut self) { - let expiry_time = SystemTime::now() + let expiry_time = self + .callbacks + .now() .checked_sub(self.retain_completed_for) .unwrap(); self.recently_completed_actions - .retain(|action| action.completed_time > expiry_time); + .retain(|action| action.completed_instant > expiry_time); } fn find_recently_completed_action( @@ -422,10 +466,21 @@ impl SimpleSchedulerImpl { Ok(()) } + fn get_queued_action(&self, unique_qualifier: &ActionInfoHashKey) -> Option<&AwaitedAction> { + self.queued_actions_set + .get(unique_qualifier) + .and_then(|action_info| self.queued_actions.get(action_info)) + } + + fn get_active_action(&self, unique_qualifier: &ActionInfoHashKey) -> Option<&AwaitedAction> { + self.active_actions.get(unique_qualifier) + } + // TODO(blaise.bruer) This is an O(n*m) (aka n^2) algorithm. In theory we can create a map // of capabilities of each worker and then try and match the actions to the worker using // the map lookup (ie. map reduce). fn do_try_match(&mut self) { + println!("do_try_match did run"); // TODO(blaise.bruer) This is a bit difficult because of how rust's borrow checker gets in // the way. We need to conditionally remove items from the `queued_action`. Rust is working // to add `drain_filter`, which would in theory solve this problem, but because we need @@ -619,7 +674,7 @@ impl SimpleSchedulerImpl { // Keep in case this is asked for soon. self.recently_completed_actions.insert(CompletedAction { - completed_time: SystemTime::now(), + completed_instant: self.callbacks.now(), state: running_action.current_state, }); @@ -647,7 +702,7 @@ impl SimpleScheduler { #[inline] #[must_use] pub fn new(scheduler_cfg: &nativelink_config::schedulers::SimpleScheduler) -> Self { - Self::new_with_callback(scheduler_cfg, || { + Self::new_with_callback(scheduler_cfg, Callbacks::default(), || { // The cost of running `do_try_match()` is very high, but constant // in relation to the number of changes that have happened. This means // that grabbing this lock to process `do_try_match()` should always @@ -664,6 +719,7 @@ impl SimpleScheduler { F: Fn() -> Fut + Send + Sync + 'static, >( scheduler_cfg: &nativelink_config::schedulers::SimpleScheduler, + callbacks: Callbacks, on_matching_engine_run: F, ) -> Self { let platform_property_manager = Arc::new(PlatformPropertyManager::new( @@ -688,6 +744,11 @@ impl SimpleScheduler { max_job_retries = DEFAULT_MAX_JOB_RETRIES; } + let mut disconnect_timeout_s = scheduler_cfg.disconnect_timeout_s; + if disconnect_timeout_s == 0 { + disconnect_timeout_s = DEFAULT_DISCONNECT_TIMEOUT_S; + } + let tasks_or_workers_change_notify = Arc::new(Notify::new()); let metrics = Arc::new(Metrics::default()); @@ -703,6 +764,8 @@ impl SimpleScheduler { max_job_retries, tasks_or_workers_change_notify: tasks_or_workers_change_notify.clone(), metrics: metrics.clone(), + disconnect_timeout_s, + callbacks: Arc::new(callbacks), })); let weak_inner = Arc::downgrade(&inner); Self { @@ -777,6 +840,97 @@ impl SimpleScheduler { #[async_trait] impl ActionScheduler for SimpleScheduler { + fn notify_client_disconnected(&self, unique_qualifier: &ActionInfoHashKey) { + // TODO: Make this prettier. + // It's a bit tricky to comply with borrow checker + // but it should be possible to make this nicer. + let inner = self.get_inner_lock(); + let Some(action) = inner + .get_queued_action(&unique_qualifier) + .or_else(|| inner.get_active_action(&unique_qualifier)) + else { + warn!( + "Scheduler notified that client disconnected, but failed to find action {}", + unique_qualifier.digest.hash_str() + ); + return; + }; + + if action.notify_channel.receiver_count() != 0 { + return; + } + let sleep_time = Duration::from_secs(inner.disconnect_timeout_s); + let callbacks = inner.callbacks.clone(); + // Drop the mutex guard so we don't hold up access. + drop(inner); + + let weak_inner = Arc::downgrade(&self.inner); + + let unique_qualifier = unique_qualifier.clone(); + + // We create a spawn here which sleeps for disconnect_timeout_s + // and then checks to see if the listener count is still 0. + // If so, it moves to the removal stage, otherwise returns. + tokio::spawn(async move { + callbacks.sleep(sleep_time).await; + let Some(inner_mux) = weak_inner.upgrade() else { + return; + }; + let mut inner = inner_mux.lock(); + let Some(action) = inner + .queued_actions_set + .get(&unique_qualifier) + .and_then(|action_info| inner.queued_actions.get(action_info)) + else { + if inner.active_actions.contains_key(&unique_qualifier) { + warn!("Action was active and could not be killed"); + } else { + info!( + "Action {} completed while client disconnected", + unique_qualifier.digest.hash_str() + ); + } + return; + }; + + // If listener count is 0, remove the action from queued actions + // or active actions, whichever is applicable, and kill the action + // on the worker if it was active. + // If the listener count is not still 0, a client has reconencted + // and there is no more work to be done. + if action.notify_channel.receiver_count() != 0 { + info!( + "Client reconnected before disconnect_timeout elsapsed for Action {}", + unique_qualifier.digest.hash_str() + ); + return; + } + + warn!( + "Client disconnect timeout elapsed - Removing action with digest hash {}", + action.action_info.unique_qualifier.digest.hash_str() + ); + + println!("About to remove"); + match inner.get_queued_action(&unique_qualifier) { + Some(_) => { + // We can't use the action info from the above call due to borrow checker. + let action_info = inner + .queued_actions_set + .get(&unique_qualifier) + .unwrap() + .clone(); + inner.queued_actions_set.remove(&action_info); + inner.queued_actions.remove(&action_info); + } + None => { + inner.active_actions.remove(&unique_qualifier); + // TODO: Send kill on worker signal - PR: #842. + } + } + }); + } + async fn get_platform_property_manager( &self, _instance_name: &str, @@ -1024,9 +1178,9 @@ impl MetricsComponent for SimpleScheduler { impl MetricsComponent for CompletedAction { fn gather_metrics(&self, c: &mut CollectorState) { c.publish( - "completed_timestamp", - &self.completed_time, - "The timestamp this action was completed", + "completed_instant_elapsed", + &self.completed_instant.elapsed(), + "The elapsed time since the action was completed", ); c.publish( "current_state", diff --git a/nativelink-scheduler/tests/simple_scheduler_test.rs b/nativelink-scheduler/tests/simple_scheduler_test.rs index 0653d7b73..5d290d5ad 100644 --- a/nativelink-scheduler/tests/simple_scheduler_test.rs +++ b/nativelink-scheduler/tests/simple_scheduler_test.rs @@ -31,7 +31,7 @@ use nativelink_proto::build::bazel::remote::execution::v2::{digest_function, Exe use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{ update_for_worker, ConnectionResult, StartExecute, UpdateForWorker, }; -use nativelink_scheduler::simple_scheduler::SimpleScheduler; +use nativelink_scheduler::simple_scheduler::{Callbacks, SimpleScheduler}; use nativelink_scheduler::worker::{Worker, WorkerId}; use nativelink_scheduler::worker_scheduler::WorkerScheduler; use nativelink_util::common::DigestInfo; @@ -95,6 +95,8 @@ async fn setup_action( #[cfg(test)] mod scheduler_tests { + use std::time::Instant; + use pretty_assertions::assert_eq; use super::*; // Must be declared in every module. @@ -107,6 +109,7 @@ mod scheduler_tests { let scheduler = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), + Callbacks::default(), || async move {}, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -160,6 +163,7 @@ mod scheduler_tests { let scheduler = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), + Callbacks::default(), || async move {}, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -224,6 +228,7 @@ mod scheduler_tests { worker_timeout_s: WORKER_TIMEOUT_S, ..Default::default() }, + Callbacks::default(), || async move {}, ); let action_digest1 = DigestInfo::new([99u8; 32], 512); @@ -367,6 +372,7 @@ mod scheduler_tests { let scheduler = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), + Callbacks::default(), || async move {}, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -442,6 +448,7 @@ mod scheduler_tests { let scheduler = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), + Callbacks::default(), || async move {}, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -524,6 +531,7 @@ mod scheduler_tests { let scheduler = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), + Callbacks::default(), || async move {}, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -625,6 +633,7 @@ mod scheduler_tests { const WORKER_ID: WorkerId = WorkerId(0x0010_0010); let scheduler = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), + Callbacks::default(), || async move {}, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -666,6 +675,7 @@ mod scheduler_tests { worker_timeout_s: WORKER_TIMEOUT_S, ..Default::default() }, + Callbacks::default(), || async move {}, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -765,6 +775,7 @@ mod scheduler_tests { let scheduler = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), + Callbacks::default(), || async move {}, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -868,6 +879,7 @@ mod scheduler_tests { let scheduler = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), + Callbacks::default(), || async move {}, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -971,6 +983,7 @@ mod scheduler_tests { let scheduler = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), + Callbacks::default(), || async move {}, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -1066,6 +1079,7 @@ mod scheduler_tests { let scheduler = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), + Callbacks::default(), || async move {}, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -1195,6 +1209,7 @@ mod scheduler_tests { let scheduler = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), + Callbacks::default(), || async move {}, ); let action_digest1 = DigestInfo::new([11u8; 32], 512); @@ -1339,6 +1354,7 @@ mod scheduler_tests { let scheduler = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), + Callbacks::default(), || async move {}, ); let action_digest1 = DigestInfo::new([11u8; 32], 512); @@ -1394,6 +1410,7 @@ mod scheduler_tests { max_job_retries: 2, ..Default::default() }, + Callbacks::default(), || async move {}, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -1523,6 +1540,7 @@ mod scheduler_tests { // DropChecker. let scheduler = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), + Callbacks::default(), move || { // This will ensure dropping happens if this function is ever dropped. let _drop_checker = drop_checker.clone(); @@ -1548,6 +1566,7 @@ mod scheduler_tests { let scheduler = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), + Callbacks::default(), || async move {}, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -1602,4 +1621,202 @@ mod scheduler_tests { Ok(()) } + + #[tokio::test] + async fn ensure_actions_with_disconnected_clients_are_dropped() -> Result<(), Error> { + const DISCONNECT_TIMEOUT_S: u64 = 1; + + let scheduler = SimpleScheduler::new_with_callback( + &nativelink_config::schedulers::SimpleScheduler { + disconnect_timeout_s: DISCONNECT_TIMEOUT_S, + ..Default::default() + }, + Callbacks::new(Instant::now, |_| Box::pin(futures::future::ready(()))), + || async move {}, + ); + + let client_rx = setup_action( + &scheduler, + DigestInfo::new([98u8; 32], 512), + PlatformProperties::default(), + make_system_time(1), + ) + .await?; + + // Drop our receiver. + let unique_qualifier = client_rx.borrow().unique_qualifier.clone(); + drop(client_rx); + // Inform the scheduler the client has been dropped. + // Normally this would happen automatically in the service api. + scheduler.notify_client_disconnected(&unique_qualifier); + + // Allow the action removal spawn to run. + tokio::task::yield_now().await; + + // Check to make sure that the action was removed. + assert!( + scheduler + .find_existing_action(&unique_qualifier) + .await + .is_none(), + "Expected action to be removed" + ); + + Ok(()) + } + + #[tokio::test] + async fn ensure_notify_disconnected_does_not_block() -> Result<(), Error> { + const WORKER_ID: WorkerId = WorkerId(0x1234_5678_9111); + const DISCONNECT_TIMEOUT_S: u64 = 1; + + let scheduler = SimpleScheduler::new_with_callback( + &nativelink_config::schedulers::SimpleScheduler { + disconnect_timeout_s: DISCONNECT_TIMEOUT_S, + ..Default::default() + }, + Callbacks::new(Instant::now, |_| Box::pin(futures::future::pending())), + || async move {}, + ); + let action1_digest = DigestInfo::new([98u8; 32], 512); + let action2_digest = DigestInfo::new([99u8; 32], 512); + + let mut rx_from_worker = + setup_new_worker(&scheduler, WORKER_ID, PlatformProperties::default()).await?; + let insert_timestamp = make_system_time(1); + + let client_rx = setup_action( + &scheduler, + action1_digest, + PlatformProperties::default(), + insert_timestamp, + ) + .await?; + + // Drop our receiver. + let unique_qualifier_1 = client_rx.borrow().unique_qualifier.clone(); + drop(client_rx); + + scheduler.notify_client_disconnected(&unique_qualifier_1); + + { + // Other tests check full data. We only care if we got StartAction. + match rx_from_worker.recv().await.unwrap().update { + Some(update_for_worker::Update::StartAction(_)) => { /* Success */ } + v => panic!("Expected StartAction, got : {v:?}"), + } + } + + // Setup a second action so matching engine is scheduled to rerun. + let client_rx = setup_action( + &scheduler, + action2_digest, + PlatformProperties::default(), + insert_timestamp, + ) + .await?; + + { + // Make sure the action sent after dropping the first reciever is registered. + match rx_from_worker.recv().await.unwrap().update { + Some(update_for_worker::Update::StartAction(_)) => { /* Success */ } + v => panic!("Expected StartAction, got : {v:?}"), + } + } + + // Check to make sure that the first action was not removed. + assert!( + scheduler + .find_existing_action(&unique_qualifier_1) + .await + .is_some(), + "Expected action to be removed" + ); + + let unique_qualifier_2 = client_rx.borrow().unique_qualifier.clone(); + // Check to make sure that the first action was not removed. + assert!( + scheduler + .find_existing_action(&unique_qualifier_2) + .await + .is_some(), + "Expected action to be removed" + ); + Ok(()) + } + + // #[tokio::test] + // async fn ensure_active_actions_with_client_disconnect_are_dropped() -> Result<(), Error> { + // const WORKER_ID: WorkerId = WorkerId(0x1234_5678_9111); + // const DISCONNECT_TIMEOUT_S: u64 = 1; + + // let scheduler = SimpleScheduler::new_with_callback( + // &nativelink_config::schedulers::SimpleScheduler { + // disconnect_timeout_s: DISCONNECT_TIMEOUT_S, + // ..Default::default() + // }, + // Callbacks::new(Instant::now, |_| Box::pin(futures::future::ready(()))), + // || async move {}, + // ); + // let action_digest = DigestInfo::new([98u8; 32], 512); + // let action_digest = DigestInfo::new([98u8; 32], 512); + + // let mut rx_from_worker = + // setup_new_worker(&scheduler, WORKER_ID, PlatformProperties::default()).await?; + // let insert_timestamp = make_system_time(1); + + // let client_rx = setup_action( + // &scheduler, + // action1_digest, + // PlatformProperties::default(), + // insert_timestamp, + // ) + // .await?; + + // // Drop our receiver. + // let unique_qualifier = client_rx.borrow().unique_qualifier.clone(); + // drop(client_rx); + + // // Allow task<->worker matcher to run. + // tokio::task::yield_now().await; + + // // Inform the scheduler the client has been dropped. + // // Normally this would happen automatically due to Tonic. + // scheduler.notify_client_disconnected(&unique_qualifier); + + // { + // // Other tests check full data. We only care if we got StartAction. + // match rx_from_worker.recv().await.unwrap().update { + // Some(update_for_worker::Update::StartAction(_)) => { /* Success */ } + // v => panic!("Expected StartAction, got : {v:?}"), + // } + // } + + // // Setup a second action so matching engine is scheduled to rerun. + // let client_rx = setup_action( + // &scheduler, + // action2_digest, + // PlatformProperties::default(), + // insert_timestamp, + // ) + // .await?; + // println!("About to drop rx"); + // drop(client_rx); + // // scheduler.notify_client_disconnected_for_test(&unique_qualifier).await; + + // // Allow task<->worker matcher to run. + // tokio::task::yield_now().await; + // println!("yield did run"); + + // // Check to make sure that the action was removed. + // assert!( + // scheduler + // .find_existing_action(&unique_qualifier) + // .await + // .is_none(), + // "Expected action to be removed" + // ); + + // Ok(()) + // } } diff --git a/nativelink-scheduler/tests/utils/mock_scheduler.rs b/nativelink-scheduler/tests/utils/mock_scheduler.rs index 9afd1dd6b..706acba56 100644 --- a/nativelink-scheduler/tests/utils/mock_scheduler.rs +++ b/nativelink-scheduler/tests/utils/mock_scheduler.rs @@ -120,6 +120,8 @@ impl MockActionScheduler { #[async_trait] impl ActionScheduler for MockActionScheduler { + fn notify_client_disconnected(&self, _unique_qualifier: &ActionInfoHashKey) {} + async fn get_platform_property_manager( &self, instance_name: &str, diff --git a/nativelink-service/BUILD.bazel b/nativelink-service/BUILD.bazel index 06e1d4ef2..63eedc52c 100644 --- a/nativelink-service/BUILD.bazel +++ b/nativelink-service/BUILD.bazel @@ -31,6 +31,7 @@ rust_library( "@crates//:parking_lot", "@crates//:prost", "@crates//:rand", + "@crates//:scopeguard", "@crates//:tokio", "@crates//:tokio-stream", "@crates//:tonic", diff --git a/nativelink-service/Cargo.toml b/nativelink-service/Cargo.toml index 233f6c3d5..2c2bae302 100644 --- a/nativelink-service/Cargo.toml +++ b/nativelink-service/Cargo.toml @@ -17,6 +17,7 @@ log = "0.4.21" parking_lot = "0.12.1" prost = "0.12.3" rand = "0.8.5" +scopeguard = "1.2.0" tokio = { version = "1.37.0", features = ["sync", "rt"] } tokio-stream = { version = "0.1.15", features = ["sync"] } tonic = { version = "0.11.0", features = ["gzip", "tls"] } diff --git a/nativelink-service/src/execution_server.rs b/nativelink-service/src/execution_server.rs index cd26d4f93..3b3269612 100644 --- a/nativelink-service/src/execution_server.rs +++ b/nativelink-service/src/execution_server.rs @@ -38,11 +38,11 @@ use nativelink_util::digest_hasher::DigestHasherFunc; use nativelink_util::platform_properties::PlatformProperties; use nativelink_util::store_trait::Store; use rand::{thread_rng, Rng}; +use scopeguard::guard; use tokio::sync::watch; use tokio_stream::wrappers::WatchStream; use tonic::{Request, Response, Status}; use tracing::{error, info}; - struct InstanceInfo { scheduler: Arc, cas_store: Arc, @@ -182,11 +182,26 @@ impl ExecutionServer { Server::new(self) } - fn to_execute_stream(receiver: watch::Receiver>) -> Response { - let receiver_stream = Box::pin(WatchStream::new(receiver).map(|action_update| { + fn to_execute_stream( + &self, + receiver: watch::Receiver>, + unique_qualifier: ActionInfoHashKey, + ) -> Response { + let scheduler = self + .instance_infos + .get(&unique_qualifier.instance_name) + .unwrap() + .scheduler + .clone(); + let scope_guard = guard(scheduler, move |scheduler| { + scheduler.notify_client_disconnected(&unique_qualifier) + }); + let receiver_stream = Box::pin(WatchStream::new(receiver).map(move |action_update| { + let _scope_guard = &scope_guard; info!("\x1b[0;31mexecute Resp Stream\x1b[0m: {:?}", action_update); Ok(Into::::into(action_update.as_ref().clone())) })); + tonic::Response::new(receiver_stream) } @@ -230,11 +245,11 @@ impl ExecutionServer { let rx = instance_info .scheduler - .add_action(action_info) + .add_action(action_info.clone()) .await .err_tip(|| "Failed to schedule task")?; - Ok(Self::to_execute_stream(rx)) + Ok(self.to_execute_stream(rx, action_info.unique_qualifier)) } async fn inner_wait_execution( @@ -256,7 +271,7 @@ impl ExecutionServer { else { return Err(Status::not_found("Failed to find existing task")); }; - Ok(Self::to_execute_stream(rx)) + Ok(self.to_execute_stream(rx, unique_qualifier)) } }