diff --git a/atoma-bin/atoma_node.rs b/atoma-bin/atoma_node.rs index 6fb87822..16f87281 100644 --- a/atoma-bin/atoma_node.rs +++ b/atoma-bin/atoma_node.rs @@ -373,9 +373,11 @@ async fn main() -> Result<()> { keystore: Arc::new(keystore), address_index, whitelist_sui_addresses_for_fiat: config.service.whitelist_sui_addresses_for_fiat, - too_many_requests: Arc::new(DashMap::new()), + too_many_requests: Arc::new(DashSet::new()), too_many_requests_timeout_ms: u128::from(config.service.too_many_requests_timeout_ms), running_num_requests: Arc::new(RequestCounter::new()), + memory_lower_threshold: config.service.memory_lower_threshold, + memory_upper_threshold: config.service.memory_upper_threshold, }; let chat_completions_service_urls = app_state diff --git a/atoma-service/src/config.rs b/atoma-service/src/config.rs index f7cff89a..e6725ff3 100644 --- a/atoma-service/src/config.rs +++ b/atoma-service/src/config.rs @@ -60,6 +60,11 @@ pub struct AtomaServiceConfig { /// The timeout for the too many requests error in milliseconds. pub too_many_requests_timeout_ms: u64, + + ///Lower threshold for memory usage, if the memory usage goes below this value, the service will not be considered overloaded + pub memory_lower_threshold: f64, + /// Upper threshold for memory usage, if the memory usage goes above this value, the service will be considered overloaded + pub memory_upper_threshold: f64, } impl AtomaServiceConfig { diff --git a/atoma-service/src/handlers/chat_completions.rs b/atoma-service/src/handlers/chat_completions.rs index 56f4d785..20fa6ae9 100644 --- a/atoma-service/src/handlers/chat_completions.rs +++ b/atoma-service/src/handlers/chat_completions.rs @@ -908,6 +908,7 @@ async fn handle_streaming_response( &state.running_num_requests, chat_completions_service_urls, &model.to_lowercase(), + state.memory_upper_threshold, ) .await .map_err(|e| AtomaServiceError::ChatCompletionsServiceUnavailable { @@ -915,9 +916,7 @@ async fn handle_streaming_response( endpoint: endpoint.clone(), })?; if status_code == StatusCode::TOO_MANY_REQUESTS { - state - .too_many_requests - .insert(model.to_string(), Instant::now()); + state.too_many_requests.insert(model.to_string()); return Err(AtomaServiceError::ChatCompletionsServiceUnavailable { message: "Too many requests".to_string(), endpoint: endpoint.clone(), @@ -1341,6 +1340,7 @@ pub mod utils { &state.running_num_requests, chat_completions_service_url_services, model, + state.memory_upper_threshold, ) .await .map_err(|e| AtomaServiceError::ChatCompletionsServiceUnavailable { @@ -1348,9 +1348,7 @@ pub mod utils { endpoint: endpoint.to_string(), })?; if status_code == StatusCode::TOO_MANY_REQUESTS { - state - .too_many_requests - .insert(model.to_string(), Instant::now()); + state.too_many_requests.insert(model.to_string()); return Err(AtomaServiceError::ChatCompletionsServiceUnavailable { message: "Too many requests".to_string(), endpoint: endpoint.to_string(), diff --git a/atoma-service/src/handlers/completions.rs b/atoma-service/src/handlers/completions.rs index a2da3069..0d80a52f 100644 --- a/atoma-service/src/handlers/completions.rs +++ b/atoma-service/src/handlers/completions.rs @@ -881,6 +881,7 @@ async fn handle_streaming_response( &state.running_num_requests, chat_completions_service_urls, model, + state.memory_upper_threshold, ) .await .map_err(|e| AtomaServiceError::ChatCompletionsServiceUnavailable { @@ -888,9 +889,7 @@ async fn handle_streaming_response( endpoint: endpoint.clone(), })?; if status_code == StatusCode::TOO_MANY_REQUESTS { - state - .too_many_requests - .insert(model.to_string(), Instant::now()); + state.too_many_requests.insert(model.to_string()); return Err(AtomaServiceError::ChatCompletionsServiceUnavailable { message: "Too many requests".to_string(), endpoint: endpoint.clone(), @@ -1303,6 +1302,7 @@ pub mod utils { &state.running_num_requests, completions_service_url_services, model, + state.memory_upper_threshold, ) .await .map_err(|e| AtomaServiceError::ChatCompletionsServiceUnavailable { @@ -1310,9 +1310,7 @@ pub mod utils { endpoint: endpoint.to_string(), })?; if status_code == StatusCode::TOO_MANY_REQUESTS { - state - .too_many_requests - .insert(model.to_string(), Instant::now()); + state.too_many_requests.insert(model.to_string()); return Err(AtomaServiceError::ChatCompletionsServiceUnavailable { message: "Too many requests".to_string(), endpoint: endpoint.to_string(), diff --git a/atoma-service/src/handlers/mod.rs b/atoma-service/src/handlers/mod.rs index 7482dd53..085163dd 100644 --- a/atoma-service/src/handlers/mod.rs +++ b/atoma-service/src/handlers/mod.rs @@ -53,9 +53,11 @@ pub const COMPLETION_TOKENS_KEY: &str = "completion_tokens"; const VLLM_RUNNING_REQUESTS_QUERY: &str = "num_requests_running"; const VLLM_QUEUED_REQUESTS_QUERY: &str = "num_requests_waiting"; +const VLLM_MEMORY_USAGE_QUERY: &str = "gpu_cache_usage_perc"; const VLLM_SERVICE_PREFIX: &str = "vllm:"; const SGLANG_RUNNING_REQUESTS_QUERY: &str = "num_running_reqs"; const SGLANG_QUEUED_REQUESTS_QUERY: &str = "num_queue_reqs"; +const SGLANG_MEMORY_USAGE_QUERY: &str = "token_usage"; const SGLANG_SERVICE_PREFIX: &str = "sglang:"; #[derive(Debug, Clone)] @@ -81,6 +83,14 @@ impl InferenceService { } } + #[must_use] + pub const fn get_usage(&self) -> &'static str { + match self { + Self::Vllm => VLLM_MEMORY_USAGE_QUERY, + Self::SgLang => SGLANG_MEMORY_USAGE_QUERY, + } + } + #[must_use] pub const fn get_service_prefix(&self) -> &'static str { match self { @@ -611,8 +621,8 @@ pub mod inference_service_metrics { }); /// Chat completions metrics - #[derive(Debug, Clone)] - struct ChatCompletionsMetrics { + #[derive(Debug, Clone, PartialEq)] + pub struct ChatCompletionsMetrics { /// The model name model: String, /// The chat completions service url @@ -621,10 +631,47 @@ pub mod inference_service_metrics { num_queued_requests: f64, /// The number of running requests num_running_requests: f64, + /// The memory usage in fraction, e.g. 1.00 means 100% memory usage + memory_usage: f64, /// The maximum number of running requests allowed for this url. max_number_of_running_requests: usize, } + impl Eq for ChatCompletionsMetrics {} + + impl PartialOrd for ChatCompletionsMetrics { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } + } + + impl Ord for ChatCompletionsMetrics { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.num_queued_requests + .total_cmp(&other.num_queued_requests) + .then_with(|| { + self.memory_usage + .total_cmp(&other.memory_usage) + .then_with(|| { + self.num_running_requests + .total_cmp(&other.num_running_requests) + }) + }) + } + } + + impl ChatCompletionsMetrics { + #[must_use] + pub fn above_upper_threshold_exceeded(&self, threshold: f64) -> bool { + self.memory_usage > threshold + } + + #[must_use] + pub fn under_lower_threshold(&self, threshold: f64) -> bool { + self.memory_usage <= threshold + } + } + /// Cache structure to store metrics #[derive(Debug, Default)] struct MetricsCache { @@ -787,12 +834,18 @@ pub mod inference_service_metrics { inference_service.get_running_requests_metric_name(), job, )?; + let memory_usage = extract_metric( + &metrics, + inference_service.get_usage(), + job, + )?; Ok(ChatCompletionsMetrics { model: model.clone(), chat_completions_service_url: chat_completions_service_url.clone(), num_queued_requests, num_running_requests, + memory_usage, max_number_of_running_requests: *max_number_of_running_requests, }) }); @@ -838,71 +891,37 @@ pub mod inference_service_metrics { }) } - /// Selects the best available chat completions service URL for a given model based on performance metrics. - /// - /// This function aims to distribute load and ensure optimal response times by choosing - /// the service instance that is currently performing best. The selection process prioritizes - /// services with lower requests running and queue lengths. - /// - /// # Metrics and Selection Logic: + /// Retrieves all chat completions metrics for the specified model. /// - /// 1. **Metrics Source**: Metrics for each service (vLLM or SgLang) are retrieved directly from the inference - /// service URL. - /// - /// 2. **Priority of Metrics for "Best" Service Selection**: - /// * **No Load**: If a service has zero running requests (`num_running_requests` is 0.0), - /// it's considered the best. - /// * **Number of Queued Requests**: If number of running requests are equivalent, the service - /// with the fewest `num_queue_requests` is selected. - /// - /// 3. **Handling Missing or Invalid Metrics**: - /// * If, after checking all services, no valid metrics are found for the specified `model`, - /// a service URL is chosen randomly from the initial list. - /// - /// # Load Thresholds and Behavior: - /// - /// The function defines several thresholds to manage high load scenarios: - /// * `MAX_ALLOWED_NUM_QUEUED_REQUESTS` (1.0) - /// - /// If the determined "best" service (or all services) exceeds these - /// thresholds, the function returns the first URL from the input `chat_completions_service_urls` - /// list along with a `StatusCode::TOO_MANY_REQUESTS`. The `CHAT_COMPLETIONS_TOO_MANY_REQUESTS` - /// metric counter is also incremented for the model. + /// This function fetches metrics from both vLLM and SgLang services, + /// partitions the service URLs based on the job type, and retrieves metrics + /// for each service. It returns a vector of `ChatCompletionsMetrics` for the specified model. /// /// # Arguments /// - /// * `chat_completions_service_urls`: A slice of tuples `(String, String)`, where each tuple - /// represents a service. The first `String` is the service URL, and the second `String` - /// is the job name (e.g., "vllm-service", "sglang-service"), used to determine the - /// metrics querying strategy. - /// * `model`: A string slice representing the name of the model for which the best service - /// URL is being requested. The comparison is case-insensitive. + /// * `chat_completions_service_urls` - A vector of tuples containing the chat completions service URLs, + /// job names, and maximum concurrent requests. + /// * `model` - The model name for which to retrieve metrics. /// /// # Returns /// - /// Returns a `Result<(String, StatusCode), ChatCompletionsMetricsError>`: - /// * `Ok((String, StatusCode::OK))`: On success, containing the URL of the best available - /// service and an OK status. - /// * `Ok((String, StatusCode::TOO_MANY_REQUESTS))`: If the system is determined to be under - /// high load based on the metrics thresholds. The returned `String` will be the first URL - /// from the `chat_completions_service_urls` input. - /// * `Err(ChatCompletionsMetricsError)`: If an error occurs, such as no service URLs - /// being provided or issues during metrics fetching that are not handled by fallback mechanisms. + /// Returns a `Result>` containing the metrics for the specified model. /// /// # Errors /// - /// * `ChatCompletionsMetricsError::NoChatCompletionsServiceUrlsFound`: If the input - /// `chat_completions_service_urls` slice is empty. - /// * Other variants of `ChatCompletionsMetricsError` may be returned if underlying issues - /// occur during metric collection from Prometheus (e.g., network errors, parsing errors), - /// though the function attempts to handle missing individual metrics gracefully.g - #[instrument(level = "info", skip_all, fields(model=model))] - #[allow(clippy::float_cmp)] - pub async fn get_best_available_chat_completions_service_url( - running_num_requests: &RequestCounter, + /// * `ChatCompletionsMetricsError::NoChatCompletionsServiceUrlsFound`: If no chat completions service URLs are provided. + /// * Other variants of `ChatCompletionsMetricsError` may be returned if underlying + /// issues occur during metric collection from Prometheus (e.g., network errors, + /// parsing errors), though the function attempts to handle missing individual metrics gracefully. + #[instrument( + level = "info", + skip(chat_completions_service_urls, model), + fields(model = model) + )] + pub async fn get_all_metrics( chat_completions_service_urls: &[(String, String, usize)], // (url, job, max_concurrent_requests) model: &str, - ) -> Result<(String, StatusCode)> { + ) -> Result> { type ChatCompletionsServiceUrls = Vec<(String, String, usize)>; if chat_completions_service_urls.is_empty() { @@ -1019,6 +1038,7 @@ pub mod inference_service_metrics { chat_completions_service_url, num_queued_requests, num_running_requests, + memory_usage, max_number_of_running_requests, }) => { tracing::info!( @@ -1045,6 +1065,7 @@ pub mod inference_service_metrics { chat_completions_service_url, num_queued_requests, num_running_requests, + memory_usage, max_number_of_running_requests, }); } @@ -1058,7 +1079,85 @@ pub mod inference_service_metrics { } } } + Ok(metrics_results) + } + /// Selects the best available chat completions service URL for a given model based on performance metrics. + /// + /// This function aims to distribute load and ensure optimal response times by choosing + /// the service instance that is currently performing best. The selection process prioritizes + /// services with lower requests running and queue lengths. + /// + /// # Metrics and Selection Logic: + /// + /// 1. **Metrics Source**: Metrics for each service (vLLM or SgLang) are retrieved directly from the inference + /// service URL. + /// + /// 2. **Priority of Metrics for "Best" Service Selection**: + /// * **No Load**: If a service has zero running requests (`num_running_requests` is 0.0), + /// it's considered the best. + /// * **Number of Queued Requests**: If number of running requests are equivalent, the service + /// with the fewest `num_queue_requests` is selected. + /// + /// 3. **Handling Missing or Invalid Metrics**: + /// * If, after checking all services, no valid metrics are found for the specified `model`, + /// a service URL is chosen randomly from the initial list. + /// + /// # Load Thresholds and Behavior: + /// + /// The function defines several thresholds to manage high load scenarios: + /// * `MAX_ALLOWED_NUM_QUEUED_REQUESTS` (1.0) + /// + /// If the determined "best" service (or all services) exceeds these + /// thresholds, the function returns the first URL from the input `chat_completions_service_urls` + /// list along with a `StatusCode::TOO_MANY_REQUESTS`. The `CHAT_COMPLETIONS_TOO_MANY_REQUESTS` + /// metric counter is also incremented for the model. + /// + /// # Arguments + /// + /// * `chat_completions_service_urls`: A slice of tuples `(String, String)`, where each tuple + /// represents a service. The first `String` is the service URL, and the second `String` + /// is the job name (e.g., "vllm-service", "sglang-service"), used to determine the + /// metrics querying strategy. + /// * `model`: A string slice representing the name of the model for which the best service + /// URL is being requested. The comparison is case-insensitive. + /// + /// # Returns + /// + /// Returns a `Result<(String, StatusCode), ChatCompletionsMetricsError>`: + /// * `Ok((String, StatusCode::OK))`: On success, containing the URL of the best available + /// service and an OK status. + /// * `Ok((String, StatusCode::TOO_MANY_REQUESTS))`: If the system is determined to be under + /// high load based on the metrics thresholds. The returned `String` will be the first URL + /// from the `chat_completions_service_urls` input. + /// * `Err(ChatCompletionsMetricsError)`: If an error occurs, such as no service URLs + /// being provided or issues during metrics fetching that are not handled by fallback mechanisms. + /// + /// # Errors + /// + /// * `ChatCompletionsMetricsError::NoChatCompletionsServiceUrlsFound`: If the input + /// `chat_completions_service_urls` slice is empty. + /// * Other variants of `ChatCompletionsMetricsError` may be returned if underlying issues + /// occur during metric collection from Prometheus (e.g., network errors, parsing errors), + /// though the function attempts to handle missing individual metrics gracefully.g + #[instrument(level = "info", skip_all, fields(model=model))] + #[allow(clippy::float_cmp)] + pub async fn get_best_available_chat_completions_service_url( + running_num_requests: &RequestCounter, + chat_completions_service_urls: &[(String, String, usize)], // (url, job, max_concurrent_requests) + model: &str, + memory_upper_threshold: f64, + ) -> Result<(String, StatusCode)> { + let mut metrics_results = get_all_metrics(chat_completions_service_urls, model) + .await + .map_err(|e| { + tracing::error!( + target = "atoma-service", + level = "error", + "Failed to get metrics for model: {model} with error: {e}" + ); + e + })?; if metrics_results.is_empty() { tracing::warn!( target = "atoma-service", @@ -1072,18 +1171,22 @@ pub mod inference_service_metrics { } // Select the best available chat completions service URL based on the number of queued and running requests. - metrics_results.sort_by_key(|metric| { - ( - metric.num_queued_requests as i64, - metric.num_running_requests as i64, - ) - }); + metrics_results.sort(); for metric in metrics_results { if running_num_requests.increment( &metric.chat_completions_service_url, metric.max_number_of_running_requests, ) { + if metric.above_upper_threshold_exceeded(memory_upper_threshold) { + tracing::debug!( + target = "atoma-service", + level = "debug", + "Memory usage for model: {model} is too high: {}", + metric.memory_usage + ); + continue; + } let best_url = metric.chat_completions_service_url.clone(); tracing::info!( target = "atoma-service", diff --git a/atoma-service/src/middleware.rs b/atoma-service/src/middleware.rs index 04fd20fc..701b986a 100644 --- a/atoma-service/src/middleware.rs +++ b/atoma-service/src/middleware.rs @@ -811,7 +811,7 @@ pub async fn verify_permissions( message: "Model is not a string".to_string(), endpoint: endpoint.clone(), })?; - utils::check_if_too_many_requests(&state, model, &endpoint)?; + utils::check_if_too_many_requests(&state, model, &endpoint).await?; if !state.models.contains(&model.to_string()) { return Err(AtomaServiceError::InvalidBody { message: format!("Model not supported, supported models: {:?}", state.models), @@ -1016,6 +1016,7 @@ pub mod utils { completions::RequestModelCompletions, embeddings::RequestModelEmbeddings, image_generations::RequestModelImageGenerations, + inference_service_metrics::get_all_metrics, request_model::{RequestModel, TokensEstimate}, }; @@ -1607,36 +1608,48 @@ pub mod utils { /// - The model has too many requests /// - The elapsed time since the first occurrence is less than the timeout #[instrument(level = "info", skip_all, err)] - pub fn check_if_too_many_requests( + pub async fn check_if_too_many_requests( state: &AppState, model: &str, endpoint: &str, ) -> Result<(), AtomaServiceError> { - match state.too_many_requests.entry(model.to_string()) { - dashmap::mapref::entry::Entry::Occupied(occupied_entry) => { - let elapsed_ms = occupied_entry.get().elapsed().as_millis(); - - if elapsed_ms < state.too_many_requests_timeout_ms { - tracing::info!( - target = "atoma-service", - level = "info", - "Too many requests for model: {model}, endpoint: {endpoint}, elapsed trigger time: {elapsed_ms} and timeout: {}", - state.too_many_requests_timeout_ms - ); - return Err(AtomaServiceError::ChatCompletionsServiceUnavailable { - message: "Too many requests".to_string(), + if state.too_many_requests.get(model).is_some() { + let chat_completions_service_urls = state + .chat_completions_service_urls + .get(&model.to_lowercase()) + .ok_or_else(|| { + AtomaServiceError::InternalError { + message: format!( + "Chat completions service URL not found, likely that model is not supported by the current node: {}", + model + ), endpoint: endpoint.to_string(), - }); - } - occupied_entry.remove(); - } - dashmap::mapref::entry::Entry::Vacant(_) => { + } + })?; + let metrics = get_all_metrics(chat_completions_service_urls, model) + .await + .map_err(|e| AtomaServiceError::InternalError { + message: format!("Failed to get metrics for model {model}, with error: {e}"), + endpoint: endpoint.to_string(), + })?; + if metrics + .iter() + .any(|metric| metric.under_lower_threshold(state.memory_lower_threshold)) + { + state.too_many_requests.remove(model); tracing::debug!( target = "atoma-service", level = "debug", - "Model is not in the `too_many_requests` map, so no action is needed here. Processing can continue." + "Model {} is in the `too_many_requests` map, but metrics indicate that it is no longer exceeding the lower threshold. Removing from the map.", + model ); } + } else { + tracing::debug!( + target = "atoma-service", + level = "debug", + "Model is not in the `too_many_requests` map, so no action is needed here. Processing can continue." + ); } Ok(()) } diff --git a/atoma-service/src/server.rs b/atoma-service/src/server.rs index 125a9d9e..7e132116 100644 --- a/atoma-service/src/server.rs +++ b/atoma-service/src/server.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, sync::Arc, time::Instant}; +use std::{collections::HashMap, sync::Arc}; use atoma_confidential::types::{ ConfidentialComputeDecryptionRequest, ConfidentialComputeDecryptionResponse, @@ -207,13 +207,21 @@ pub struct AppState { pub whitelist_sui_addresses_for_fiat: Vec, /// When was the too many requests triggered for each model. - pub too_many_requests: Arc>, + pub too_many_requests: Arc>, /// The time for which we triiger too many requests since the first occurrence. pub too_many_requests_timeout_ms: u128, /// Number of running requests for each inference service. pub running_num_requests: Arc, + + /// The upper memory threshold for the node. + /// This threshold is used to determine when the node should start rejecting. + pub memory_upper_threshold: f64, + + /// The lower memory threshold for the node. + /// This threshold is used to determine when the node can start accepting requests again. + pub memory_lower_threshold: f64, } /// Creates and configures the main router for the application. diff --git a/atoma-service/src/tests.rs b/atoma-service/src/tests.rs index 9410bfca..b601a696 100644 --- a/atoma-service/src/tests.rs +++ b/atoma-service/src/tests.rs @@ -341,9 +341,11 @@ mod middleware { address_index: 0, stack_retrieve_sender, whitelist_sui_addresses_for_fiat: vec![], - too_many_requests: Arc::new(DashMap::new()), + too_many_requests: Arc::new(DashSet::new()), too_many_requests_timeout_ms: 0, running_num_requests: Arc::new(RequestCounter::new()), + memory_lower_threshold: 1.0, + memory_upper_threshold: 1.0, }, public_key, signature, diff --git a/config.example.toml b/config.example.toml index 3aa7c708..1a519842 100644 --- a/config.example.toml +++ b/config.example.toml @@ -46,6 +46,8 @@ image_generations_service_url = "http://image-generations:80" # List of models to be used by the service, the current value here is just a placeholder, please change it to the models you want to deploy environment = "development" # or "production" (for use in sentry, you need to set the Sentry DSN) heartbeat_url = "my-heartbeat-url" +memory_lower_threshold = 0.75 # Lower threshold for memory usage, if the memory usage goes below this value, the service will not be considered overloaded +memory_upper_threshold = 0.9 # Upper threshold for memory usage, if the memory usage goes above this value, the service will be considered overloaded metrics_update_interval = 30 models = [ "Infermatic/Llama-3.3-70B-Instruct-FP8-Dynamic" ] revisions = [ "main" ]