Skip to content
5 changes: 3 additions & 2 deletions benchmark/online/bench_qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,15 @@ async def main():
PORT = 1919
N = 1000
SCALES = [0.4, 0.5, 0.6, 0.7, 0.8, 1.6] # from fast to slow
async with OpenAI(base_url=f"http://127.0.0.1:{PORT}/v1", api_key="") as client:
PROFILE = False
async with OpenAI(base_url=f"http://127.0.0.1:{PORT}/v1", api_key="123") as client:
MODEL = await get_model_name(client)
tokenizer = AutoTokenizer.from_pretrained(MODEL)
TRACES = read_qwen_trace(download_qwen_trace(URL), tokenizer, n=N, dummy=True)
logger.info(f"Start benchmarking with {N} requests using model {MODEL}...")
for scale in SCALES:
traces = scale_traces(TRACES, scale)
results = await benchmark_trace(client, traces, MODEL)
results = await benchmark_trace(client, traces, MODEL, profile=PROFILE)
process_benchmark_results(results)
logger.info("Benchmarking completed.")

Expand Down
12 changes: 10 additions & 2 deletions benchmark/online/bench_simple.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import argparse
import asyncio
import random
import sys
Expand Down Expand Up @@ -34,6 +35,7 @@ async def generate_task(max_bs: int) -> List[str]:
TEST_BS = [64]
PORT = 1919
MAX_INPUT = 8192
PROFILE = False
# Create the async client
async with OpenAI(base_url=f"http://127.0.0.1:{PORT}/v1", api_key="") as client:
MODEL = await get_model_name(client)
Expand All @@ -46,7 +48,9 @@ async def generate_task(max_bs: int) -> List[str]:
try:
gen_task = asyncio.create_task(generate_task(max(TEST_BS)))
test_msg = generate_prompt(tokenizer, 100)
test_result = await benchmark_one(client, test_msg, 2, MODEL, pbar=False)
test_result = await benchmark_one(
client, test_msg, 2, MODEL, pbar=False, profile=PROFILE
)
if len(test_result.tics) <= 2:
logger.info("Server connection test failed")
return
Expand All @@ -64,7 +68,11 @@ async def generate_task(max_bs: int) -> List[str]:
for batch_size in TEST_BS:
try:
results = await benchmark_one_batch(
client, msgs[:batch_size], output_lengths[:batch_size], MODEL
client,
msgs[:batch_size],
output_lengths[:batch_size],
MODEL,
profile=PROFILE,
)
process_benchmark_results(results)
except Exception as e:
Expand Down
14 changes: 13 additions & 1 deletion python/minisgl/benchmark/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ async def benchmark_one(
pbar: Console | bool = True,
extra_body: Dict[str, Any] | None = None,
input_length: int | None = None, # a hack to force input length
profile: bool = False,
) -> RawResult:
if isinstance(pbar, bool):
pbar = make_console(1, output_length, use_pbar=pbar)
Expand All @@ -220,6 +221,8 @@ async def benchmark_one(
if input_length is not None:
kwargs["input_length_override"] = input_length
kwargs.update(extra_body or {}) # can override kwargs
if profile:
kwargs["profile"] = True
response = await client.chat.completions.create(
model=model,
stream=True,
Expand Down Expand Up @@ -257,6 +260,7 @@ async def benchmark_one_batch(
extra_body: Dict[str, Any] | None = None,
input_lengths: List[int | None] | None = None,
pbar: Console | bool = True,
profile: bool = False,
) -> List[RawResult]:
if isinstance(output_lengths, int):
output_lengths = [output_lengths] * len(prompts)
Expand All @@ -275,6 +279,7 @@ async def benchmark_one_batch(
pbar=pbar,
extra_body=extra_body,
input_length=input_length,
profile=profile,
)
for prompt, output_length, input_length in zip(
prompts, output_lengths, input_lengths, strict=True
Expand All @@ -290,6 +295,7 @@ async def benchmark_trace(
model: str,
*,
pbar: Console | bool = True,
profile: bool = False,
) -> List[RawResult]:
if isinstance(pbar, bool):
sum_output_len = sum(msg.output_length for msg in msgs)
Expand All @@ -301,7 +307,13 @@ async def benchmark_timed(msg: BenchmarkTrace):
target = start + msg.timestamp - offset
await asyncio.sleep(max(0, target - time.perf_counter()))
return await benchmark_one(
client, msg.message, msg.output_length, model, pbar=pbar, input_length=msg.input_length
client,
msg.message,
msg.output_length,
model,
pbar=pbar,
input_length=msg.input_length,
profile=profile,
)

tasks = [benchmark_timed(msg) for msg in msgs]
Expand Down
1 change: 1 addition & 0 deletions python/minisgl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class Req:
uid: int
sampling_params: SamplingParams
cache_handle: BaseCacheHandle
profile: bool = False

def __post_init__(self) -> None:
assert self.input_ids.is_cpu
Expand Down
4 changes: 2 additions & 2 deletions python/minisgl/kernel/pynccl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import functools
import os
import pathlib
from typing import TYPE_CHECKING, Any, Literal

from minisgl.env import ENV
Expand All @@ -24,12 +26,10 @@ def get_buffer(self) -> int: ...
else:
PyNCCLCommunicator = Any


@functools.cache
def _load_nccl_module() -> Module:
return load_aot("pynccl", cuda_files=["pynccl.cu"], extra_ldflags=["-lnccl"])


@functools.cache
def _get_pynccl_wrapper_cls():
import tvm_ffi
Expand Down
2 changes: 2 additions & 0 deletions python/minisgl/message/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class UserMsg(BaseBackendMsg):
uid: int
input_ids: torch.Tensor # CPU 1D int32 tensor
sampling_params: SamplingParams
profile: bool = False



@dataclass
Expand Down
2 changes: 2 additions & 0 deletions python/minisgl/message/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ class TokenizeMsg(BaseTokenizerMsg):
uid: int
text: str | List[Dict[str, str]]
sampling_params: SamplingParams
profile: bool = False


@dataclass
class AbortMsg(BaseTokenizerMsg):
uid: int

3 changes: 2 additions & 1 deletion python/minisgl/scheduler/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def _add_one_req(
uid=pending_req.uid,
cache_handle=cache_handle,
sampling_params=pending_req.sampling_params,
profile=pending_req.profile,
)

def try_add_one(self, pending_req: PendingReq) -> Req | None:
Expand Down Expand Up @@ -121,7 +122,7 @@ class PrefillManager:
pending_list: List[PendingReq] = field(default_factory=list)

def add_one_req(self, req: UserMsg) -> None:
self.pending_list.append(PendingReq(req.uid, req.input_ids, req.sampling_params))
self.pending_list.append(PendingReq(req.uid, req.input_ids, req.sampling_params, req.profile))

def schedule_next_batch(self, prefill_budget: int) -> Batch | None:
if len(self.pending_list) == 0:
Expand Down
43 changes: 41 additions & 2 deletions python/minisgl/scheduler/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING, List, NamedTuple, NoReturn, Set, Tuple, TypeAlias

import torch
Expand Down Expand Up @@ -70,11 +71,38 @@ def __init__(self, config: SchedulerConfig):
self.eos_token_id = self.tokenizer.eos_token_id
self.token_pool = self.table_manager.token_pool
self.prefill_budget = config.max_extend_tokens
# self.config = config

# Initialize the I/O mixin
self._active_profile_uid: int | None = None
self._active_profiler = None
super().__init__(config, self.engine.tp_cpu_group)

def _maybe_start_profiler(self, batch: Batch) -> None:
if self._active_profile_uid is not None:
return
profiled_uids = [r.uid for r in batch.reqs if getattr(r, "profile", False)]
if not profiled_uids:
return

from minisgl.utils.profiler import RequestProfiler

uid = profiled_uids[0]
out_dir = os.environ.get("MINISGL_PROFILE_DIR", "/tmp")
self._active_profile_uid = uid
self._active_profiler = RequestProfiler(uid=uid, out_dir=out_dir)
self._active_profiler.start()
logger.warning_rank0("Torch profiler enabled for uid=%s", uid)

def _maybe_stop_profiler(self, finished_uid: int) -> None:
if self._active_profile_uid != finished_uid or self._active_profiler is None:
return
path = self._active_profiler.stop_and_export()
if path is None:
logger.error_rank0("Torch profiler export failed for uid=%s", finished_uid)
else:
logger.warning_rank0("Torch profiler trace written to %s", path)
self._active_profile_uid = None
self._active_profiler = None

def run_when_idle(self) -> None:
"""Called when the scheduler is idle to perform background tasks."""
logger.info_rank0("Scheduler is idle, waiting for new reqs...")
Expand Down Expand Up @@ -131,6 +159,14 @@ def run_forever(self) -> NoReturn:
data = self.overlap_loop(data)

def shutdown(self) -> None:
if self._active_profiler is not None:
path = self._active_profiler.stop_and_export()
if path is None:
logger.error_rank0("Torch profiler export failed during shutdown")
else:
logger.warning_rank0("Torch profiler trace written to %s during shutdown", path)
self._active_profile_uid = None
self._active_profiler = None
torch.cuda.synchronize(self.device)
self.sync_all_ranks()
self.engine.shutdown()
Expand Down Expand Up @@ -158,6 +194,7 @@ def _process_last_data(self, last_data: ForwardData | None) -> None:
# NOTE: overlap scheduling may make the request freed twice, skip second free
if finished and req not in self.finished_reqs:
self.decode_manager.remove_req(req)
self._maybe_stop_profiler(req.uid)
self._free_req_resources(req)
new_finished_reqs.add(req)
elif batch.is_prefill: # for prefill, non-chunk req, cache the prefix
Expand Down Expand Up @@ -192,6 +229,7 @@ def _process_one_msg(self, msg: BaseBackendMsg) -> None:
req_to_free = self.prefill_manager.abort_req(msg.uid)
req_to_free = req_to_free or self.decode_manager.abort_req(msg.uid)
if req_to_free is not None:
self._maybe_stop_profiler(req_to_free.uid)
self._free_req_resources(req_to_free)
else:
logger.error(f"Unknown message type: {type(msg)}")
Expand Down Expand Up @@ -225,6 +263,7 @@ def _schedule_next_batch(self) -> ForwardInput | None:
return self._prepare_batch(batch) if batch else None

def _forward(self, forward_input: ForwardInput) -> ForwardOutput:
self._maybe_start_profiler(forward_input.batch)
batch, sample_args, input_mapping, output_mapping = forward_input
batch.input_ids = self.token_pool[input_mapping]
forward_output = self.engine.forward_batch(batch, sample_args)
Expand Down
1 change: 1 addition & 0 deletions python/minisgl/scheduler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class PendingReq:
uid: int
input_ids: torch.Tensor
sampling_params: SamplingParams
profile: bool = False
chunked_req: ChunkedReq | None = None

@property
Expand Down
5 changes: 5 additions & 0 deletions python/minisgl/server/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class GenerateRequest(BaseModel):
prompt: str
max_tokens: int
ignore_eos: bool = False
profile: bool = False


class Message(BaseModel):
Expand Down Expand Up @@ -82,6 +83,7 @@ class OpenAICompletionRequest(BaseModel):
frequency_penalty: float = 0.0

ignore_eos: bool = False
profile: bool = False


class ModelCard(BaseModel):
Expand Down Expand Up @@ -239,6 +241,7 @@ async def generate(req: GenerateRequest, request: Request):
ignore_eos=req.ignore_eos,
max_tokens=req.max_tokens,
),
profile=req.profile,
)
)

Expand Down Expand Up @@ -275,6 +278,7 @@ async def v1_completions(req: OpenAICompletionRequest, request: Request):
top_k=req.top_k,
top_p=req.top_p,
),
profile=req.profile,
)
)

Expand Down Expand Up @@ -308,6 +312,7 @@ async def shell_completion(req: OpenAICompletionRequest):
top_k=req.top_k,
top_p=req.top_p,
),
profile=req.profile,
)
)

Expand Down
1 change: 1 addition & 0 deletions python/minisgl/tokenizer/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def tokenize_worker(
uid=msg.uid,
input_ids=t,
sampling_params=msg.sampling_params,
profile=getattr(msg, "profile", False),
)
for msg, t in zip(tokenize_msg, tensors, strict=True)
]
Expand Down
Loading