Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion python/minisgl/scheduler/config.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Literal

from minisgl.engine import EngineConfig

SchedulePolicy = Literal["prefill_first", "decode_first"]


def _get_pid_suffix() -> str:
import os

return f".pid={os.getpid()}"
return f".pid={os.getpid()}"


@dataclass(frozen=True)
class SchedulerConfig(EngineConfig):
max_extend_tokens: int = 8192
cache_type: str = "radix"
offline_mode: bool = False
schedule_policy: SchedulePolicy = "prefill_first"

# networking config
_unique_suffix: str = field(default_factory=_get_pid_suffix)
Expand Down
18 changes: 12 additions & 6 deletions python/minisgl/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .config import SchedulerConfig
from .decode import DecodeManager
from .io import SchedulerIOMixin
from .prefill import ChunkedReq, PrefillManager
from .prefill import ChunkedReq, PrefillManager
from .table import TableManager

if TYPE_CHECKING:
Expand Down Expand Up @@ -68,6 +68,7 @@ def __init__(self, config: SchedulerConfig):
self.finished_reqs: Set[Req] = set()
self.tokenizer = load_tokenizer(config.model_path)
self.eos_token_id = self.tokenizer.eos_token_id
self.schedule_policy = config.schedule_policy
self.token_pool = self.table_manager.token_pool
self.prefill_budget = config.max_extend_tokens
# self.config = config
Expand Down Expand Up @@ -217,11 +218,16 @@ def _prepare_batch(self, batch: Batch) -> ForwardInput:
)

def _schedule_next_batch(self) -> ForwardInput | None:
# TODO: support other policies: e.g. DECODE first
batch = (
self.prefill_manager.schedule_next_batch(self.prefill_budget)
or self.decode_manager.schedule_next_batch()
)
if self.schedule_policy == "decode_first":
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A decode first policy should be:

  1. Form a decode batch first.
  2. Try to schedule a prefill batch with the remaining token budget (prefill_budget - decode_tokens).
    This is actually mix prefill-decode style batching.

batch = (
self.decode_manager.schedule_next_batch()
or self.prefill_manager.schedule_next_batch(self.prefill_budget)
)
else: # prefill_first
batch = (
self.prefill_manager.schedule_next_batch(self.prefill_budget)
or self.decode_manager.schedule_next_batch()
)
return self._prepare_batch(batch) if batch else None

def _forward(self, forward_input: ForwardInput) -> ForwardOutput:
Expand Down
13 changes: 11 additions & 2 deletions python/minisgl/server/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from minisgl.utils import init_logger


@dataclass(frozen=True)
@dataclass(frozen=True)
class ServerArgs(SchedulerConfig):
server_host: str = "127.0.0.1"
server_port: int = 1919
Expand Down Expand Up @@ -194,7 +194,7 @@ def parse_args(args: List[str], run_shell: bool = False) -> Tuple[ServerArgs, bo
" the first one is used for prefill and the second one for decode.",
)

parser.add_argument(
parser.add_argument(
"--model-source",
type=str,
default="huggingface",
Expand All @@ -210,6 +210,15 @@ def parse_args(args: List[str], run_shell: bool = False) -> Tuple[ServerArgs, bo
help="The KV cache management strategy.",
)

parser.add_argument(
"--schedule-policy",
type=str,
default=ServerArgs.schedule_policy,
choices=["prefill_first", "decode_first"],
help="The scheduling policy. 'prefill_first' reduces TTFT for online serving, "
"'decode_first' improves throughput for offline inference.",
)

parser.add_argument(
"--moe-backend",
default=ServerArgs.moe_backend,
Expand Down