Skip to content

Commit 9507fc5

Browse files
committed
refactor the code
1 parent eb51399 commit 9507fc5

5 files changed

Lines changed: 36 additions & 20 deletions

File tree

xllm/core/distributed_runtime/llm_engine.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -500,14 +500,8 @@ Engine::KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() {
500500
int64_t head_v_dim = args_.linear_value_head_dim();
501501

502502
// Parse mamba_ssm_dtype if specified
503-
int64_t ssm_dtype_size = dtype_size;
504-
if (!args_.mamba_ssm_dtype().empty()) {
505-
auto parsed_ssm_dtype =
506-
try_get_scalar_type_from_string(args_.mamba_ssm_dtype());
507-
if (parsed_ssm_dtype) {
508-
ssm_dtype_size = get_dtype_size(parsed_ssm_dtype.value());
509-
}
510-
}
503+
int64_t ssm_dtype_size =
504+
resolve_ssm_dtype_size(args_.mamba_ssm_dtype(), dtype_size);
511505

512506
int64_t linear_ssm_slot_size =
513507
ssm_dtype_size * n_local_linear_v_heads_ * head_k_dim * head_v_dim;

xllm/core/framework/model/model_args.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ struct ModelArgs {
181181
PROPERTY(int32_t, linear_value_head_dim) = 0;
182182
PROPERTY(int64_t, linear_num_key_heads) = 0;
183183
PROPERTY(int32_t, linear_num_value_heads) = 0;
184+
PROPERTY(std::string, mamba_ssm_dtype);
184185
PROPERTY(int32_t, shared_expert_intermediate_size) = 0;
185186
PROPERTY(float, partial_rotary_factor) = 0.0f;
186187
PROPERTY(std::vector<std::string>, layer_types) = {};
@@ -339,9 +340,6 @@ struct ModelArgs {
339340
PROPERTY(int64_t, mm_image_shortest_edge) = 0;
340341
PROPERTY(int64_t, mm_image_longest_edge) = 0;
341342

342-
// Mamba SSM dtype
343-
PROPERTY(std::string, mamba_ssm_dtype);
344-
345343
// GLM
346344
PROPERTY(int64_t, mm_video_shortest_edge) = 0;
347345
PROPERTY(int64_t, mm_video_longest_edge) = 0;

xllm/core/runtime/worker_impl.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -252,15 +252,9 @@ bool WorkerImpl::allocate_kv_cache(
252252

253253
if (is_linear_layer) {
254254
// Linear attention layer: only allocate conv_cache and ssm_cache
255-
torch::ScalarType ssm_dtype = dtype_;
256255
// Parse mamba_ssm_dtype if specified
257-
if (!args.mamba_ssm_dtype().empty()) {
258-
auto parsed_ssm_dtype =
259-
try_get_scalar_type_from_string(args.mamba_ssm_dtype());
260-
if (parsed_ssm_dtype) {
261-
ssm_dtype = parsed_ssm_dtype.value();
262-
}
263-
}
256+
torch::ScalarType ssm_dtype =
257+
resolve_ssm_dtype(args.mamba_ssm_dtype(), dtype_);
264258

265259
#if defined(USE_NPU)
266260
aclFormat npu_format_type = ACL_FORMAT_ND;

xllm/core/util/tensor_helper.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,4 +360,34 @@ inline int32_t get_dtype_size(torch::ScalarType dtype) {
360360
return static_cast<int32_t>(torch::elementSize(dtype));
361361
}
362362

363+
inline torch::ScalarType resolve_ssm_dtype(
364+
const std::string& mamba_ssm_dtype_str,
365+
torch::ScalarType default_dtype) {
366+
if (mamba_ssm_dtype_str.empty()) {
367+
return default_dtype;
368+
}
369+
auto parsed = try_get_scalar_type_from_string(mamba_ssm_dtype_str);
370+
if (parsed) {
371+
return parsed.value();
372+
}
373+
LOG(WARNING) << "Failed to parse mamba_ssm_dtype='" << mamba_ssm_dtype_str
374+
<< "', falling back to default_dtype: " << default_dtype;
375+
return default_dtype;
376+
}
377+
378+
inline int64_t resolve_ssm_dtype_size(
379+
const std::string& mamba_ssm_dtype_str,
380+
int64_t default_dtype_size) {
381+
if (mamba_ssm_dtype_str.empty()) {
382+
return default_dtype_size;
383+
}
384+
auto parsed = try_get_scalar_type_from_string(mamba_ssm_dtype_str);
385+
if (parsed) {
386+
return get_dtype_size(parsed.value());
387+
}
388+
LOG(WARNING) << "Failed to parse mamba_ssm_dtype='" << mamba_ssm_dtype_str
389+
<< "', falling back to default dtype size";
390+
return default_dtype_size;
391+
}
392+
363393
} // namespace xllm

xllm/models/llm/qwen3_5.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ TORCH_MODULE(Qwen3_5ForCausalLM);
146146
SET_ARG(routed_scaling_factor, 1.0f); \
147147
SET_ARG(stop_token_ids, \
148148
std::unordered_set<int32_t>({args->eos_token_id()})); \
149-
LOAD_ARG_TEXT_OR_ROOT(mamba_ssm_dtype, "mamba_ssm_dtype", "bfloat16")
149+
LOAD_ARG_TEXT_OR_ROOT(mamba_ssm_dtype, "mamba_ssm_dtype", "float32")
150150

151151
#define LOAD_QWEN3_5_TYPE_AND_DTYPE(default_model_type) \
152152
LOAD_ARG_OR(model_type, "model_type", default_model_type); \

0 commit comments

Comments
 (0)