diff --git a/python/minisgl/scheduler/config.py b/python/minisgl/scheduler/config.py index 52a799f5..4e1077d9 100644 --- a/python/minisgl/scheduler/config.py +++ b/python/minisgl/scheduler/config.py @@ -1,14 +1,17 @@ from __future__ import annotations from dataclasses import dataclass, field +from typing import Literal from minisgl.engine import EngineConfig +SchedulePolicy = Literal["prefill_first", "decode_first"] + def _get_pid_suffix() -> str: import os - return f".pid={os.getpid()}" + return f".pid={os.getpid()}" @dataclass(frozen=True) @@ -16,6 +19,7 @@ class SchedulerConfig(EngineConfig): max_extend_tokens: int = 8192 cache_type: str = "radix" offline_mode: bool = False + schedule_policy: SchedulePolicy = "prefill_first" # networking config _unique_suffix: str = field(default_factory=_get_pid_suffix) diff --git a/python/minisgl/scheduler/scheduler.py b/python/minisgl/scheduler/scheduler.py index d0c08d83..d42dce91 100644 --- a/python/minisgl/scheduler/scheduler.py +++ b/python/minisgl/scheduler/scheduler.py @@ -19,7 +19,7 @@ from .config import SchedulerConfig from .decode import DecodeManager from .io import SchedulerIOMixin -from .prefill import ChunkedReq, PrefillManager +from .prefill import ChunkedReq, PrefillManager from .table import TableManager if TYPE_CHECKING: @@ -68,6 +68,7 @@ def __init__(self, config: SchedulerConfig): self.finished_reqs: Set[Req] = set() self.tokenizer = load_tokenizer(config.model_path) self.eos_token_id = self.tokenizer.eos_token_id + self.schedule_policy = config.schedule_policy self.token_pool = self.table_manager.token_pool self.prefill_budget = config.max_extend_tokens # self.config = config @@ -217,11 +218,16 @@ def _prepare_batch(self, batch: Batch) -> ForwardInput: ) def _schedule_next_batch(self) -> ForwardInput | None: - # TODO: support other policies: e.g. DECODE first - batch = ( - self.prefill_manager.schedule_next_batch(self.prefill_budget) - or self.decode_manager.schedule_next_batch() - ) + if self.schedule_policy == "decode_first": + batch = ( + self.decode_manager.schedule_next_batch() + or self.prefill_manager.schedule_next_batch(self.prefill_budget) + ) + else: # prefill_first + batch = ( + self.prefill_manager.schedule_next_batch(self.prefill_budget) + or self.decode_manager.schedule_next_batch() + ) return self._prepare_batch(batch) if batch else None def _forward(self, forward_input: ForwardInput) -> ForwardOutput: diff --git a/python/minisgl/server/args.py b/python/minisgl/server/args.py index 3ec88f8d..4de756c4 100644 --- a/python/minisgl/server/args.py +++ b/python/minisgl/server/args.py @@ -11,7 +11,7 @@ from minisgl.utils import init_logger -@dataclass(frozen=True) +@dataclass(frozen=True) class ServerArgs(SchedulerConfig): server_host: str = "127.0.0.1" server_port: int = 1919 @@ -194,7 +194,7 @@ def parse_args(args: List[str], run_shell: bool = False) -> Tuple[ServerArgs, bo " the first one is used for prefill and the second one for decode.", ) - parser.add_argument( + parser.add_argument( "--model-source", type=str, default="huggingface", @@ -210,6 +210,15 @@ def parse_args(args: List[str], run_shell: bool = False) -> Tuple[ServerArgs, bo help="The KV cache management strategy.", ) + parser.add_argument( + "--schedule-policy", + type=str, + default=ServerArgs.schedule_policy, + choices=["prefill_first", "decode_first"], + help="The scheduling policy. 'prefill_first' reduces TTFT for online serving, " + "'decode_first' improves throughput for offline inference.", + ) + parser.add_argument( "--moe-backend", default=ServerArgs.moe_backend,