Skip to content

Commit 5a0b770

Browse files
JackTan25LLLLKKKK
authored andcommitted
fix: fix qwen3-next pd
1 parent c8c4e87 commit 5a0b770

2 files changed

Lines changed: 20 additions & 8 deletions

File tree

rtp_llm/cpp/normal_engine/NormalEngine.cc

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
#include <thread>
2020
#include <random>
2121

22+
#if USING_CUDA
23+
#include "c10/cuda/CUDACachingAllocator.h"
24+
#endif
25+
2226
#ifdef __linux__
2327
#include <malloc.h>
2428
#endif
@@ -220,11 +224,13 @@ WarmUpResult NormalEngine::prefillWarmUp(const EngineInitParams& params) {
220224
rtp_llm::setTraceMemory(true);
221225
executor_.reset(new NormalExecutor(params, nullptr, true, false, 0, exec_init_params_));
222226
THROW_IF_STATUSOR_ERROR(preRun(fake_input, preRunMode::prefill_warm_up));
223-
const auto device_status = getGpuExecStatus();
227+
const auto max_consumed = getGpuExecStatus().device_memory_status.max_consumed_bytes;
224228
rtp_llm::setTraceMemory(false);
225229
(void)executor_.reset(nullptr);
226-
return WarmUpResult(
227-
{device_status.device_memory_status.available_bytes, device_status.device_memory_status.max_consumed_bytes});
230+
cudaDeviceSynchronize();
231+
c10::cuda::CUDACachingAllocator::emptyCache();
232+
const auto device_status = getGpuExecStatus();
233+
return WarmUpResult({device_status.device_memory_status.available_bytes, max_consumed});
228234
#endif
229235
}
230236

@@ -250,11 +256,13 @@ WarmUpResult NormalEngine::decodeWarmUp(const EngineInitParams& params) {
250256
}
251257
executor_.reset(new NormalExecutor(params, cache_manager, true, false, 0, exec_init_params_));
252258
THROW_IF_STATUSOR_ERROR(preRun(fake_input, preRunMode::decode_warm_up));
253-
const auto device_status = getGpuExecStatus();
259+
const auto max_consumed = getGpuExecStatus().device_memory_status.max_consumed_bytes;
254260
rtp_llm::setTraceMemory(false);
255261
(void)executor_.reset(nullptr);
256-
return WarmUpResult(
257-
{device_status.device_memory_status.available_bytes, device_status.device_memory_status.max_consumed_bytes});
262+
cudaDeviceSynchronize();
263+
c10::cuda::CUDACachingAllocator::emptyCache();
264+
const auto device_status = getGpuExecStatus();
265+
return WarmUpResult({device_status.device_memory_status.available_bytes, max_consumed});
258266
#endif
259267
}
260268

rtp_llm/cpp/normal_engine/speculative/MtpExecutor.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,8 @@ MtpExecutor::MtpExecutor(const EngineInitParams& params,
172172
cache_manager ? std::make_optional(target_cache_layer_layout) : std::nullopt,
173173
params.model_id,
174174
params.parallelism_config,
175-
exec_init_params});
175+
exec_init_params,
176+
cache_manager});
176177

177178
if (params.ffn_disaggregate_config.enable_ffn_disaggregate) {
178179
RTP_LLM_LOG_INFO("using ffn as service");
@@ -207,7 +208,8 @@ MtpExecutor::MtpExecutor(const EngineInitParams& params,
207208
cache_manager ? std::make_optional(draft_cache_layer_layout) : std::nullopt,
208209
mtp_params->model_id,
209210
mtp_params->parallelism_config,
210-
exec_init_params});
211+
exec_init_params,
212+
cache_manager});
211213
if (!params.py_sp_model.is_none()) {
212214
RTP_LLM_LOG_INFO("[speculative decoding] using py model");
213215
draft_model_.reset(new PyWrappedModel(
@@ -349,6 +351,7 @@ absl::Status MtpExecutor::prefillStep(const std::list<GenerateStreamPtr>& stream
349351
maybePrintModelInput(model_input, "prefill post draft model");
350352
const auto& mtp_cache_cfg = cache_manager_->getMTPModuleCacheConfig(0);
351353
model_input.kv_block_stride_bytes = mtp_cache_cfg.kv_block_stride_bytes;
354+
model_input.kv_scale_stride_bytes = mtp_cache_cfg.kv_scale_stride_bytes;
352355
model_input.kv_cache_layer_to_group = draft_kv_cache_layer_to_group;
353356
draft_model_output = std::move(draft_model_->forward(model_input));
354357
}
@@ -631,6 +634,7 @@ absl::Status MtpExecutor::decodeStep(const std::list<GenerateStreamPtr>& streams
631634
maybePrintModelInput(model_input, "decode post draft model");
632635
const auto& mtp_cache_cfg = cache_manager_->getMTPModuleCacheConfig(0);
633636
model_input.kv_block_stride_bytes = mtp_cache_cfg.kv_block_stride_bytes;
637+
model_input.kv_scale_stride_bytes = mtp_cache_cfg.kv_scale_stride_bytes;
634638
model_input.kv_cache_layer_to_group = draft_kv_cache_layer_to_group;
635639
}
636640

0 commit comments

Comments
 (0)