Skip to content

Commit 22e8d0d

Browse files
liuruian“liuruian”
authored andcommitted
Split enable_mm
1 parent 6cae9b1 commit 22e8d0d

33 files changed

+99
-68
lines changed

fastdeploy/config.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1980,6 +1980,7 @@ def expand_bsz_map(real_bsz_to_captured_size):
19801980
int(envs.ENABLE_V1_KVCACHE_SCHEDULER) == 0
19811981
and self.model_config is not None
19821982
and self.model_config.enable_mm
1983+
and self.deploy_modality != DeployModality.TEXT
19831984
):
19841985
self.max_prefill_batch = 1 # TODO:当前V0多模prefill阶段只支持并行度为1,待优化
19851986
else:
@@ -2019,6 +2020,20 @@ def expand_bsz_map(real_bsz_to_captured_size):
20192020
self.check()
20202021
# self.print() # NOTE: it's better to explicitly call .print() when FDConfig is initialized
20212022

2023+
@property
2024+
def enable_mm_runtime(self) -> bool:
2025+
return (
2026+
self.model_config is not None
2027+
and self.model_config.enable_mm
2028+
and self.deploy_modality != DeployModality.TEXT
2029+
)
2030+
2031+
@property
2032+
def enable_rope_3d_runtime(self) -> bool:
2033+
return self.enable_mm_runtime and (
2034+
getattr(self.model_config, "rope_3d", False) or getattr(self.model_config, "use_3d_rope", False)
2035+
)
2036+
20222037
def _disable_sequence_parallel_moe_if_needed(self, mode_name):
20232038
if self.parallel_config.use_sequence_parallel_moe and self.graph_opt_config.use_cudagraph:
20242039
self.parallel_config.use_sequence_parallel_moe = False
@@ -2057,9 +2072,21 @@ def postprocess(self):
20572072
if self.long_prefill_token_threshold == 0:
20582073
self.long_prefill_token_threshold = int(self.model_config.max_model_len * 0.04)
20592074

2075+
if (
2076+
self.model_config is not None
2077+
and self.model_config.enable_mm
2078+
and self.deploy_modality == DeployModality.TEXT
2079+
):
2080+
if getattr(self.model_config, "rope_3d", False) or getattr(self.model_config, "use_3d_rope", False):
2081+
logger.info(
2082+
"Deploy modality is text; forcing the multimodal-capable model onto the 1D RoPE runtime path."
2083+
)
2084+
setattr(self.model_config, "rope_3d", False)
2085+
setattr(self.model_config, "use_3d_rope", False)
2086+
20602087
self.cache_config.max_block_num_per_seq = int(self.model_config.max_model_len // self.cache_config.block_size)
20612088
self.cache_config.postprocess(self.get_max_chunk_tokens(), self.scheduler_config.max_num_seqs)
2062-
if self.model_config is not None and self.model_config.enable_mm and not envs.ENABLE_V1_KVCACHE_SCHEDULER:
2089+
if self.model_config is not None and self.enable_mm_runtime and not envs.ENABLE_V1_KVCACHE_SCHEDULER:
20632090
self.cache_config.enable_prefix_caching = False
20642091
if (
20652092
self.structured_outputs_config is not None
@@ -2085,7 +2112,7 @@ def postprocess(self):
20852112
f"Guided decoding backend '{self.structured_outputs_config.guided_decoding_backend}' is not implemented. [auto, xgrammar, guidance, off]"
20862113
)
20872114

2088-
if self.model_config.enable_mm:
2115+
if self.enable_mm_runtime:
20892116
if self.cache_config.max_encoder_cache is None or self.cache_config.max_encoder_cache < 0:
20902117
self.cache_config.max_encoder_cache = self.scheduler_config.max_num_batched_tokens
20912118
elif self.cache_config.max_encoder_cache != 0:
@@ -2392,7 +2419,7 @@ def get_max_chunk_tokens(self, mm_max_tokens_per_item=None):
23922419
num_tokens = self.scheduler_config.max_num_seqs
23932420
else:
23942421
num_tokens = self.scheduler_config.max_num_batched_tokens
2395-
if mm_max_tokens_per_item is not None and self.deploy_modality != DeployModality.TEXT:
2422+
if self.enable_mm_runtime and mm_max_tokens_per_item is not None:
23962423
max_mm_tokens = max(
23972424
mm_max_tokens_per_item.get("image", 0),
23982425
mm_max_tokens_per_item.get("video", 0),

fastdeploy/engine/async_llm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ def __init__(self, cfg, pid):
294294
cfg.limit_mm_per_prompt,
295295
cfg.mm_processor_kwargs,
296296
cfg.tool_parser,
297+
enable_mm_runtime=cfg.enable_mm_runtime,
297298
)
298299
# Create data processor
299300
self.data_processor = self.input_processor.create_processor()
@@ -446,7 +447,7 @@ async def add_request(
446447
)
447448
if envs.ZMQ_SEND_BATCH_DATA and self.connection_manager is not None:
448449
request["zmq_worker_pid"] = self.connection_manager.worker_pid
449-
if self.cfg.model_config.enable_mm:
450+
if self.cfg.enable_mm_runtime:
450451
self.request_client.send_pyobj(request)
451452
else:
452453
self.request_client.send_json(request)

fastdeploy/engine/common_engine.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ def create_data_processor(self):
330330
self.cfg.limit_mm_per_prompt,
331331
self.cfg.mm_processor_kwargs,
332332
self.cfg.tool_parser,
333+
enable_mm_runtime=self.cfg.enable_mm_runtime,
333334
)
334335
self.data_processor = self.input_processor.create_processor()
335336
self.mm_max_tokens_per_item = self.data_processor.get_mm_max_tokens_per_item(
@@ -601,7 +602,7 @@ def insert_tasks(self, tasks: List[Request], current_id=-1):
601602
LoggingEventName.RESCHEDULED_INFERENCE_START, task.request_id, getattr(task, "user", "")
602603
)
603604
if not is_prefill:
604-
if not self.cfg.model_config.enable_mm:
605+
if not self.cfg.enable_mm_runtime:
605606
self.update_requests_chunk_size(tasks)
606607
else:
607608
self.update_mm_requests_chunk_size(tasks)
@@ -1218,7 +1219,7 @@ def _insert_zmq_task_to_scheduler(self):
12181219
while self.running:
12191220
try:
12201221
block = True if len(added_requests) == 0 else False
1221-
if not self.cfg.model_config.enable_mm:
1222+
if not self.cfg.enable_mm_runtime:
12221223
err, data = self.recv_request_server.receive_json_once(block)
12231224
else:
12241225
err, data = self.recv_request_server.receive_pyobj_once(block)
@@ -1276,6 +1277,7 @@ def _insert_zmq_task_to_scheduler(self):
12761277
err_msg = None
12771278
try:
12781279
request = Request.from_dict(data)
1280+
12791281
request.metrics.scheduler_recv_req_time = time.time()
12801282
main_process_metrics.requests_number.inc()
12811283
trace_carrier = data.get("trace_carrier")
@@ -2355,7 +2357,7 @@ def _setting_environ_variables(self):
23552357
if self.cfg.scheduler_config.splitwise_role == "prefill":
23562358
variables["FLAGS_fmt_write_cache_completed_signal"] = 1
23572359

2358-
if self.cfg.model_config.enable_mm:
2360+
if self.cfg.enable_mm_runtime:
23592361
variables["FLAGS_max_partition_size"] = 1024
23602362

23612363
command_prefix = ""

fastdeploy/engine/sched/resource_manager_v1.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -205,11 +205,11 @@ def __init__(self, max_num_seqs, config, tensor_parallel_size, splitwise_role, l
205205
self.need_block_num_map = dict()
206206

207207
self.encoder_cache = None
208-
if config.model_config.enable_mm and config.cache_config.max_encoder_cache > 0:
208+
if config.enable_mm_runtime and config.cache_config.max_encoder_cache > 0:
209209
self.encoder_cache = EncoderCacheManager(config.cache_config.max_encoder_cache)
210210

211211
self.processor_cache = None
212-
if config.model_config.enable_mm and config.cache_config.max_processor_cache > 0:
212+
if config.enable_mm_runtime and config.cache_config.max_processor_cache > 0:
213213
max_processor_cache_in_bytes = int(config.cache_config.max_processor_cache * 1024 * 1024 * 1024)
214214
self.processor_cache = ProcessorCacheManager(max_processor_cache_in_bytes)
215215

@@ -550,7 +550,7 @@ def _get_num_new_tokens(self, request, token_budget):
550550
num_new_tokens = token_budget // self.config.cache_config.block_size * self.config.cache_config.block_size
551551
request.with_image = False
552552

553-
if not self.config.model_config.enable_mm:
553+
if not self.config.enable_mm_runtime:
554554
return num_new_tokens
555555

556556
inputs = request.multimodal_inputs

fastdeploy/entrypoints/engine_client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class EngineClient:
8484
def __init__(self, pid: int | str, port: int | str, fd_config: FDConfig, workers: int = 1, max_logprobs: int = 20):
8585
self.fd_config = fd_config
8686
self.tensor_parallel_size = self.fd_config.parallel_config.tensor_parallel_size
87-
self.enable_mm = self.fd_config.model_config.enable_mm
87+
self.enable_mm = self.fd_config.enable_mm_runtime
8888
self.max_logprobs = max_logprobs
8989
input_processor = InputPreprocessor(
9090
self.fd_config.model_config,
@@ -93,6 +93,7 @@ def __init__(self, pid: int | str, port: int | str, fd_config: FDConfig, workers
9393
self.fd_config.mm_processor_kwargs,
9494
self.fd_config.tool_parser,
9595
self.enable_mm and self.fd_config.cache_config.max_processor_cache > 0,
96+
enable_mm_runtime=self.enable_mm,
9697
)
9798
self.enable_logprob = self.fd_config.model_config.enable_logprob
9899
self.data_processor = input_processor.create_processor()
@@ -358,6 +359,7 @@ async def add_requests(self, task):
358359

359360
task["max_tokens"] = min(self.max_model_len - input_ids_len, task.get("max_tokens"))
360361
min_tokens = task.get("min_tokens", 1)
362+
361363
if "messages" in task:
362364
task["messages"] = None
363365
api_server_logger.info(f"task['max_tokens']:{task['max_tokens']}")

fastdeploy/input/preprocess.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(
4848
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
4949
tool_parser: str = None,
5050
enable_processor_cache: bool = False,
51+
enable_mm_runtime: Optional[bool] = None,
5152
) -> None:
5253
self.model_config = model_config
5354
self.model_name_or_path = self.model_config.model
@@ -56,6 +57,7 @@ def __init__(
5657
self.mm_processor_kwargs = mm_processor_kwargs
5758
self.tool_parser = tool_parser
5859
self.enable_processor_cache = enable_processor_cache
60+
self.enable_mm_runtime = self.model_config.enable_mm if enable_mm_runtime is None else enable_mm_runtime
5961

6062
def create_processor(self):
6163
reasoning_parser_obj = None
@@ -77,10 +79,11 @@ def create_processor(self):
7779
reasoning_parser_obj=reasoning_parser_obj,
7880
tool_parser_obj=tool_parser_obj,
7981
mm_processor_kwargs=self.mm_processor_kwargs,
82+
enable_mm_runtime=self.enable_mm_runtime,
8083
)
8184
except Exception as e:
8285
logger.info(f"Plugin input processor not available ({e}), using built-in processor")
83-
if not self.model_config.enable_mm:
86+
if not self.enable_mm_runtime:
8487
from fastdeploy.input.text_processor import TextProcessor
8588

8689
tokenizer_type = "ernie4_5" if ErnieArchitectures.contains_ernie_arch(architecture) else "auto"

fastdeploy/inter_communicator/engine_worker_queue.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,6 @@ def put_tasks(self, tasks: List[Any]) -> None:
549549
self.lock.release()
550550
time.sleep(0.001)
551551
self.lock.acquire()
552-
553552
if envs.FD_ENABLE_MAX_PREFILL or envs.FD_ENABLE_E2W_TENSOR_CONVERT:
554553
# multimodal input numpy -> tensor
555554
to_tensor(tasks[0])
@@ -571,7 +570,6 @@ def get_tasks(self) -> Tuple[List[Any], bool]:
571570
"""
572571
tasks: List[Any] = list()
573572
self.lock.acquire()
574-
575573
tasks.extend(self.tasks)
576574
self.client_read_flag[self.client_id] = 1
577575
all_client_read: bool = np.sum(self.client_read_flag) == self.num_client

fastdeploy/model_executor/layers/attention/append_attn_backend.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,7 @@ def __init__(
138138
self.rope_theta: float = (
139139
10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta
140140
)
141-
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr(
142-
fd_config.model_config, "use_3d_rope", False
143-
)
141+
self.rope_3d: bool = fd_config.enable_rope_3d_runtime
144142
if fd_config.speculative_config.model_type != "main":
145143
self.rope_3d = False
146144
self.causal: bool = getattr(fd_config.model_config, "causal", True)

fastdeploy/model_executor/layers/attention/dsa_attention_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def __init__(
136136
self.rope_theta: float = (
137137
10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta
138138
)
139-
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
139+
self.rope_3d: bool = fd_config.enable_rope_3d_runtime
140140
self.causal: bool = getattr(fd_config.model_config, "causal", True)
141141
self.speculative_method: str = fd_config.speculative_config.method
142142
self.use_speculate: bool = self.speculative_method is not None

fastdeploy/model_executor/layers/attention/flash_attn_backend.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -267,9 +267,7 @@ def __init__(
267267

268268
self.rank, self.device_id = init_rank_and_device_id(fd_config)
269269

270-
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr(
271-
fd_config.model_config, "use_3d_rope", False
272-
)
270+
self.rope_3d: bool = fd_config.enable_rope_3d_runtime
273271
if fd_config.speculative_config.model_type != "main":
274272
self.rope_3d = False
275273
# Note(ZKK): here must be consistent with append_attn_backend.py

0 commit comments

Comments
 (0)