diff --git a/xllm/core/common/global_flags.cpp b/xllm/core/common/global_flags.cpp index aac678170..81a7719eb 100644 --- a/xllm/core/common/global_flags.cpp +++ b/xllm/core/common/global_flags.cpp @@ -104,6 +104,12 @@ DEFINE_int32(max_tokens_for_graph_mode, 2048, "Maximum number of tokens for graph execution. " "If 0, no limit is applied."); + +DEFINE_int32(acl_graph_decode_batch_size_limit, + 16, + "Decode batch size threshold for ACL graph on NPU. " + "When actual decode batch_size > this value, ACL graph decode " + "falls back to eager mode to avoid OOM."); // --- vlm config --- DEFINE_int32(limit_image_per_prompt, diff --git a/xllm/core/common/global_flags.h b/xllm/core/common/global_flags.h index 2d8977c16..53215e627 100644 --- a/xllm/core/common/global_flags.h +++ b/xllm/core/common/global_flags.h @@ -109,6 +109,7 @@ DECLARE_bool(enable_prefill_piecewise_graph); DECLARE_bool(enable_graph_vmm_pool); DECLARE_int32(max_tokens_for_graph_mode); +DECLARE_int32(acl_graph_decode_batch_size_limit); DECLARE_bool(enable_chunked_prefill); diff --git a/xllm/core/common/help_formatter.h b/xllm/core/common/help_formatter.h index 8f696a981..03ff582e2 100644 --- a/xllm/core/common/help_formatter.h +++ b/xllm/core/common/help_formatter.h @@ -47,6 +47,7 @@ const OptionCategory kCommonOptions = {"COMMON OPTIONS", "enable_graph_mode_decode_no_padding", "enable_prefill_piecewise_graph", "max_tokens_for_graph_mode", + "acl_graph_decode_batch_size_limit", "communication_backend", "task"}}; diff --git a/xllm/core/layers/npu/npu_glm4_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_glm4_decoder_layer_impl.cpp index 24af8f0cb..caff0a721 100644 --- a/xllm/core/layers/npu/npu_glm4_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_glm4_decoder_layer_impl.cpp @@ -86,7 +86,9 @@ NpuGlm4DecoderLayerImpl::NpuGlm4DecoderLayerImpl(const ModelContext& context) auto options = context.get_tensor_options(); param_from_args(prefill_param_, model_args, parallel_args, true); - param_from_args(decode_param_, model_args, parallel_args, false); + param_from_args(decode_graph_param_, model_args, parallel_args, false); + decode_eager_param_ = decode_graph_param_; + decode_eager_param_.enableAclGraphPagedAttention = false; atb_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); placeholder_vec_ = {1}; dtype_ = c10::typeMetaToScalarType(options.dtype()); @@ -102,7 +104,10 @@ int64_t NpuGlm4DecoderLayerImpl::init_layer() { name_ = "glm4_decoder_layer"; model_name_ = "glm4"; CHECK_OPERATION_STATUS_RETURN(init_node(prefill_node_, prefill_param_)); - CHECK_OPERATION_STATUS_RETURN(init_node(decode_node_, decode_param_)); + CHECK_OPERATION_STATUS_RETURN( + init_node(decode_graph_node_, decode_graph_param_)); + CHECK_OPERATION_STATUS_RETURN( + init_node(decode_eager_node_, decode_eager_param_)); return atb::NO_ERROR; } @@ -164,21 +169,27 @@ torch::Tensor NpuGlm4DecoderLayerImpl::forward(torch::Tensor& x, attn_mask, kv_cache, input_params, - true); + true, + false); // mstxRangeEnd(id); st = execute_node(prefill_node_, node_id, event, event_flag); LOG_IF(FATAL, st != 0) << model_name_ << "excute prefill layer fail, error code: " << st; } else { - build_node_variant_pack(decode_node_, + const bool use_graph_decode_input = + FLAGS_enable_graph && input_params.graph_buffer.tiling_data.defined(); + auto& decode_node = + use_graph_decode_input ? decode_graph_node_ : decode_eager_node_; + build_node_variant_pack(decode_node, x, cos_pos, sin_pos, decode_attn_mask_, kv_cache, input_params, - false); - st = execute_node(decode_node_, node_id + 1000, event, event_flag); + false, + use_graph_decode_input); + st = execute_node(decode_node, node_id + 1000, event, event_flag); LOG_IF(FATAL, st != 0) << model_name_ << "excute decode layer fail, error code: " << st; } @@ -194,7 +205,8 @@ void NpuGlm4DecoderLayerImpl::build_node_variant_pack( at::Tensor& attn_mask, KVCache& kv_cache, ModelInputParams& input_params, - bool is_prefill) { + bool is_prefill, + bool use_graph_decode_input) { internal_tensors_ = atb_speed::Utils::AtTensor2Tensor(x); // std::cout<<"node.variantPack.inTensors.size:"< get_dtp_inputs(torch::Tensor token_size_per_dp_group, int32_t rank, at::Device device); } // namespace layer -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp index adc146b08..8ad4339e7 100644 --- a/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp @@ -157,7 +157,9 @@ NpuQwen3DecoderLayerImpl::NpuQwen3DecoderLayerImpl(const ModelContext& context) auto options = context.get_tensor_options(); param_from_args(prefill_param_, model_args, parallel_args, true); - param_from_args(decode_param_, model_args, parallel_args, false); + param_from_args(decode_graph_param_, model_args, parallel_args, false); + decode_eager_param_ = decode_graph_param_; + decode_eager_param_.enableAclGraphPagedAttention = false; atb_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); placeholder_vec_ = {1}; dtype_ = c10::typeMetaToScalarType(options.dtype()); @@ -188,7 +190,10 @@ int64_t NpuQwen3DecoderLayerImpl::init_layer() { name_ = "qwen3_decoder_layer"; model_name_ = "qwen3"; CHECK_OPERATION_STATUS_RETURN(init_node(prefill_node_, prefill_param_)); - CHECK_OPERATION_STATUS_RETURN(init_node(decode_node_, decode_param_)); + CHECK_OPERATION_STATUS_RETURN( + init_node(decode_graph_node_, decode_graph_param_)); + CHECK_OPERATION_STATUS_RETURN( + init_node(decode_eager_node_, decode_eager_param_)); return atb::NO_ERROR; } @@ -246,13 +251,18 @@ torch::Tensor NpuQwen3DecoderLayerImpl::forward(torch::Tensor& x, kv_cache, input_params, /*is_prefill=*/true, - node_id); + node_id, + /*use_graph_decode_input=*/false); // mstxRangeEnd(id); st = execute_node(prefill_node_, node_id, event, event_flag); LOG_IF(FATAL, st != 0) << model_name_ << "excute prefill layer fail, error code: " << st; } else { - build_node_variant_pack(decode_node_, + const bool use_graph_decode_input = + FLAGS_enable_graph && input_params.graph_buffer.tiling_data.defined(); + auto& decode_node = + use_graph_decode_input ? decode_graph_node_ : decode_eager_node_; + build_node_variant_pack(decode_node, x, cos_pos, sin_pos, @@ -260,8 +270,9 @@ torch::Tensor NpuQwen3DecoderLayerImpl::forward(torch::Tensor& x, kv_cache, input_params, /*is_prefill=*/false, - node_id); - st = execute_node(decode_node_, node_id + 1000, event, event_flag); + node_id, + use_graph_decode_input); + st = execute_node(decode_node, node_id + 1000, event, event_flag); LOG_IF(FATAL, st != 0) << model_name_ << "excute decode layer fail, error code: " << st; } @@ -278,7 +289,8 @@ void NpuQwen3DecoderLayerImpl::build_node_variant_pack( KVCache& kv_cache, ModelInputParams& input_params, bool is_prefill, - int node_id) { + int node_id, + bool use_graph_decode_input) { internal_tensors_ = atb_speed::Utils::AtTensor2Tensor(x); node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER) = internal_tensors_; node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 1) = @@ -342,7 +354,7 @@ void NpuQwen3DecoderLayerImpl::build_node_variant_pack( input_params.q_seq_lens_vec.data(); } - if (FLAGS_enable_graph && !is_prefill && + if (!is_prefill && use_graph_decode_input && input_params.graph_buffer.tiling_data.defined()) { node.variantPack.inTensors.at(input_idx++) = atb_speed::Utils::AtTensor2Tensor( diff --git a/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.h b/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.h index 98609cff0..30440fe98 100644 --- a/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.h @@ -76,7 +76,8 @@ class NpuQwen3DecoderLayerImpl : public BaseLayer { KVCache& kv_cache, ModelInputParams& input_params, bool is_prefill, - int node_id); + int node_id, + bool use_graph_decode_input); void initialize_parallel_parameters(atb_speed::qwen::QwenLayerParam& param, const ParallelArgs& parallel_args); @@ -90,10 +91,12 @@ class NpuQwen3DecoderLayerImpl : public BaseLayer { int64_t init_attn_mask(); atb_speed::Model::Node prefill_node_; - atb_speed::Model::Node decode_node_; + atb_speed::Model::Node decode_graph_node_; + atb_speed::Model::Node decode_eager_node_; std::string model_name_; atb_speed::qwen::QwenLayerParam prefill_param_; - atb_speed::qwen::QwenLayerParam decode_param_; + atb_speed::qwen::QwenLayerParam decode_graph_param_; + atb_speed::qwen::QwenLayerParam decode_eager_param_; atb::Tensor internal_tensors_; atb::Tensor placeholder_; diff --git a/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp index 0d58d2ca8..7a601f370 100644 --- a/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp @@ -55,7 +55,9 @@ NpuQwen3MoeDecoderLayerImpl::NpuQwen3MoeDecoderLayerImpl( dp_local_tp_rank_ = parallel_args.rank() % dp_local_tp_size_; param_from_args(prefill_param_, model_args, parallel_args, true); - param_from_args(decode_param_, model_args, parallel_args, false); + param_from_args(decode_graph_param_, model_args, parallel_args, false); + decode_eager_param_ = decode_graph_param_; + decode_eager_param_.enableAclGraphPagedAttention = false; loader_ = std::make_unique(WEIGHT_COUNT_PER_LAYER, context); initialize_tensors(options); @@ -274,7 +276,10 @@ int64_t NpuQwen3MoeDecoderLayerImpl::init_layer() { name_ = "qwen3_moe_decoder_layer " + std::to_string(layer_id_); model_name_ = "Qwen3_Moe"; CHECK_OPERATION_STATUS_RETURN(init_node(prefill_node_, prefill_param_)); - CHECK_OPERATION_STATUS_RETURN(init_node(decode_node_, decode_param_)); + CHECK_OPERATION_STATUS_RETURN( + init_node(decode_graph_node_, decode_graph_param_)); + CHECK_OPERATION_STATUS_RETURN( + init_node(decode_eager_node_, decode_eager_param_)); return atb::NO_ERROR; } @@ -325,12 +330,17 @@ torch::Tensor NpuQwen3MoeDecoderLayerImpl::forward( attn_mask, kv_cache, input_params, - true); + true, + false); st = execute_node(prefill_node_, node_id, event, event_flag); LOG_IF(FATAL, st != 0) << model_name_ << "excute prefill layer fail, error code: " << st; } else { - build_node_variant_pack(decode_node_, + const bool use_graph_decode_input = + FLAGS_enable_graph && input_params.graph_buffer.tiling_data.defined(); + auto& decode_node = + use_graph_decode_input ? decode_graph_node_ : decode_eager_node_; + build_node_variant_pack(decode_node, x, residual, cos_pos, @@ -338,8 +348,9 @@ torch::Tensor NpuQwen3MoeDecoderLayerImpl::forward( /*attn_mask*/ tensor_placeholder_, kv_cache, input_params, - false); - st = execute_node(decode_node_, node_id + 1000, event, event_flag); + false, + use_graph_decode_input); + st = execute_node(decode_node, node_id + 1000, event, event_flag); LOG_IF(FATAL, st != 0) << model_name_ << "excute decode layer fail, error code: " << st; } @@ -356,7 +367,8 @@ void NpuQwen3MoeDecoderLayerImpl::build_node_variant_pack( torch::Tensor& attn_mask, KVCache& kv_cache, const ModelInputParams& input_params, - bool is_prefill) { + bool is_prefill, + bool use_graph_decode_input) { internal_tensor_ = atb_speed::Utils::AtTensor2Tensor(x); int32_t input_idx = 0; auto& dp_ep_padding = input_params.dp_ep_padding_data; @@ -428,7 +440,7 @@ void NpuQwen3MoeDecoderLayerImpl::build_node_variant_pack( const_cast(input_params.q_seq_lens_vec.data()); } - if (FLAGS_enable_graph && !is_prefill && + if (!is_prefill && use_graph_decode_input && input_params.graph_buffer.tiling_data.defined()) { node.variantPack.inTensors.at(input_idx++) = atb_speed::Utils::AtTensor2Tensor( diff --git a/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.h b/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.h index ed96af006..5601997b0 100644 --- a/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.h @@ -105,7 +105,8 @@ class NpuQwen3MoeDecoderLayerImpl : public BaseLayer { torch::Tensor& attn_mask, KVCache& kv_cache, const ModelInputParams& input_params, - bool is_prefill); + bool is_prefill, + bool use_graph_decode_input); torch::Tensor block_tables_placeholder_; std::string model_name_; @@ -129,10 +130,12 @@ class NpuQwen3MoeDecoderLayerImpl : public BaseLayer { int32_t num_speculative_tokens_ = 0; atb_speed::qwen::MoeDecoderLayerParam prefill_param_; - atb_speed::qwen::MoeDecoderLayerParam decode_param_; + atb_speed::qwen::MoeDecoderLayerParam decode_graph_param_; + atb_speed::qwen::MoeDecoderLayerParam decode_eager_param_; atb_speed::Model::Node prefill_node_; - atb_speed::Model::Node decode_node_; + atb_speed::Model::Node decode_graph_node_; + atb_speed::Model::Node decode_eager_node_; atb::Tensor internal_tensor_; diff --git a/xllm/core/runtime/acl_graph_executor_impl.cpp b/xllm/core/runtime/acl_graph_executor_impl.cpp index 728498fc6..7943bc18d 100644 --- a/xllm/core/runtime/acl_graph_executor_impl.cpp +++ b/xllm/core/runtime/acl_graph_executor_impl.cpp @@ -984,6 +984,16 @@ ModelOutput AclGraphExecutorImpl::run(const torch::Tensor& tokens, // Get actual num_tokens from tokens shape const uint32_t n_tokens = tokens_tensor.size(/*dim=*/0); const uint32_t actual_batch_size = n_tokens / options_.num_decoding_tokens(); + + // Large decode batches create too many/too large ACL graphs and may OOM. + // Fall back to eager mode when batch size exceeds the safety threshold. + const uint32_t decode_batch_size_limit = + std::max(1, FLAGS_acl_graph_decode_batch_size_limit); + if (actual_batch_size > decode_batch_size_limit) { + COUNTER_INC(num_model_execution_total_eager); + return model_->forward(tokens, positions, kv_caches, params); + } + const uint32_t bucket_num_tokens = get_bucket_num_tokens(n_tokens); // Check if conditions are suitable for graph execution (replay or capture) diff --git a/xllm/core/runtime/acl_graph_executor_test.cpp b/xllm/core/runtime/acl_graph_executor_test.cpp index 9dcc438c6..eeb2b0d7e 100644 --- a/xllm/core/runtime/acl_graph_executor_test.cpp +++ b/xllm/core/runtime/acl_graph_executor_test.cpp @@ -542,6 +542,57 @@ TEST_F(AclGraphExecutorTest, DifferentBatchSizes) { } } +// Test decode batch-size threshold fallback: ACL graph should fall back to +// eager when batch_size exceeds the configured limit (default: 16). +TEST_F(AclGraphExecutorTest, DecodeBatchSizeThresholdFallsBackToEager) { + constexpr uint32_t batch_size = 17; + sequences_.clear(); + auto batch = std::make_unique(); + + for (uint32_t i = 0; i < batch_size; ++i) { + sequences_.emplace_back(i, + std::vector{static_cast(1 + i), + static_cast(3 + i), + static_cast(5 + i), + static_cast(7 + i)}, + input_embedding_, + mm_data_, + fake_decoder_, + seq_params_); + auto& sequence = sequences_.back(); + sequence.add_kv_blocks(block_manager_->allocate(2)); + sequence.kv_state().incr_kv_cache_tokens_num(/*size=*/4); + sequence.append_token(100 + i); + batch->add(&sequence); + } + + auto forward_input = batch->prepare_forward_input( + options_.num_decoding_tokens(), 0, model_args_); + forward_input = forward_input.to(*device_, torch::kFloat32); + + auto npu_executor = std::make_unique( + model_.get(), model_args_, *device_, options_); + auto graph_executor = std::make_unique<::xllm::npu::AclGraphExecutorImpl>( + model_.get(), model_args_, *device_, options_); + + auto eager_out = npu_executor->run({forward_input.token_ids}, + {forward_input.positions}, + kv_caches_, + {forward_input.input_params}); + auto graph_out = graph_executor->run({forward_input.token_ids}, + {forward_input.positions}, + kv_caches_, + {forward_input.input_params}); + + EXPECT_EQ(graph_out.hidden_states.size(0), + batch_size * options_.num_decoding_tokens()); + EXPECT_EQ(graph_out.hidden_states.size(1), model_args_.hidden_size()); + EXPECT_TRUE(torch::allclose(eager_out.hidden_states, + graph_out.hidden_states, + /*rtol=*/1e-5, + /*atol=*/1e-6)); +} + // Test ACL graph executor against original NPU executor implementation TEST_F(AclGraphExecutorTest, AclGraphExecutorVsBaseExecutorImpl) { // Create test batch