File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -500,14 +500,8 @@ Engine::KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() {
500500 int64_t head_v_dim = args_.linear_value_head_dim ();
501501
502502 // Parse mamba_ssm_dtype if specified
503- int64_t ssm_dtype_size = dtype_size;
504- if (!args_.mamba_ssm_dtype ().empty ()) {
505- auto parsed_ssm_dtype =
506- try_get_scalar_type_from_string (args_.mamba_ssm_dtype ());
507- if (parsed_ssm_dtype) {
508- ssm_dtype_size = get_dtype_size (parsed_ssm_dtype.value ());
509- }
510- }
503+ int64_t ssm_dtype_size =
504+ resolve_ssm_dtype_size (args_.mamba_ssm_dtype (), dtype_size);
511505
512506 int64_t linear_ssm_slot_size =
513507 ssm_dtype_size * n_local_linear_v_heads_ * head_k_dim * head_v_dim;
Original file line number Diff line number Diff line change @@ -181,6 +181,7 @@ struct ModelArgs {
181181 PROPERTY (int32_t , linear_value_head_dim) = 0 ;
182182 PROPERTY (int64_t , linear_num_key_heads) = 0 ;
183183 PROPERTY (int32_t , linear_num_value_heads) = 0 ;
184+ PROPERTY (std::string, mamba_ssm_dtype);
184185 PROPERTY (int32_t , shared_expert_intermediate_size) = 0 ;
185186 PROPERTY (float , partial_rotary_factor) = 0 .0f ;
186187 PROPERTY (std::vector<std::string>, layer_types) = {};
@@ -339,9 +340,6 @@ struct ModelArgs {
339340 PROPERTY (int64_t , mm_image_shortest_edge) = 0 ;
340341 PROPERTY (int64_t , mm_image_longest_edge) = 0 ;
341342
342- // Mamba SSM dtype
343- PROPERTY (std::string, mamba_ssm_dtype);
344-
345343 // GLM
346344 PROPERTY (int64_t , mm_video_shortest_edge) = 0 ;
347345 PROPERTY (int64_t , mm_video_longest_edge) = 0 ;
Original file line number Diff line number Diff line change @@ -252,15 +252,9 @@ bool WorkerImpl::allocate_kv_cache(
252252
253253 if (is_linear_layer) {
254254 // Linear attention layer: only allocate conv_cache and ssm_cache
255- torch::ScalarType ssm_dtype = dtype_;
256255 // Parse mamba_ssm_dtype if specified
257- if (!args.mamba_ssm_dtype ().empty ()) {
258- auto parsed_ssm_dtype =
259- try_get_scalar_type_from_string (args.mamba_ssm_dtype ());
260- if (parsed_ssm_dtype) {
261- ssm_dtype = parsed_ssm_dtype.value ();
262- }
263- }
256+ torch::ScalarType ssm_dtype =
257+ resolve_ssm_dtype (args.mamba_ssm_dtype (), dtype_);
264258
265259#if defined(USE_NPU)
266260 aclFormat npu_format_type = ACL_FORMAT_ND;
Original file line number Diff line number Diff line change @@ -360,4 +360,34 @@ inline int32_t get_dtype_size(torch::ScalarType dtype) {
360360 return static_cast <int32_t >(torch::elementSize (dtype));
361361}
362362
363+ inline torch::ScalarType resolve_ssm_dtype (
364+ const std::string& mamba_ssm_dtype_str,
365+ torch::ScalarType default_dtype) {
366+ if (mamba_ssm_dtype_str.empty ()) {
367+ return default_dtype;
368+ }
369+ auto parsed = try_get_scalar_type_from_string (mamba_ssm_dtype_str);
370+ if (parsed) {
371+ return parsed.value ();
372+ }
373+ LOG (WARNING) << " Failed to parse mamba_ssm_dtype='" << mamba_ssm_dtype_str
374+ << " ', falling back to default_dtype: " << default_dtype;
375+ return default_dtype;
376+ }
377+
378+ inline int64_t resolve_ssm_dtype_size (
379+ const std::string& mamba_ssm_dtype_str,
380+ int64_t default_dtype_size) {
381+ if (mamba_ssm_dtype_str.empty ()) {
382+ return default_dtype_size;
383+ }
384+ auto parsed = try_get_scalar_type_from_string (mamba_ssm_dtype_str);
385+ if (parsed) {
386+ return get_dtype_size (parsed.value ());
387+ }
388+ LOG (WARNING) << " Failed to parse mamba_ssm_dtype='" << mamba_ssm_dtype_str
389+ << " ', falling back to default dtype size" ;
390+ return default_dtype_size;
391+ }
392+
363393} // namespace xllm
Original file line number Diff line number Diff line change @@ -146,7 +146,7 @@ TORCH_MODULE(Qwen3_5ForCausalLM);
146146 SET_ARG (routed_scaling_factor, 1 .0f ); \
147147 SET_ARG (stop_token_ids, \
148148 std::unordered_set<int32_t >({args->eos_token_id ()})); \
149- LOAD_ARG_TEXT_OR_ROOT (mamba_ssm_dtype, " mamba_ssm_dtype" , " bfloat16 " )
149+ LOAD_ARG_TEXT_OR_ROOT (mamba_ssm_dtype, " mamba_ssm_dtype" , " float32 " )
150150
151151#define LOAD_QWEN3_5_TYPE_AND_DTYPE (default_model_type ) \
152152 LOAD_ARG_OR (model_type, " model_type" , default_model_type); \
You can’t perform that action at this time.
0 commit comments